1*da0073e9SAndroid Build Coastguard Workerimport time 2*da0073e9SAndroid Build Coastguard Workerimport timeit 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport numpy as np 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerdef add1(x): 10*da0073e9SAndroid Build Coastguard Worker return x + 1 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerdef bench(name, fn, requires_grad): 14*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 15*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, requires_grad=requires_grad) 16*da0073e9SAndroid Build Coastguard Worker start = time.perf_counter() 17*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 18*da0073e9SAndroid Build Coastguard Worker fn(x) 19*da0073e9SAndroid Build Coastguard Worker end = time.perf_counter() 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker results = timeit.repeat(lambda: fn(x), number=1000, repeat=1000) 22*da0073e9SAndroid Build Coastguard Worker print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)") 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerdef main(): 26*da0073e9SAndroid Build Coastguard Worker print("requires_grad=False") 27*da0073e9SAndroid Build Coastguard Worker bench("eager ", add1, False) 28*da0073e9SAndroid Build Coastguard Worker bench("compiled", torch.compile(add1), False) 29*da0073e9SAndroid Build Coastguard Worker print() 30*da0073e9SAndroid Build Coastguard Worker print("requires_grad=True") 31*da0073e9SAndroid Build Coastguard Worker bench("eager ", add1, True) 32*da0073e9SAndroid Build Coastguard Worker bench("compiled", torch.compile(add1), True) 33*da0073e9SAndroid Build Coastguard Worker print() 34*da0073e9SAndroid Build Coastguard Worker print("inference_mode()") 35*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 36*da0073e9SAndroid Build Coastguard Worker bench("eager ", add1, False) 37*da0073e9SAndroid Build Coastguard Worker bench("compiled", torch.compile(add1), False) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 41*da0073e9SAndroid Build Coastguard Worker main() 42