xref: /aosp_15_r20/external/pytorch/benchmarks/overrides_benchmark/bench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport time
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom common import SubTensor, SubWithTorchFunction, WithTorchFunction
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerNUM_REPEATS = 1000
10*da0073e9SAndroid Build Coastguard WorkerNUM_REPEAT_OF_REPEATS = 1000
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerdef bench(t1, t2):
14*da0073e9SAndroid Build Coastguard Worker    bench_times = []
15*da0073e9SAndroid Build Coastguard Worker    for _ in range(NUM_REPEAT_OF_REPEATS):
16*da0073e9SAndroid Build Coastguard Worker        time_start = time.time()
17*da0073e9SAndroid Build Coastguard Worker        for _ in range(NUM_REPEATS):
18*da0073e9SAndroid Build Coastguard Worker            torch.add(t1, t2)
19*da0073e9SAndroid Build Coastguard Worker        bench_times.append(time.time() - time_start)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    bench_time = float(torch.min(torch.tensor(bench_times))) / 1000
22*da0073e9SAndroid Build Coastguard Worker    bench_std = float(torch.std(torch.tensor(bench_times))) / 1000
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    return bench_time, bench_std
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerdef main():
28*da0073e9SAndroid Build Coastguard Worker    global NUM_REPEATS
29*da0073e9SAndroid Build Coastguard Worker    global NUM_REPEAT_OF_REPEATS
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
32*da0073e9SAndroid Build Coastguard Worker        description="Run the __torch_function__ benchmarks."
33*da0073e9SAndroid Build Coastguard Worker    )
34*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
35*da0073e9SAndroid Build Coastguard Worker        "--nreps",
36*da0073e9SAndroid Build Coastguard Worker        "-n",
37*da0073e9SAndroid Build Coastguard Worker        type=int,
38*da0073e9SAndroid Build Coastguard Worker        default=NUM_REPEATS,
39*da0073e9SAndroid Build Coastguard Worker        help="The number of repeats for one measurement.",
40*da0073e9SAndroid Build Coastguard Worker    )
41*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
42*da0073e9SAndroid Build Coastguard Worker        "--nrepreps",
43*da0073e9SAndroid Build Coastguard Worker        "-m",
44*da0073e9SAndroid Build Coastguard Worker        type=int,
45*da0073e9SAndroid Build Coastguard Worker        default=NUM_REPEAT_OF_REPEATS,
46*da0073e9SAndroid Build Coastguard Worker        help="The number of measurements.",
47*da0073e9SAndroid Build Coastguard Worker    )
48*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    NUM_REPEATS = args.nreps
51*da0073e9SAndroid Build Coastguard Worker    NUM_REPEAT_OF_REPEATS = args.nrepreps
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    types = torch.tensor, SubTensor, WithTorchFunction, SubWithTorchFunction
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    for t in types:
56*da0073e9SAndroid Build Coastguard Worker        tensor_1 = t([1.0])
57*da0073e9SAndroid Build Coastguard Worker        tensor_2 = t([2.0])
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        bench_min, bench_std = bench(tensor_1, tensor_2)
60*da0073e9SAndroid Build Coastguard Worker        print(
61*da0073e9SAndroid Build Coastguard Worker            f"Type {t.__name__} had a minimum time of {10**6 * bench_min} us"
62*da0073e9SAndroid Build Coastguard Worker            f" and a standard deviation of {(10**6) * bench_std} us."
63*da0073e9SAndroid Build Coastguard Worker        )
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
67*da0073e9SAndroid Build Coastguard Worker    main()
68