1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport time 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerfrom common import SubTensor, SubWithTorchFunction, WithTorchFunction 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard WorkerNUM_REPEATS = 1000 10*da0073e9SAndroid Build Coastguard WorkerNUM_REPEAT_OF_REPEATS = 1000 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerdef bench(t1, t2): 14*da0073e9SAndroid Build Coastguard Worker bench_times = [] 15*da0073e9SAndroid Build Coastguard Worker for _ in range(NUM_REPEAT_OF_REPEATS): 16*da0073e9SAndroid Build Coastguard Worker time_start = time.time() 17*da0073e9SAndroid Build Coastguard Worker for _ in range(NUM_REPEATS): 18*da0073e9SAndroid Build Coastguard Worker torch.add(t1, t2) 19*da0073e9SAndroid Build Coastguard Worker bench_times.append(time.time() - time_start) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker bench_time = float(torch.min(torch.tensor(bench_times))) / 1000 22*da0073e9SAndroid Build Coastguard Worker bench_std = float(torch.std(torch.tensor(bench_times))) / 1000 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker return bench_time, bench_std 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerdef main(): 28*da0073e9SAndroid Build Coastguard Worker global NUM_REPEATS 29*da0073e9SAndroid Build Coastguard Worker global NUM_REPEAT_OF_REPEATS 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 32*da0073e9SAndroid Build Coastguard Worker description="Run the __torch_function__ benchmarks." 33*da0073e9SAndroid Build Coastguard Worker ) 34*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 35*da0073e9SAndroid Build Coastguard Worker "--nreps", 36*da0073e9SAndroid Build Coastguard Worker "-n", 37*da0073e9SAndroid Build Coastguard Worker type=int, 38*da0073e9SAndroid Build Coastguard Worker default=NUM_REPEATS, 39*da0073e9SAndroid Build Coastguard Worker help="The number of repeats for one measurement.", 40*da0073e9SAndroid Build Coastguard Worker ) 41*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 42*da0073e9SAndroid Build Coastguard Worker "--nrepreps", 43*da0073e9SAndroid Build Coastguard Worker "-m", 44*da0073e9SAndroid Build Coastguard Worker type=int, 45*da0073e9SAndroid Build Coastguard Worker default=NUM_REPEAT_OF_REPEATS, 46*da0073e9SAndroid Build Coastguard Worker help="The number of measurements.", 47*da0073e9SAndroid Build Coastguard Worker ) 48*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker NUM_REPEATS = args.nreps 51*da0073e9SAndroid Build Coastguard Worker NUM_REPEAT_OF_REPEATS = args.nrepreps 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker types = torch.tensor, SubTensor, WithTorchFunction, SubWithTorchFunction 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker for t in types: 56*da0073e9SAndroid Build Coastguard Worker tensor_1 = t([1.0]) 57*da0073e9SAndroid Build Coastguard Worker tensor_2 = t([2.0]) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker bench_min, bench_std = bench(tensor_1, tensor_2) 60*da0073e9SAndroid Build Coastguard Worker print( 61*da0073e9SAndroid Build Coastguard Worker f"Type {t.__name__} had a minimum time of {10**6 * bench_min} us" 62*da0073e9SAndroid Build Coastguard Worker f" and a standard deviation of {(10**6) * bench_std} us." 63*da0073e9SAndroid Build Coastguard Worker ) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 67*da0073e9SAndroid Build Coastguard Worker main() 68