xref: /aosp_15_r20/external/pytorch/functorch/benchmarks/cse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2import torch.fx as fx
3from functorch import make_fx
4from torch._functorch.compile_utils import fx_graph_cse
5from torch.profiler import profile, ProfilerActivity
6
7
8def profile_it(f, inp):
9    for _ in range(5):
10        f(inp)
11
12    itr = 5
13    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
14        for _ in range(itr):
15            f(inp)
16
17    timing = prof.key_averages()
18    cuda_time_total = 0
19    for e in timing:
20        cuda_time_total = cuda_time_total + e.cuda_time_total
21    return cuda_time_total / itr
22
23
24def profile_function(name, f, inp):
25    fx_g = make_fx(f)(inp)
26
27    new_g = fx_graph_cse(fx_g.graph)
28    new_g = fx.GraphModule(fx_g, new_g)
29    # do not benchmark against the scripted version because script already does some CSE
30    # script_f = torch.jit.script(fx_g)
31    # script_g = torch.jit.script(new_g)
32    # avg_cuda_time_f = profile_it(script_f, inp)
33    # avg_cuda_time_g = profile_it(script_g, inp)
34    avg_cuda_time_f = profile_it(fx_g, inp)
35    avg_cuda_time_g = profile_it(new_g, inp)
36    num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
37
38    print(
39        f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
40    )
41
42
43g_gpu = torch.Generator(device="cuda")
44g_gpu.manual_seed(2147483647)
45inp = torch.randn(2**20, device="cuda", generator=g_gpu)
46
47
48def f1(x):
49    return x.cos().cos()
50
51
52profile_function("f1", f1, inp)
53
54
55def fsum(x):
56    a = x.sum()
57    b = x.sum()
58    c = x.sum()
59    d = x.sum()
60    return a + b + c + d
61
62
63profile_function("fsum", fsum, inp)
64
65
66def fconcat(x):
67    a = torch.cat((x, x))
68    b = torch.cat((x, x))
69    return a + b
70
71
72profile_function("fconcat", fconcat, inp)
73
74
75def fsum2(x):
76    a = x.sum()
77    for _ in range(30):
78        a = a + x.sum()
79    return a
80
81
82profile_function("fsum2", fsum2, inp)
83
84
85def fsummulti(x):
86    a = 0
87    for _ in range(3):
88        a = a + x.sum()
89        a = a * x.sum()
90    return a
91
92
93profile_function("fsummulti", fsummulti, inp)
94
95
96def fsummulti2(x):
97    a = 0
98    for _ in range(30):
99        a = a + x.sum()
100        a = a * x.sum()
101    return a
102
103
104profile_function("fsummulti2", fsummulti2, inp)
105
106
107def fcos(x):
108    a = 0
109    for _ in range(3):
110        a = a + x.cos()
111    return a
112
113
114profile_function("fcos", fcos, inp)
115
116
117def fcos2(x):
118    a = 0
119    for _ in range(30):
120        a = a + x.cos()
121    return a
122
123
124profile_function("fcos2", fcos2, inp)
125