xref: /aosp_15_r20/external/pytorch/benchmarks/overrides_benchmark/pyspybench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom common import SubTensor, SubWithTorchFunction, WithTorchFunction  # noqa: F401
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard WorkerTensor = torch.tensor
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard WorkerNUM_REPEATS = 1000000
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
13*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
14*da0073e9SAndroid Build Coastguard Worker        description="Run the torch.add for a given class a given number of times."
15*da0073e9SAndroid Build Coastguard Worker    )
16*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
17*da0073e9SAndroid Build Coastguard Worker        "tensor_class", metavar="TensorClass", type=str, help="The class to benchmark."
18*da0073e9SAndroid Build Coastguard Worker    )
19*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
20*da0073e9SAndroid Build Coastguard Worker        "--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats."
21*da0073e9SAndroid Build Coastguard Worker    )
22*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    TensorClass = globals()[args.tensor_class]
25*da0073e9SAndroid Build Coastguard Worker    NUM_REPEATS = args.nreps
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    t1 = TensorClass([1.0])
28*da0073e9SAndroid Build Coastguard Worker    t2 = TensorClass([2.0])
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    for _ in range(NUM_REPEATS):
31*da0073e9SAndroid Build Coastguard Worker        torch.add(t1, t2)
32