xref: /aosp_15_r20/external/pytorch/benchmarks/overrides_benchmark/pyspybench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2
3from common import SubTensor, SubWithTorchFunction, WithTorchFunction  # noqa: F401
4
5import torch
6
7
8Tensor = torch.tensor
9
10NUM_REPEATS = 1000000
11
12if __name__ == "__main__":
13    parser = argparse.ArgumentParser(
14        description="Run the torch.add for a given class a given number of times."
15    )
16    parser.add_argument(
17        "tensor_class", metavar="TensorClass", type=str, help="The class to benchmark."
18    )
19    parser.add_argument(
20        "--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats."
21    )
22    args = parser.parse_args()
23
24    TensorClass = globals()[args.tensor_class]
25    NUM_REPEATS = args.nreps
26
27    t1 = TensorClass([1.0])
28    t2 = TensorClass([2.0])
29
30    for _ in range(NUM_REPEATS):
31        torch.add(t1, t2)
32