1import argparse 2 3from common import SubTensor, SubWithTorchFunction, WithTorchFunction # noqa: F401 4 5import torch 6 7 8Tensor = torch.tensor 9 10NUM_REPEATS = 1000000 11 12if __name__ == "__main__": 13 parser = argparse.ArgumentParser( 14 description="Run the torch.add for a given class a given number of times." 15 ) 16 parser.add_argument( 17 "tensor_class", metavar="TensorClass", type=str, help="The class to benchmark." 18 ) 19 parser.add_argument( 20 "--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats." 21 ) 22 args = parser.parse_args() 23 24 TensorClass = globals()[args.tensor_class] 25 NUM_REPEATS = args.nreps 26 27 t1 = TensorClass([1.0]) 28 t2 = TensorClass([2.0]) 29 30 for _ in range(NUM_REPEATS): 31 torch.add(t1, t2) 32