1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom common import SubTensor, SubWithTorchFunction, WithTorchFunction # noqa: F401 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard WorkerTensor = torch.tensor 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard WorkerNUM_REPEATS = 1000000 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 13*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 14*da0073e9SAndroid Build Coastguard Worker description="Run the torch.add for a given class a given number of times." 15*da0073e9SAndroid Build Coastguard Worker ) 16*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 17*da0073e9SAndroid Build Coastguard Worker "tensor_class", metavar="TensorClass", type=str, help="The class to benchmark." 18*da0073e9SAndroid Build Coastguard Worker ) 19*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 20*da0073e9SAndroid Build Coastguard Worker "--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats." 21*da0073e9SAndroid Build Coastguard Worker ) 22*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker TensorClass = globals()[args.tensor_class] 25*da0073e9SAndroid Build Coastguard Worker NUM_REPEATS = args.nreps 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker t1 = TensorClass([1.0]) 28*da0073e9SAndroid Build Coastguard Worker t2 = TensorClass([2.0]) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker for _ in range(NUM_REPEATS): 31*da0073e9SAndroid Build Coastguard Worker torch.add(t1, t2) 32