xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/ddp/diff.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker#
3*da0073e9SAndroid Build Coastguard Worker# Computes difference between measurements produced by ./benchmark.py.
4*da0073e9SAndroid Build Coastguard Worker#
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport argparse
7*da0073e9SAndroid Build Coastguard Workerimport json
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport numpy as np
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerdef load(path):
13*da0073e9SAndroid Build Coastguard Worker    with open(path) as f:
14*da0073e9SAndroid Build Coastguard Worker        return json.load(f)
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerdef main():
18*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="PyTorch distributed benchmark diff")
19*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("file", nargs=2)
20*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    if len(args.file) != 2:
23*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Must specify 2 files to diff")
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    ja = load(args.file[0])
26*da0073e9SAndroid Build Coastguard Worker    jb = load(args.file[1])
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"}
29*da0073e9SAndroid Build Coastguard Worker    print(f"{'':20s} {'baseline':>20s}      {'test':>20s}")
30*da0073e9SAndroid Build Coastguard Worker    print(f"{'':20s} {'-' * 20:>20s}      {'-' * 20:>20s}")
31*da0073e9SAndroid Build Coastguard Worker    for key in sorted(keys):
32*da0073e9SAndroid Build Coastguard Worker        va = str(ja.get(key, "-"))
33*da0073e9SAndroid Build Coastguard Worker        vb = str(jb.get(key, "-"))
34*da0073e9SAndroid Build Coastguard Worker        print(f"{key + ':':20s} {va:>20s}  vs  {vb:>20s}")
35*da0073e9SAndroid Build Coastguard Worker    print()
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    ba = ja["benchmark_results"]
38*da0073e9SAndroid Build Coastguard Worker    bb = jb["benchmark_results"]
39*da0073e9SAndroid Build Coastguard Worker    for ra, rb in zip(ba, bb):
40*da0073e9SAndroid Build Coastguard Worker        if ra["model"] != rb["model"]:
41*da0073e9SAndroid Build Coastguard Worker            continue
42*da0073e9SAndroid Build Coastguard Worker        if ra["batch_size"] != rb["batch_size"]:
43*da0073e9SAndroid Build Coastguard Worker            continue
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        model = ra["model"]
46*da0073e9SAndroid Build Coastguard Worker        batch_size = int(ra["batch_size"])
47*da0073e9SAndroid Build Coastguard Worker        name = f"{model} with batch size {batch_size}"
48*da0073e9SAndroid Build Coastguard Worker        print(f"Benchmark: {name}")
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        # Print header
51*da0073e9SAndroid Build Coastguard Worker        print()
52*da0073e9SAndroid Build Coastguard Worker        print(f"{'':>10s}", end="")  # noqa: E999
53*da0073e9SAndroid Build Coastguard Worker        for _ in [75, 95]:
54*da0073e9SAndroid Build Coastguard Worker            print(
55*da0073e9SAndroid Build Coastguard Worker                f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
56*da0073e9SAndroid Build Coastguard Worker            )  # noqa: E999
57*da0073e9SAndroid Build Coastguard Worker        print()
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        # Print measurements
60*da0073e9SAndroid Build Coastguard Worker        for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
61*da0073e9SAndroid Build Coastguard Worker            # Ignore round without ddp
62*da0073e9SAndroid Build Coastguard Worker            if i == 0:
63*da0073e9SAndroid Build Coastguard Worker                continue
64*da0073e9SAndroid Build Coastguard Worker            # Sanity check: ignore if number of ranks is not equal
65*da0073e9SAndroid Build Coastguard Worker            if len(xa["ranks"]) != len(xb["ranks"]):
66*da0073e9SAndroid Build Coastguard Worker                continue
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker            ngpus = len(xa["ranks"])
69*da0073e9SAndroid Build Coastguard Worker            ma = sorted(xa["measurements"])
70*da0073e9SAndroid Build Coastguard Worker            mb = sorted(xb["measurements"])
71*da0073e9SAndroid Build Coastguard Worker            print(f"{ngpus:>4d} GPUs:", end="")  # noqa: E999
72*da0073e9SAndroid Build Coastguard Worker            for p in [75, 95]:
73*da0073e9SAndroid Build Coastguard Worker                va = np.percentile(ma, p)
74*da0073e9SAndroid Build Coastguard Worker                vb = np.percentile(mb, p)
75*da0073e9SAndroid Build Coastguard Worker                # We're measuring time, so lower is better (hence the negation)
76*da0073e9SAndroid Build Coastguard Worker                delta = -100 * ((vb - va) / va)
77*da0073e9SAndroid Build Coastguard Worker                print(
78*da0073e9SAndroid Build Coastguard Worker                    f"  p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
79*da0073e9SAndroid Build Coastguard Worker                    end="",
80*da0073e9SAndroid Build Coastguard Worker                )  # noqa: E999
81*da0073e9SAndroid Build Coastguard Worker            print()
82*da0073e9SAndroid Build Coastguard Worker        print()
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
86*da0073e9SAndroid Build Coastguard Worker    main()
87