xref: /aosp_15_r20/external/pytorch/benchmarks/fuser/run_benchmarks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import inspect
2import itertools
3import sys
4import time
5
6import click
7
8import torch
9
10
11torch.set_num_threads(1)
12torch._C._debug_set_fusion_group_inlining(False)
13
14
15def rand(*shape):
16    return torch.rand(*shape).mul(16).add(1)
17
18
19# ------------------------------------------------------------------------------
20# Shape test cases
21# ------------------------------------------------------------------------------
22def scalar():
23    return (rand(1), rand(1))
24
25
26def small():
27    return (rand(32), rand(32))
28
29
30def small_2d():
31    return (rand(1, 32), rand(1, 32))
32
33
34def small_broadcast():
35    return (rand(4, 32), rand(32))
36
37
38def medium():
39    return (rand(32, 12, 64, 64), rand(32, 12, 64, 64))
40
41
42def medium_sliced():
43    return (rand(32, 12, 64, 64)[..., ::2], rand(32, 12, 64, 64)[..., ::2])
44
45
46def medium_transpose():
47    return (
48        rand(32, 12, 64, 64).transpose(-1, -2),
49        rand(32, 12, 64, 64).transpose(-1, -2),
50    )
51
52
53def medium2():
54    return (rand(32, 3, 224, 224), rand(32, 3, 224, 224))
55
56
57def medium3d():
58    return (rand(16, 32, 64), rand(16, 32, 64))
59
60
61def medium_channels_last():
62    return (
63        rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
64        rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
65    )
66
67
68def medium_broadcast():
69    return (rand(32, 12, 64, 64), rand(64))
70
71
72def medium_broadcast_channels_last():
73    return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), rand(3, 1, 1))
74
75
76def large():
77    return (rand(8192, 8192), rand(8192, 8192))
78
79
80def large_transpose():
81    return (rand(8192, 8192).transpose(0, 1), rand(8192, 8192).transpose(0, 1))
82
83
84def large_channels_last():
85    return (
86        rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
87        rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
88    )
89
90
91def broadcast_narrow_57611():
92    return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
93
94
95def large_broadcast_66816():
96    return (rand(64, 8, 256, 162), rand(256, 162))
97
98
99# ------------------------------------------------------------------------------
100# Operator test cases
101# ------------------------------------------------------------------------------
102def add(a, b):
103    return 3 * a + b
104
105
106def sub(a, b):
107    return 3 * a - b
108
109
110def mul(a, b):
111    return 3 * a * b
112
113
114def div(a, b):
115    return 3 * a / b
116
117
118def relu(a):
119    return (3 * a).relu()
120
121
122def sigmoid(a):
123    return (3 * a).sigmoid()
124
125
126def tanh(a):
127    return (3 * a).tanh()
128
129
130def log(a):
131    return (3 * a).log()
132
133
134def exp(a):
135    return (3 * a).exp()
136
137
138def square(a):
139    return (3 * a) ** 2
140
141
142def fma(a, b):
143    return a * b + b
144
145
146def mul_mul_add_66816(a, b, c):
147    return (a * b) + (a * c)
148
149
150def hardswish_int(a):
151    return a * (a + 3).clamp(0, 6) / 6
152
153
154def hardswish(a):
155    return a * (a + 3).clamp(0.0, 6.0) / 6
156
157
158def native_hardswish(a):
159    return torch._C._nn.hardswish(a * 3)
160
161
162def softplus(a):
163    return (a * 1.0).exp().log1p() / 1.0
164
165
166def mish(a):
167    return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
168
169
170SHAPES = [
171    scalar,
172    small,
173    small_2d,
174    small_broadcast,
175    medium,
176    medium2,
177    medium3d,
178    medium_sliced,
179    medium_transpose,
180    medium_channels_last,
181    medium_broadcast,
182    medium_broadcast_channels_last,
183    large,
184    large_transpose,
185    large_channels_last,
186    broadcast_narrow_57611,
187    large_broadcast_66816,
188]
189
190OPERATORS = [
191    add,
192    sub,
193    mul,
194    div,
195    relu,
196    sigmoid,
197    tanh,
198    log,
199    exp,
200    square,
201    fma,
202    mul_mul_add_66816,
203    hardswish_int,
204    hardswish,
205    native_hardswish,
206    softplus,
207    mish,
208]
209
210
211def time_cpu(fn, args, iters):
212    s = time.perf_counter()
213    for _ in range(iters):
214        fn(*args)
215    e = time.perf_counter()
216    return e - s
217
218
219def time_cuda(fn, args, iters):
220    start = torch.cuda.Event(enable_timing=True)
221    end = torch.cuda.Event(enable_timing=True)
222    start.record()
223    for _ in range(iters):
224        fn(*args)
225    end.record()
226    torch.cuda.synchronize()
227    return start.elapsed_time(end) / 1e3
228
229
230def benchmark_with_timer(fn, args, timer):
231    timer(fn, args, 3)
232    calibration = timer(fn, args, 1)
233    iters = int(1.0 / calibration)
234    return timer(fn, args, iters) / iters
235
236
237def benchmark(fn, args):
238    timer = time_cpu if args[0].device.type == "cpu" else time_cuda
239    return benchmark_with_timer(fn, args, timer)
240
241
242def micros(s):
243    return f"{s * 1e6:.1f}"
244
245
246def with_nvfuser():
247    torch._C._jit_override_can_fuse_on_cpu(False)
248    torch._C._jit_override_can_fuse_on_gpu(False)
249    torch._C._jit_set_texpr_fuser_enabled(False)
250    torch._C._jit_set_nvfuser_enabled(True)
251    torch._C._jit_set_profiling_executor(True)
252    torch._C._jit_set_profiling_mode(True)
253
254
255def with_nnc():
256    torch._C._jit_override_can_fuse_on_cpu(True)
257    torch._C._jit_override_can_fuse_on_gpu(True)
258    torch._C._jit_set_texpr_fuser_enabled(True)
259    torch._C._jit_set_nvfuser_enabled(False)
260    torch._C._jit_set_profiling_executor(True)
261    torch._C._jit_set_profiling_mode(True)
262
263
264def with_legacy():
265    torch._C._jit_override_can_fuse_on_cpu(True)
266    torch._C._jit_override_can_fuse_on_gpu(True)
267    torch._C._jit_set_texpr_fuser_enabled(False)
268    torch._C._jit_set_nvfuser_enabled(False)
269    torch._C._jit_set_profiling_executor(False)
270    torch._C._jit_set_profiling_mode(False)
271
272
273@click.command()
274@click.option("--operators", default=None)
275@click.option("--shapes", default=None)
276def run_benchmarks(operators, shapes):
277    if operators is None:
278        operators = OPERATORS
279    else:
280        operators = [globals()[k] for k in operators.split(",")]
281    if shapes is None:
282        shapes = SHAPES
283    else:
284        shapes = [globals()[k] for k in shapes.split(",")]
285
286    print("fuser,device,operator,shape,time")
287    results = []
288    for shape, operator in itertools.product(shapes, operators):
289        nargs = len(inspect.signature(operator).parameters)
290        args = shape()
291        if nargs > len(args):
292            args = list(args)
293            args += [args[-1]] * (nargs - len(args))
294        args = args[:nargs]
295        args = [arg.to("cuda") for arg in args]
296
297        result = benchmark(operator, args)
298        print(
299            ",".join(
300                [
301                    "eager",
302                    args[0].device.type,
303                    operator.__name__,
304                    shape.__name__,
305                    micros(result),
306                ]
307            )
308        )
309
310        def bench(name):
311            nnc_op = torch.jit.trace(operator, args)
312            result = benchmark(nnc_op, args)
313            print(
314                ",".join(
315                    [
316                        name,
317                        args[0].device.type,
318                        operator.__name__,
319                        shape.__name__,
320                        micros(result),
321                    ]
322                )
323            )
324            sys.stdout.flush()
325
326        with_nnc()
327        bench("nnc")
328        with_nvfuser()
329        bench("nvfuser")
330        with_legacy()
331        bench("legacy")
332
333
334if __name__ == "__main__":
335    run_benchmarks()
336