1import torch 2import torch.fx as fx 3from functorch import make_fx 4from torch._functorch.compile_utils import fx_graph_cse 5from torch.profiler import profile, ProfilerActivity 6 7 8def profile_it(f, inp): 9 for _ in range(5): 10 f(inp) 11 12 itr = 5 13 with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: 14 for _ in range(itr): 15 f(inp) 16 17 timing = prof.key_averages() 18 cuda_time_total = 0 19 for e in timing: 20 cuda_time_total = cuda_time_total + e.cuda_time_total 21 return cuda_time_total / itr 22 23 24def profile_function(name, f, inp): 25 fx_g = make_fx(f)(inp) 26 27 new_g = fx_graph_cse(fx_g.graph) 28 new_g = fx.GraphModule(fx_g, new_g) 29 # do not benchmark against the scripted version because script already does some CSE 30 # script_f = torch.jit.script(fx_g) 31 # script_g = torch.jit.script(new_g) 32 # avg_cuda_time_f = profile_it(script_f, inp) 33 # avg_cuda_time_g = profile_it(script_g, inp) 34 avg_cuda_time_f = profile_it(fx_g, inp) 35 avg_cuda_time_g = profile_it(new_g, inp) 36 num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes) 37 38 print( 39 f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}" 40 ) 41 42 43g_gpu = torch.Generator(device="cuda") 44g_gpu.manual_seed(2147483647) 45inp = torch.randn(2**20, device="cuda", generator=g_gpu) 46 47 48def f1(x): 49 return x.cos().cos() 50 51 52profile_function("f1", f1, inp) 53 54 55def fsum(x): 56 a = x.sum() 57 b = x.sum() 58 c = x.sum() 59 d = x.sum() 60 return a + b + c + d 61 62 63profile_function("fsum", fsum, inp) 64 65 66def fconcat(x): 67 a = torch.cat((x, x)) 68 b = torch.cat((x, x)) 69 return a + b 70 71 72profile_function("fconcat", fconcat, inp) 73 74 75def fsum2(x): 76 a = x.sum() 77 for _ in range(30): 78 a = a + x.sum() 79 return a 80 81 82profile_function("fsum2", fsum2, inp) 83 84 85def fsummulti(x): 86 a = 0 87 for _ in range(3): 88 a = a + x.sum() 89 a = a * x.sum() 90 return a 91 92 93profile_function("fsummulti", fsummulti, inp) 94 95 96def fsummulti2(x): 97 a = 0 98 for _ in range(30): 99 a = a + x.sum() 100 a = a * x.sum() 101 return a 102 103 104profile_function("fsummulti2", fsummulti2, inp) 105 106 107def fcos(x): 108 a = 0 109 for _ in range(3): 110 a = a + x.cos() 111 return a 112 113 114profile_function("fcos", fcos, inp) 115 116 117def fcos2(x): 118 a = 0 119 for _ in range(30): 120 a = a + x.cos() 121 return a 122 123 124profile_function("fcos2", fcos2, inp) 125