import torch import torch.fx as fx from functorch import make_fx from torch._functorch.compile_utils import fx_graph_cse from torch.profiler import profile, ProfilerActivity def profile_it(f, inp): for _ in range(5): f(inp) itr = 5 with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: for _ in range(itr): f(inp) timing = prof.key_averages() cuda_time_total = 0 for e in timing: cuda_time_total = cuda_time_total + e.cuda_time_total return cuda_time_total / itr def profile_function(name, f, inp): fx_g = make_fx(f)(inp) new_g = fx_graph_cse(fx_g.graph) new_g = fx.GraphModule(fx_g, new_g) # do not benchmark against the scripted version because script already does some CSE # script_f = torch.jit.script(fx_g) # script_g = torch.jit.script(new_g) # avg_cuda_time_f = profile_it(script_f, inp) # avg_cuda_time_g = profile_it(script_g, inp) avg_cuda_time_f = profile_it(fx_g, inp) avg_cuda_time_g = profile_it(new_g, inp) num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes) print( f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}" ) g_gpu = torch.Generator(device="cuda") g_gpu.manual_seed(2147483647) inp = torch.randn(2**20, device="cuda", generator=g_gpu) def f1(x): return x.cos().cos() profile_function("f1", f1, inp) def fsum(x): a = x.sum() b = x.sum() c = x.sum() d = x.sum() return a + b + c + d profile_function("fsum", fsum, inp) def fconcat(x): a = torch.cat((x, x)) b = torch.cat((x, x)) return a + b profile_function("fconcat", fconcat, inp) def fsum2(x): a = x.sum() for _ in range(30): a = a + x.sum() return a profile_function("fsum2", fsum2, inp) def fsummulti(x): a = 0 for _ in range(3): a = a + x.sum() a = a * x.sum() return a profile_function("fsummulti", fsummulti, inp) def fsummulti2(x): a = 0 for _ in range(30): a = a + x.sum() a = a * x.sum() return a profile_function("fsummulti2", fsummulti2, inp) def fcos(x): a = 0 for _ in range(3): a = a + x.cos() return a profile_function("fcos", fcos, inp) def fcos2(x): a = 0 for _ in range(30): a = a + x.cos() return a profile_function("fcos2", fcos2, inp)