1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# 3*da0073e9SAndroid Build Coastguard Worker# Computes difference between measurements produced by ./benchmark.py. 4*da0073e9SAndroid Build Coastguard Worker# 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport argparse 7*da0073e9SAndroid Build Coastguard Workerimport json 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport numpy as np 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerdef load(path): 13*da0073e9SAndroid Build Coastguard Worker with open(path) as f: 14*da0073e9SAndroid Build Coastguard Worker return json.load(f) 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerdef main(): 18*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="PyTorch distributed benchmark diff") 19*da0073e9SAndroid Build Coastguard Worker parser.add_argument("file", nargs=2) 20*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker if len(args.file) != 2: 23*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Must specify 2 files to diff") 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker ja = load(args.file[0]) 26*da0073e9SAndroid Build Coastguard Worker jb = load(args.file[1]) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"} 29*da0073e9SAndroid Build Coastguard Worker print(f"{'':20s} {'baseline':>20s} {'test':>20s}") 30*da0073e9SAndroid Build Coastguard Worker print(f"{'':20s} {'-' * 20:>20s} {'-' * 20:>20s}") 31*da0073e9SAndroid Build Coastguard Worker for key in sorted(keys): 32*da0073e9SAndroid Build Coastguard Worker va = str(ja.get(key, "-")) 33*da0073e9SAndroid Build Coastguard Worker vb = str(jb.get(key, "-")) 34*da0073e9SAndroid Build Coastguard Worker print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}") 35*da0073e9SAndroid Build Coastguard Worker print() 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker ba = ja["benchmark_results"] 38*da0073e9SAndroid Build Coastguard Worker bb = jb["benchmark_results"] 39*da0073e9SAndroid Build Coastguard Worker for ra, rb in zip(ba, bb): 40*da0073e9SAndroid Build Coastguard Worker if ra["model"] != rb["model"]: 41*da0073e9SAndroid Build Coastguard Worker continue 42*da0073e9SAndroid Build Coastguard Worker if ra["batch_size"] != rb["batch_size"]: 43*da0073e9SAndroid Build Coastguard Worker continue 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker model = ra["model"] 46*da0073e9SAndroid Build Coastguard Worker batch_size = int(ra["batch_size"]) 47*da0073e9SAndroid Build Coastguard Worker name = f"{model} with batch size {batch_size}" 48*da0073e9SAndroid Build Coastguard Worker print(f"Benchmark: {name}") 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker # Print header 51*da0073e9SAndroid Build Coastguard Worker print() 52*da0073e9SAndroid Build Coastguard Worker print(f"{'':>10s}", end="") # noqa: E999 53*da0073e9SAndroid Build Coastguard Worker for _ in [75, 95]: 54*da0073e9SAndroid Build Coastguard Worker print( 55*da0073e9SAndroid Build Coastguard Worker f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="" 56*da0073e9SAndroid Build Coastguard Worker ) # noqa: E999 57*da0073e9SAndroid Build Coastguard Worker print() 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker # Print measurements 60*da0073e9SAndroid Build Coastguard Worker for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])): 61*da0073e9SAndroid Build Coastguard Worker # Ignore round without ddp 62*da0073e9SAndroid Build Coastguard Worker if i == 0: 63*da0073e9SAndroid Build Coastguard Worker continue 64*da0073e9SAndroid Build Coastguard Worker # Sanity check: ignore if number of ranks is not equal 65*da0073e9SAndroid Build Coastguard Worker if len(xa["ranks"]) != len(xb["ranks"]): 66*da0073e9SAndroid Build Coastguard Worker continue 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker ngpus = len(xa["ranks"]) 69*da0073e9SAndroid Build Coastguard Worker ma = sorted(xa["measurements"]) 70*da0073e9SAndroid Build Coastguard Worker mb = sorted(xb["measurements"]) 71*da0073e9SAndroid Build Coastguard Worker print(f"{ngpus:>4d} GPUs:", end="") # noqa: E999 72*da0073e9SAndroid Build Coastguard Worker for p in [75, 95]: 73*da0073e9SAndroid Build Coastguard Worker va = np.percentile(ma, p) 74*da0073e9SAndroid Build Coastguard Worker vb = np.percentile(mb, p) 75*da0073e9SAndroid Build Coastguard Worker # We're measuring time, so lower is better (hence the negation) 76*da0073e9SAndroid Build Coastguard Worker delta = -100 * ((vb - va) / va) 77*da0073e9SAndroid Build Coastguard Worker print( 78*da0073e9SAndroid Build Coastguard Worker f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%", 79*da0073e9SAndroid Build Coastguard Worker end="", 80*da0073e9SAndroid Build Coastguard Worker ) # noqa: E999 81*da0073e9SAndroid Build Coastguard Worker print() 82*da0073e9SAndroid Build Coastguard Worker print() 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 86*da0073e9SAndroid Build Coastguard Worker main() 87