xref: /aosp_15_r20/external/pytorch/functorch/benchmarks/cse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport torch
2*da0073e9SAndroid Build Coastguard Workerimport torch.fx as fx
3*da0073e9SAndroid Build Coastguard Workerfrom functorch import make_fx
4*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch.compile_utils import fx_graph_cse
5*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import profile, ProfilerActivity
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerdef profile_it(f, inp):
9*da0073e9SAndroid Build Coastguard Worker    for _ in range(5):
10*da0073e9SAndroid Build Coastguard Worker        f(inp)
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker    itr = 5
13*da0073e9SAndroid Build Coastguard Worker    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
14*da0073e9SAndroid Build Coastguard Worker        for _ in range(itr):
15*da0073e9SAndroid Build Coastguard Worker            f(inp)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    timing = prof.key_averages()
18*da0073e9SAndroid Build Coastguard Worker    cuda_time_total = 0
19*da0073e9SAndroid Build Coastguard Worker    for e in timing:
20*da0073e9SAndroid Build Coastguard Worker        cuda_time_total = cuda_time_total + e.cuda_time_total
21*da0073e9SAndroid Build Coastguard Worker    return cuda_time_total / itr
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerdef profile_function(name, f, inp):
25*da0073e9SAndroid Build Coastguard Worker    fx_g = make_fx(f)(inp)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    new_g = fx_graph_cse(fx_g.graph)
28*da0073e9SAndroid Build Coastguard Worker    new_g = fx.GraphModule(fx_g, new_g)
29*da0073e9SAndroid Build Coastguard Worker    # do not benchmark against the scripted version because script already does some CSE
30*da0073e9SAndroid Build Coastguard Worker    # script_f = torch.jit.script(fx_g)
31*da0073e9SAndroid Build Coastguard Worker    # script_g = torch.jit.script(new_g)
32*da0073e9SAndroid Build Coastguard Worker    # avg_cuda_time_f = profile_it(script_f, inp)
33*da0073e9SAndroid Build Coastguard Worker    # avg_cuda_time_g = profile_it(script_g, inp)
34*da0073e9SAndroid Build Coastguard Worker    avg_cuda_time_f = profile_it(fx_g, inp)
35*da0073e9SAndroid Build Coastguard Worker    avg_cuda_time_g = profile_it(new_g, inp)
36*da0073e9SAndroid Build Coastguard Worker    num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker    print(
39*da0073e9SAndroid Build Coastguard Worker        f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
40*da0073e9SAndroid Build Coastguard Worker    )
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerg_gpu = torch.Generator(device="cuda")
44*da0073e9SAndroid Build Coastguard Workerg_gpu.manual_seed(2147483647)
45*da0073e9SAndroid Build Coastguard Workerinp = torch.randn(2**20, device="cuda", generator=g_gpu)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workerdef f1(x):
49*da0073e9SAndroid Build Coastguard Worker    return x.cos().cos()
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Workerprofile_function("f1", f1, inp)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerdef fsum(x):
56*da0073e9SAndroid Build Coastguard Worker    a = x.sum()
57*da0073e9SAndroid Build Coastguard Worker    b = x.sum()
58*da0073e9SAndroid Build Coastguard Worker    c = x.sum()
59*da0073e9SAndroid Build Coastguard Worker    d = x.sum()
60*da0073e9SAndroid Build Coastguard Worker    return a + b + c + d
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Workerprofile_function("fsum", fsum, inp)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Workerdef fconcat(x):
67*da0073e9SAndroid Build Coastguard Worker    a = torch.cat((x, x))
68*da0073e9SAndroid Build Coastguard Worker    b = torch.cat((x, x))
69*da0073e9SAndroid Build Coastguard Worker    return a + b
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerprofile_function("fconcat", fconcat, inp)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Workerdef fsum2(x):
76*da0073e9SAndroid Build Coastguard Worker    a = x.sum()
77*da0073e9SAndroid Build Coastguard Worker    for _ in range(30):
78*da0073e9SAndroid Build Coastguard Worker        a = a + x.sum()
79*da0073e9SAndroid Build Coastguard Worker    return a
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Workerprofile_function("fsum2", fsum2, inp)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerdef fsummulti(x):
86*da0073e9SAndroid Build Coastguard Worker    a = 0
87*da0073e9SAndroid Build Coastguard Worker    for _ in range(3):
88*da0073e9SAndroid Build Coastguard Worker        a = a + x.sum()
89*da0073e9SAndroid Build Coastguard Worker        a = a * x.sum()
90*da0073e9SAndroid Build Coastguard Worker    return a
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Workerprofile_function("fsummulti", fsummulti, inp)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Workerdef fsummulti2(x):
97*da0073e9SAndroid Build Coastguard Worker    a = 0
98*da0073e9SAndroid Build Coastguard Worker    for _ in range(30):
99*da0073e9SAndroid Build Coastguard Worker        a = a + x.sum()
100*da0073e9SAndroid Build Coastguard Worker        a = a * x.sum()
101*da0073e9SAndroid Build Coastguard Worker    return a
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerprofile_function("fsummulti2", fsummulti2, inp)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Workerdef fcos(x):
108*da0073e9SAndroid Build Coastguard Worker    a = 0
109*da0073e9SAndroid Build Coastguard Worker    for _ in range(3):
110*da0073e9SAndroid Build Coastguard Worker        a = a + x.cos()
111*da0073e9SAndroid Build Coastguard Worker    return a
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Workerprofile_function("fcos", fcos, inp)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Workerdef fcos2(x):
118*da0073e9SAndroid Build Coastguard Worker    a = 0
119*da0073e9SAndroid Build Coastguard Worker    for _ in range(30):
120*da0073e9SAndroid Build Coastguard Worker        a = a + x.cos()
121*da0073e9SAndroid Build Coastguard Worker    return a
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Workerprofile_function("fcos2", fcos2, inp)
125