xref: /aosp_15_r20/external/pytorch/benchmarks/serialization/simple_measurement.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom pyarkbench import Benchmark, default_args, Timer
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workeruse_new = True
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerclass Basic(Benchmark):
10*da0073e9SAndroid Build Coastguard Worker    def benchmark(self):
11*da0073e9SAndroid Build Coastguard Worker        x = [torch.ones(200, 200) for i in range(30)]
12*da0073e9SAndroid Build Coastguard Worker        with Timer() as big1:
13*da0073e9SAndroid Build Coastguard Worker            torch.save(x, "big_tensor.zip", _use_new_zipfile_serialization=use_new)
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker        with Timer() as big2:
16*da0073e9SAndroid Build Coastguard Worker            v = torch.load("big_tensor.zip")
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker        x = [torch.ones(10, 10) for i in range(200)]
19*da0073e9SAndroid Build Coastguard Worker        with Timer() as small1:
20*da0073e9SAndroid Build Coastguard Worker            torch.save(x, "small_tensor.zip", _use_new_zipfile_serialization=use_new)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker        with Timer() as small2:
23*da0073e9SAndroid Build Coastguard Worker            v = torch.load("small_tensor.zip")
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        return {
26*da0073e9SAndroid Build Coastguard Worker            "Big Tensors Save": big1.ms_duration,
27*da0073e9SAndroid Build Coastguard Worker            "Big Tensors Load": big2.ms_duration,
28*da0073e9SAndroid Build Coastguard Worker            "Small Tensors Save": small1.ms_duration,
29*da0073e9SAndroid Build Coastguard Worker            "Small Tensors Load": small2.ms_duration,
30*da0073e9SAndroid Build Coastguard Worker        }
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
34*da0073e9SAndroid Build Coastguard Worker    bench = Basic(*default_args.bench())
35*da0073e9SAndroid Build Coastguard Worker    print("Use zipfile serialization:", use_new)
36*da0073e9SAndroid Build Coastguard Worker    results = bench.run()
37*da0073e9SAndroid Build Coastguard Worker    bench.print_stats(results, stats=["mean", "median"])
38