xref: /aosp_15_r20/external/pytorch/benchmarks/fuser/run_benchmarks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport inspect
2*da0073e9SAndroid Build Coastguard Workerimport itertools
3*da0073e9SAndroid Build Coastguard Workerimport sys
4*da0073e9SAndroid Build Coastguard Workerimport time
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport click
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workertorch.set_num_threads(1)
12*da0073e9SAndroid Build Coastguard Workertorch._C._debug_set_fusion_group_inlining(False)
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef rand(*shape):
16*da0073e9SAndroid Build Coastguard Worker    return torch.rand(*shape).mul(16).add(1)
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker# ------------------------------------------------------------------------------
20*da0073e9SAndroid Build Coastguard Worker# Shape test cases
21*da0073e9SAndroid Build Coastguard Worker# ------------------------------------------------------------------------------
22*da0073e9SAndroid Build Coastguard Workerdef scalar():
23*da0073e9SAndroid Build Coastguard Worker    return (rand(1), rand(1))
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerdef small():
27*da0073e9SAndroid Build Coastguard Worker    return (rand(32), rand(32))
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerdef small_2d():
31*da0073e9SAndroid Build Coastguard Worker    return (rand(1, 32), rand(1, 32))
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerdef small_broadcast():
35*da0073e9SAndroid Build Coastguard Worker    return (rand(4, 32), rand(32))
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerdef medium():
39*da0073e9SAndroid Build Coastguard Worker    return (rand(32, 12, 64, 64), rand(32, 12, 64, 64))
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerdef medium_sliced():
43*da0073e9SAndroid Build Coastguard Worker    return (rand(32, 12, 64, 64)[..., ::2], rand(32, 12, 64, 64)[..., ::2])
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerdef medium_transpose():
47*da0073e9SAndroid Build Coastguard Worker    return (
48*da0073e9SAndroid Build Coastguard Worker        rand(32, 12, 64, 64).transpose(-1, -2),
49*da0073e9SAndroid Build Coastguard Worker        rand(32, 12, 64, 64).transpose(-1, -2),
50*da0073e9SAndroid Build Coastguard Worker    )
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerdef medium2():
54*da0073e9SAndroid Build Coastguard Worker    return (rand(32, 3, 224, 224), rand(32, 3, 224, 224))
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Workerdef medium3d():
58*da0073e9SAndroid Build Coastguard Worker    return (rand(16, 32, 64), rand(16, 32, 64))
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerdef medium_channels_last():
62*da0073e9SAndroid Build Coastguard Worker    return (
63*da0073e9SAndroid Build Coastguard Worker        rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
64*da0073e9SAndroid Build Coastguard Worker        rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
65*da0073e9SAndroid Build Coastguard Worker    )
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Workerdef medium_broadcast():
69*da0073e9SAndroid Build Coastguard Worker    return (rand(32, 12, 64, 64), rand(64))
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerdef medium_broadcast_channels_last():
73*da0073e9SAndroid Build Coastguard Worker    return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), rand(3, 1, 1))
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Workerdef large():
77*da0073e9SAndroid Build Coastguard Worker    return (rand(8192, 8192), rand(8192, 8192))
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Workerdef large_transpose():
81*da0073e9SAndroid Build Coastguard Worker    return (rand(8192, 8192).transpose(0, 1), rand(8192, 8192).transpose(0, 1))
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Workerdef large_channels_last():
85*da0073e9SAndroid Build Coastguard Worker    return (
86*da0073e9SAndroid Build Coastguard Worker        rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
87*da0073e9SAndroid Build Coastguard Worker        rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
88*da0073e9SAndroid Build Coastguard Worker    )
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Workerdef broadcast_narrow_57611():
92*da0073e9SAndroid Build Coastguard Worker    return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Workerdef large_broadcast_66816():
96*da0073e9SAndroid Build Coastguard Worker    return (rand(64, 8, 256, 162), rand(256, 162))
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker# ------------------------------------------------------------------------------
100*da0073e9SAndroid Build Coastguard Worker# Operator test cases
101*da0073e9SAndroid Build Coastguard Worker# ------------------------------------------------------------------------------
102*da0073e9SAndroid Build Coastguard Workerdef add(a, b):
103*da0073e9SAndroid Build Coastguard Worker    return 3 * a + b
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Workerdef sub(a, b):
107*da0073e9SAndroid Build Coastguard Worker    return 3 * a - b
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Workerdef mul(a, b):
111*da0073e9SAndroid Build Coastguard Worker    return 3 * a * b
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Workerdef div(a, b):
115*da0073e9SAndroid Build Coastguard Worker    return 3 * a / b
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Workerdef relu(a):
119*da0073e9SAndroid Build Coastguard Worker    return (3 * a).relu()
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Workerdef sigmoid(a):
123*da0073e9SAndroid Build Coastguard Worker    return (3 * a).sigmoid()
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Workerdef tanh(a):
127*da0073e9SAndroid Build Coastguard Worker    return (3 * a).tanh()
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Workerdef log(a):
131*da0073e9SAndroid Build Coastguard Worker    return (3 * a).log()
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Workerdef exp(a):
135*da0073e9SAndroid Build Coastguard Worker    return (3 * a).exp()
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Workerdef square(a):
139*da0073e9SAndroid Build Coastguard Worker    return (3 * a) ** 2
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Workerdef fma(a, b):
143*da0073e9SAndroid Build Coastguard Worker    return a * b + b
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Workerdef mul_mul_add_66816(a, b, c):
147*da0073e9SAndroid Build Coastguard Worker    return (a * b) + (a * c)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Workerdef hardswish_int(a):
151*da0073e9SAndroid Build Coastguard Worker    return a * (a + 3).clamp(0, 6) / 6
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Workerdef hardswish(a):
155*da0073e9SAndroid Build Coastguard Worker    return a * (a + 3).clamp(0.0, 6.0) / 6
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Workerdef native_hardswish(a):
159*da0073e9SAndroid Build Coastguard Worker    return torch._C._nn.hardswish(a * 3)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Workerdef softplus(a):
163*da0073e9SAndroid Build Coastguard Worker    return (a * 1.0).exp().log1p() / 1.0
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Workerdef mish(a):
167*da0073e9SAndroid Build Coastguard Worker    return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard WorkerSHAPES = [
171*da0073e9SAndroid Build Coastguard Worker    scalar,
172*da0073e9SAndroid Build Coastguard Worker    small,
173*da0073e9SAndroid Build Coastguard Worker    small_2d,
174*da0073e9SAndroid Build Coastguard Worker    small_broadcast,
175*da0073e9SAndroid Build Coastguard Worker    medium,
176*da0073e9SAndroid Build Coastguard Worker    medium2,
177*da0073e9SAndroid Build Coastguard Worker    medium3d,
178*da0073e9SAndroid Build Coastguard Worker    medium_sliced,
179*da0073e9SAndroid Build Coastguard Worker    medium_transpose,
180*da0073e9SAndroid Build Coastguard Worker    medium_channels_last,
181*da0073e9SAndroid Build Coastguard Worker    medium_broadcast,
182*da0073e9SAndroid Build Coastguard Worker    medium_broadcast_channels_last,
183*da0073e9SAndroid Build Coastguard Worker    large,
184*da0073e9SAndroid Build Coastguard Worker    large_transpose,
185*da0073e9SAndroid Build Coastguard Worker    large_channels_last,
186*da0073e9SAndroid Build Coastguard Worker    broadcast_narrow_57611,
187*da0073e9SAndroid Build Coastguard Worker    large_broadcast_66816,
188*da0073e9SAndroid Build Coastguard Worker]
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard WorkerOPERATORS = [
191*da0073e9SAndroid Build Coastguard Worker    add,
192*da0073e9SAndroid Build Coastguard Worker    sub,
193*da0073e9SAndroid Build Coastguard Worker    mul,
194*da0073e9SAndroid Build Coastguard Worker    div,
195*da0073e9SAndroid Build Coastguard Worker    relu,
196*da0073e9SAndroid Build Coastguard Worker    sigmoid,
197*da0073e9SAndroid Build Coastguard Worker    tanh,
198*da0073e9SAndroid Build Coastguard Worker    log,
199*da0073e9SAndroid Build Coastguard Worker    exp,
200*da0073e9SAndroid Build Coastguard Worker    square,
201*da0073e9SAndroid Build Coastguard Worker    fma,
202*da0073e9SAndroid Build Coastguard Worker    mul_mul_add_66816,
203*da0073e9SAndroid Build Coastguard Worker    hardswish_int,
204*da0073e9SAndroid Build Coastguard Worker    hardswish,
205*da0073e9SAndroid Build Coastguard Worker    native_hardswish,
206*da0073e9SAndroid Build Coastguard Worker    softplus,
207*da0073e9SAndroid Build Coastguard Worker    mish,
208*da0073e9SAndroid Build Coastguard Worker]
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Workerdef time_cpu(fn, args, iters):
212*da0073e9SAndroid Build Coastguard Worker    s = time.perf_counter()
213*da0073e9SAndroid Build Coastguard Worker    for _ in range(iters):
214*da0073e9SAndroid Build Coastguard Worker        fn(*args)
215*da0073e9SAndroid Build Coastguard Worker    e = time.perf_counter()
216*da0073e9SAndroid Build Coastguard Worker    return e - s
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Workerdef time_cuda(fn, args, iters):
220*da0073e9SAndroid Build Coastguard Worker    start = torch.cuda.Event(enable_timing=True)
221*da0073e9SAndroid Build Coastguard Worker    end = torch.cuda.Event(enable_timing=True)
222*da0073e9SAndroid Build Coastguard Worker    start.record()
223*da0073e9SAndroid Build Coastguard Worker    for _ in range(iters):
224*da0073e9SAndroid Build Coastguard Worker        fn(*args)
225*da0073e9SAndroid Build Coastguard Worker    end.record()
226*da0073e9SAndroid Build Coastguard Worker    torch.cuda.synchronize()
227*da0073e9SAndroid Build Coastguard Worker    return start.elapsed_time(end) / 1e3
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Workerdef benchmark_with_timer(fn, args, timer):
231*da0073e9SAndroid Build Coastguard Worker    timer(fn, args, 3)
232*da0073e9SAndroid Build Coastguard Worker    calibration = timer(fn, args, 1)
233*da0073e9SAndroid Build Coastguard Worker    iters = int(1.0 / calibration)
234*da0073e9SAndroid Build Coastguard Worker    return timer(fn, args, iters) / iters
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Workerdef benchmark(fn, args):
238*da0073e9SAndroid Build Coastguard Worker    timer = time_cpu if args[0].device.type == "cpu" else time_cuda
239*da0073e9SAndroid Build Coastguard Worker    return benchmark_with_timer(fn, args, timer)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Workerdef micros(s):
243*da0073e9SAndroid Build Coastguard Worker    return f"{s * 1e6:.1f}"
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Workerdef with_nvfuser():
247*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_override_can_fuse_on_cpu(False)
248*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_override_can_fuse_on_gpu(False)
249*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_texpr_fuser_enabled(False)
250*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_nvfuser_enabled(True)
251*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_profiling_executor(True)
252*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_profiling_mode(True)
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Workerdef with_nnc():
256*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_override_can_fuse_on_cpu(True)
257*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_override_can_fuse_on_gpu(True)
258*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_texpr_fuser_enabled(True)
259*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_nvfuser_enabled(False)
260*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_profiling_executor(True)
261*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_profiling_mode(True)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Workerdef with_legacy():
265*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_override_can_fuse_on_cpu(True)
266*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_override_can_fuse_on_gpu(True)
267*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_texpr_fuser_enabled(False)
268*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_nvfuser_enabled(False)
269*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_profiling_executor(False)
270*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_profiling_mode(False)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker@click.command()
274*da0073e9SAndroid Build Coastguard Worker@click.option("--operators", default=None)
275*da0073e9SAndroid Build Coastguard Worker@click.option("--shapes", default=None)
276*da0073e9SAndroid Build Coastguard Workerdef run_benchmarks(operators, shapes):
277*da0073e9SAndroid Build Coastguard Worker    if operators is None:
278*da0073e9SAndroid Build Coastguard Worker        operators = OPERATORS
279*da0073e9SAndroid Build Coastguard Worker    else:
280*da0073e9SAndroid Build Coastguard Worker        operators = [globals()[k] for k in operators.split(",")]
281*da0073e9SAndroid Build Coastguard Worker    if shapes is None:
282*da0073e9SAndroid Build Coastguard Worker        shapes = SHAPES
283*da0073e9SAndroid Build Coastguard Worker    else:
284*da0073e9SAndroid Build Coastguard Worker        shapes = [globals()[k] for k in shapes.split(",")]
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    print("fuser,device,operator,shape,time")
287*da0073e9SAndroid Build Coastguard Worker    results = []
288*da0073e9SAndroid Build Coastguard Worker    for shape, operator in itertools.product(shapes, operators):
289*da0073e9SAndroid Build Coastguard Worker        nargs = len(inspect.signature(operator).parameters)
290*da0073e9SAndroid Build Coastguard Worker        args = shape()
291*da0073e9SAndroid Build Coastguard Worker        if nargs > len(args):
292*da0073e9SAndroid Build Coastguard Worker            args = list(args)
293*da0073e9SAndroid Build Coastguard Worker            args += [args[-1]] * (nargs - len(args))
294*da0073e9SAndroid Build Coastguard Worker        args = args[:nargs]
295*da0073e9SAndroid Build Coastguard Worker        args = [arg.to("cuda") for arg in args]
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        result = benchmark(operator, args)
298*da0073e9SAndroid Build Coastguard Worker        print(
299*da0073e9SAndroid Build Coastguard Worker            ",".join(
300*da0073e9SAndroid Build Coastguard Worker                [
301*da0073e9SAndroid Build Coastguard Worker                    "eager",
302*da0073e9SAndroid Build Coastguard Worker                    args[0].device.type,
303*da0073e9SAndroid Build Coastguard Worker                    operator.__name__,
304*da0073e9SAndroid Build Coastguard Worker                    shape.__name__,
305*da0073e9SAndroid Build Coastguard Worker                    micros(result),
306*da0073e9SAndroid Build Coastguard Worker                ]
307*da0073e9SAndroid Build Coastguard Worker            )
308*da0073e9SAndroid Build Coastguard Worker        )
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        def bench(name):
311*da0073e9SAndroid Build Coastguard Worker            nnc_op = torch.jit.trace(operator, args)
312*da0073e9SAndroid Build Coastguard Worker            result = benchmark(nnc_op, args)
313*da0073e9SAndroid Build Coastguard Worker            print(
314*da0073e9SAndroid Build Coastguard Worker                ",".join(
315*da0073e9SAndroid Build Coastguard Worker                    [
316*da0073e9SAndroid Build Coastguard Worker                        name,
317*da0073e9SAndroid Build Coastguard Worker                        args[0].device.type,
318*da0073e9SAndroid Build Coastguard Worker                        operator.__name__,
319*da0073e9SAndroid Build Coastguard Worker                        shape.__name__,
320*da0073e9SAndroid Build Coastguard Worker                        micros(result),
321*da0073e9SAndroid Build Coastguard Worker                    ]
322*da0073e9SAndroid Build Coastguard Worker                )
323*da0073e9SAndroid Build Coastguard Worker            )
324*da0073e9SAndroid Build Coastguard Worker            sys.stdout.flush()
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker        with_nnc()
327*da0073e9SAndroid Build Coastguard Worker        bench("nnc")
328*da0073e9SAndroid Build Coastguard Worker        with_nvfuser()
329*da0073e9SAndroid Build Coastguard Worker        bench("nvfuser")
330*da0073e9SAndroid Build Coastguard Worker        with_legacy()
331*da0073e9SAndroid Build Coastguard Worker        bench("legacy")
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
335*da0073e9SAndroid Build Coastguard Worker    run_benchmarks()
336