1*da0073e9SAndroid Build Coastguard Workerfrom pyarkbench import Benchmark, default_args, Timer 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workeruse_new = True 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerclass Basic(Benchmark): 10*da0073e9SAndroid Build Coastguard Worker def benchmark(self): 11*da0073e9SAndroid Build Coastguard Worker x = [torch.ones(200, 200) for i in range(30)] 12*da0073e9SAndroid Build Coastguard Worker with Timer() as big1: 13*da0073e9SAndroid Build Coastguard Worker torch.save(x, "big_tensor.zip", _use_new_zipfile_serialization=use_new) 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker with Timer() as big2: 16*da0073e9SAndroid Build Coastguard Worker v = torch.load("big_tensor.zip") 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker x = [torch.ones(10, 10) for i in range(200)] 19*da0073e9SAndroid Build Coastguard Worker with Timer() as small1: 20*da0073e9SAndroid Build Coastguard Worker torch.save(x, "small_tensor.zip", _use_new_zipfile_serialization=use_new) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker with Timer() as small2: 23*da0073e9SAndroid Build Coastguard Worker v = torch.load("small_tensor.zip") 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker return { 26*da0073e9SAndroid Build Coastguard Worker "Big Tensors Save": big1.ms_duration, 27*da0073e9SAndroid Build Coastguard Worker "Big Tensors Load": big2.ms_duration, 28*da0073e9SAndroid Build Coastguard Worker "Small Tensors Save": small1.ms_duration, 29*da0073e9SAndroid Build Coastguard Worker "Small Tensors Load": small2.ms_duration, 30*da0073e9SAndroid Build Coastguard Worker } 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 34*da0073e9SAndroid Build Coastguard Worker bench = Basic(*default_args.bench()) 35*da0073e9SAndroid Build Coastguard Worker print("Use zipfile serialization:", use_new) 36*da0073e9SAndroid Build Coastguard Worker results = bench.run() 37*da0073e9SAndroid Build Coastguard Worker bench.print_stats(results, stats=["mean", "median"]) 38