1*da0073e9SAndroid Build Coastguard Workerimport torch 2*da0073e9SAndroid Build Coastguard Workerimport torch.fx as fx 3*da0073e9SAndroid Build Coastguard Workerfrom functorch import make_fx 4*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch.compile_utils import fx_graph_cse 5*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import profile, ProfilerActivity 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerdef profile_it(f, inp): 9*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 10*da0073e9SAndroid Build Coastguard Worker f(inp) 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker itr = 5 13*da0073e9SAndroid Build Coastguard Worker with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: 14*da0073e9SAndroid Build Coastguard Worker for _ in range(itr): 15*da0073e9SAndroid Build Coastguard Worker f(inp) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker timing = prof.key_averages() 18*da0073e9SAndroid Build Coastguard Worker cuda_time_total = 0 19*da0073e9SAndroid Build Coastguard Worker for e in timing: 20*da0073e9SAndroid Build Coastguard Worker cuda_time_total = cuda_time_total + e.cuda_time_total 21*da0073e9SAndroid Build Coastguard Worker return cuda_time_total / itr 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerdef profile_function(name, f, inp): 25*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(f)(inp) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker new_g = fx_graph_cse(fx_g.graph) 28*da0073e9SAndroid Build Coastguard Worker new_g = fx.GraphModule(fx_g, new_g) 29*da0073e9SAndroid Build Coastguard Worker # do not benchmark against the scripted version because script already does some CSE 30*da0073e9SAndroid Build Coastguard Worker # script_f = torch.jit.script(fx_g) 31*da0073e9SAndroid Build Coastguard Worker # script_g = torch.jit.script(new_g) 32*da0073e9SAndroid Build Coastguard Worker # avg_cuda_time_f = profile_it(script_f, inp) 33*da0073e9SAndroid Build Coastguard Worker # avg_cuda_time_g = profile_it(script_g, inp) 34*da0073e9SAndroid Build Coastguard Worker avg_cuda_time_f = profile_it(fx_g, inp) 35*da0073e9SAndroid Build Coastguard Worker avg_cuda_time_g = profile_it(new_g, inp) 36*da0073e9SAndroid Build Coastguard Worker num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker print( 39*da0073e9SAndroid Build Coastguard Worker f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}" 40*da0073e9SAndroid Build Coastguard Worker ) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerg_gpu = torch.Generator(device="cuda") 44*da0073e9SAndroid Build Coastguard Workerg_gpu.manual_seed(2147483647) 45*da0073e9SAndroid Build Coastguard Workerinp = torch.randn(2**20, device="cuda", generator=g_gpu) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workerdef f1(x): 49*da0073e9SAndroid Build Coastguard Worker return x.cos().cos() 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Workerprofile_function("f1", f1, inp) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerdef fsum(x): 56*da0073e9SAndroid Build Coastguard Worker a = x.sum() 57*da0073e9SAndroid Build Coastguard Worker b = x.sum() 58*da0073e9SAndroid Build Coastguard Worker c = x.sum() 59*da0073e9SAndroid Build Coastguard Worker d = x.sum() 60*da0073e9SAndroid Build Coastguard Worker return a + b + c + d 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Workerprofile_function("fsum", fsum, inp) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Workerdef fconcat(x): 67*da0073e9SAndroid Build Coastguard Worker a = torch.cat((x, x)) 68*da0073e9SAndroid Build Coastguard Worker b = torch.cat((x, x)) 69*da0073e9SAndroid Build Coastguard Worker return a + b 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerprofile_function("fconcat", fconcat, inp) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Workerdef fsum2(x): 76*da0073e9SAndroid Build Coastguard Worker a = x.sum() 77*da0073e9SAndroid Build Coastguard Worker for _ in range(30): 78*da0073e9SAndroid Build Coastguard Worker a = a + x.sum() 79*da0073e9SAndroid Build Coastguard Worker return a 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Workerprofile_function("fsum2", fsum2, inp) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerdef fsummulti(x): 86*da0073e9SAndroid Build Coastguard Worker a = 0 87*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 88*da0073e9SAndroid Build Coastguard Worker a = a + x.sum() 89*da0073e9SAndroid Build Coastguard Worker a = a * x.sum() 90*da0073e9SAndroid Build Coastguard Worker return a 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Workerprofile_function("fsummulti", fsummulti, inp) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Workerdef fsummulti2(x): 97*da0073e9SAndroid Build Coastguard Worker a = 0 98*da0073e9SAndroid Build Coastguard Worker for _ in range(30): 99*da0073e9SAndroid Build Coastguard Worker a = a + x.sum() 100*da0073e9SAndroid Build Coastguard Worker a = a * x.sum() 101*da0073e9SAndroid Build Coastguard Worker return a 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerprofile_function("fsummulti2", fsummulti2, inp) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Workerdef fcos(x): 108*da0073e9SAndroid Build Coastguard Worker a = 0 109*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 110*da0073e9SAndroid Build Coastguard Worker a = a + x.cos() 111*da0073e9SAndroid Build Coastguard Worker return a 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Workerprofile_function("fcos", fcos, inp) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Workerdef fcos2(x): 118*da0073e9SAndroid Build Coastguard Worker a = 0 119*da0073e9SAndroid Build Coastguard Worker for _ in range(30): 120*da0073e9SAndroid Build Coastguard Worker a = a + x.cos() 121*da0073e9SAndroid Build Coastguard Worker return a 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Workerprofile_function("fcos2", fcos2, inp) 125