1#!/usr/bin/env python3 2# 3# Computes difference between measurements produced by ./benchmark.py. 4# 5 6import argparse 7import json 8 9import numpy as np 10 11 12def load(path): 13 with open(path) as f: 14 return json.load(f) 15 16 17def main(): 18 parser = argparse.ArgumentParser(description="PyTorch distributed benchmark diff") 19 parser.add_argument("file", nargs=2) 20 args = parser.parse_args() 21 22 if len(args.file) != 2: 23 raise RuntimeError("Must specify 2 files to diff") 24 25 ja = load(args.file[0]) 26 jb = load(args.file[1]) 27 28 keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"} 29 print(f"{'':20s} {'baseline':>20s} {'test':>20s}") 30 print(f"{'':20s} {'-' * 20:>20s} {'-' * 20:>20s}") 31 for key in sorted(keys): 32 va = str(ja.get(key, "-")) 33 vb = str(jb.get(key, "-")) 34 print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}") 35 print() 36 37 ba = ja["benchmark_results"] 38 bb = jb["benchmark_results"] 39 for ra, rb in zip(ba, bb): 40 if ra["model"] != rb["model"]: 41 continue 42 if ra["batch_size"] != rb["batch_size"]: 43 continue 44 45 model = ra["model"] 46 batch_size = int(ra["batch_size"]) 47 name = f"{model} with batch size {batch_size}" 48 print(f"Benchmark: {name}") 49 50 # Print header 51 print() 52 print(f"{'':>10s}", end="") # noqa: E999 53 for _ in [75, 95]: 54 print( 55 f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="" 56 ) # noqa: E999 57 print() 58 59 # Print measurements 60 for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])): 61 # Ignore round without ddp 62 if i == 0: 63 continue 64 # Sanity check: ignore if number of ranks is not equal 65 if len(xa["ranks"]) != len(xb["ranks"]): 66 continue 67 68 ngpus = len(xa["ranks"]) 69 ma = sorted(xa["measurements"]) 70 mb = sorted(xb["measurements"]) 71 print(f"{ngpus:>4d} GPUs:", end="") # noqa: E999 72 for p in [75, 95]: 73 va = np.percentile(ma, p) 74 vb = np.percentile(mb, p) 75 # We're measuring time, so lower is better (hence the negation) 76 delta = -100 * ((vb - va) / va) 77 print( 78 f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%", 79 end="", 80 ) # noqa: E999 81 print() 82 print() 83 84 85if __name__ == "__main__": 86 main() 87