xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/overheads.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport time
2*da0073e9SAndroid Build Coastguard Workerimport timeit
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport numpy as np
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerdef add1(x):
10*da0073e9SAndroid Build Coastguard Worker    return x + 1
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerdef bench(name, fn, requires_grad):
14*da0073e9SAndroid Build Coastguard Worker    torch._dynamo.reset()
15*da0073e9SAndroid Build Coastguard Worker    x = torch.randn(1, requires_grad=requires_grad)
16*da0073e9SAndroid Build Coastguard Worker    start = time.perf_counter()
17*da0073e9SAndroid Build Coastguard Worker    for _ in range(3):
18*da0073e9SAndroid Build Coastguard Worker        fn(x)
19*da0073e9SAndroid Build Coastguard Worker    end = time.perf_counter()
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    results = timeit.repeat(lambda: fn(x), number=1000, repeat=1000)
22*da0073e9SAndroid Build Coastguard Worker    print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)")
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerdef main():
26*da0073e9SAndroid Build Coastguard Worker    print("requires_grad=False")
27*da0073e9SAndroid Build Coastguard Worker    bench("eager   ", add1, False)
28*da0073e9SAndroid Build Coastguard Worker    bench("compiled", torch.compile(add1), False)
29*da0073e9SAndroid Build Coastguard Worker    print()
30*da0073e9SAndroid Build Coastguard Worker    print("requires_grad=True")
31*da0073e9SAndroid Build Coastguard Worker    bench("eager   ", add1, True)
32*da0073e9SAndroid Build Coastguard Worker    bench("compiled", torch.compile(add1), True)
33*da0073e9SAndroid Build Coastguard Worker    print()
34*da0073e9SAndroid Build Coastguard Worker    print("inference_mode()")
35*da0073e9SAndroid Build Coastguard Worker    with torch.inference_mode():
36*da0073e9SAndroid Build Coastguard Worker        bench("eager   ", add1, False)
37*da0073e9SAndroid Build Coastguard Worker        bench("compiled", torch.compile(add1), False)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
41*da0073e9SAndroid Build Coastguard Worker    main()
42