xref: /aosp_15_r20/external/pytorch/test/inductor/test_kernel_benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3import os
4import subprocess
5import sys
6from unittest.mock import patch
7
8import torch
9import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
10from torch._dynamo.testing import rand_strided
11from torch._inductor import config
12from torch._inductor.codecache import PyCodeCache
13from torch._inductor.test_case import run_tests, TestCase
14from torch._inductor.utils import fresh_inductor_cache
15from torch.testing import FileCheck
16from torch.testing._internal.common_device_type import expectedFailureXPU
17from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
18
19
20class TestKernelBenchmark(TestCase):
21    device_type = GPU_TYPE
22
23    @classmethod
24    def setUpClass(cls):
25        cls.exit_stack = contextlib.ExitStack()
26        cls.exit_stack.enter_context(patch.object(config, "benchmark_kernel", True))
27
28    @classmethod
29    def tearDownClass(cls):
30        cls.exit_stack.close()
31
32    def setUp(self):
33        super().setUp()
34        PyCodeCache.cache.clear()
35
36    def get_compiled_module(self):
37        compiled_module = None
38        for v in PyCodeCache.cache.values():
39            if hasattr(v, "benchmark_compiled_module"):
40                self.assertTrue(
41                    compiled_module is None, "Found multiple compiled modules"
42                )
43                compiled_module = v
44
45        self.assertTrue(compiled_module is not None)
46        return compiled_module
47
48    def verify_compiled_kernels(self, GB_count=1):
49        compiled_module = self.get_compiled_module()
50
51        # now run the compiled module in subprocess and check its output
52        bench_out = subprocess.check_output(
53            f"{sys.executable} {compiled_module.__file__} -kc".split(),
54            stderr=subprocess.STDOUT,
55        ).decode()
56
57        # make sure we have the bandwidth information in the output
58        FileCheck().check_count(
59            "GB/s",
60            GB_count,
61            exactly=1,
62        ).run(bench_out)
63
64    def verify_remove_inductor_deps(self, compiled_module):
65        try:
66            out = subprocess.check_output(
67                f"{sys.executable} {compiled_module.__file__}".split(),
68                env={**os.environ.copy(), "TORCHINDUCTOR_DUMP_LAUNCH_PARAMS": "1"},
69                stderr=subprocess.STDOUT,
70            )
71        except subprocess.CalledProcessError as e:
72            print(
73                "Failed when runinng triton code with TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1",
74                e,
75            )
76            print(e.output.decode())
77            raise e
78        from torch.utils._get_clean_triton import get_clean_triton
79
80        cleaned_triton = get_clean_triton(
81            compiled_module.__file__, f"{compiled_module.__file__}.cleaned"
82        )
83        self.assertTrue("@triton_heuristics" not in cleaned_triton)
84        self.assertTrue(".run(" not in cleaned_triton)
85        try:
86            out = subprocess.check_output(
87                f"{sys.executable} {compiled_module.__file__}.cleaned".split(),
88                stderr=subprocess.STDOUT,
89            )
90        except subprocess.CalledProcessError as e:
91            print("Failed when when running cleaned triton", e)
92            print(e.output.decode())
93            print(cleaned_triton)
94            raise e
95        return cleaned_triton
96
97    def check_bandwidth(self, compiled_module, num_gb):
98        # now run the compiled module in subprocess and check its output
99        bench_out = subprocess.check_output(
100            f"{sys.executable} {compiled_module.__file__} -k".split(),
101            stderr=subprocess.STDOUT,
102        ).decode()
103
104        # make sure we have the bandwidth information in the output
105        FileCheck().check_count(
106            f"{num_gb} GB ",
107            1,
108            exactly=1,
109        ).run(bench_out)
110
111    def test_pw_kernel_benchmark(self):
112        @torch.compile
113        def f(x):
114            return torch.sin(x) + torch.cos(x)
115
116        inp = torch.rand(2, 3).to(device=GPU_TYPE)
117        out = f(inp)
118        self.verify_compiled_kernels()
119
120    @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
121    @fresh_inductor_cache()
122    def test_matmul_triton_kernel_benchmark(self):
123        M = 12544
124        N = 256
125        K = 64
126        a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
127        b = torch.rand(N, K, dtype=torch.float16, device=GPU_TYPE).t()
128
129        @torch.compile
130        def f(a, b):
131            return torch.relu(a @ b)
132
133        f(a, b)
134        self.verify_compiled_kernels()
135
136    @expectedFailureXPU
137    @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
138    @fresh_inductor_cache()
139    def test_mm_triton_kernel_benchmark(self):
140        M = 2048
141        N = 2432
142        K = 1949
143        K_2 = 3581
144        a = rand_strided((M, K_2), (K_2, 1), device="cuda", dtype=torch.float16)
145        b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float16)
146
147        @torch.compile
148        def f(a, b):
149            a_1 = torch.narrow(a, 1, 0, K)
150            c = torch.mm(a_1, b)
151            return c
152
153        f(a, b)
154        self.verify_compiled_kernels(GB_count=3)
155
156        # make sure we correctly generate the grid info
157        compiled_module = self.get_compiled_module()
158        with open(compiled_module.__file__) as f:
159            source_code = f.read()
160        lines = source_code.split("\n")
161        meta = [l for l in lines if "meta0 = {" in l]
162        scope = {}
163        from torch._inductor.kernel.mm_common import mm_grid
164
165        exec(meta[0], scope)
166        grid = mm_grid(M, N, scope["meta0"])
167        FileCheck().check_count(
168            f"grid={grid}",
169            2,
170            exactly=1,
171        ).run(source_code)
172
173    def test_matmul_bandwidth_computation(self):
174        """
175        The test does a matmul and then mul. Without max-autotune, we use
176        the matmul in aten. So there is a single triton kernel for mul.
177        The kernel we generated is like:
178
179            @triton.jit
180            def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
181
182        Note the in_out_ptr0 argument. It's for a 1000x1000 tensor, but it's
183        inplace udpated, so when computing the bandwidth, we should count
184        the total memory access as 2 * 1000 * 1000 * 4 = 8MB. This amount is
185        what this test asserts.
186        """
187        torch.set_float32_matmul_precision("high")  # suggested by a warning
188
189        @torch.compile
190        def f(x, y):
191            z = x @ y
192            w = z * z
193            return w
194
195        M, N, K = 1000, 1000, 10
196        x = torch.rand(M, K).to(device=GPU_TYPE)
197        y = torch.rand(K, N).to(device=GPU_TYPE)
198        out = f(x, y)
199
200        compiled_module = self.get_compiled_module()
201
202        self.check_bandwidth(compiled_module, 0.008)
203
204    def test_unused_input_bandwidth_computation(self):
205        M, N = 5, 1000000
206
207        @torch.compile
208        def f(a, b, c):
209            return a + c
210
211        a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
212        b = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
213        c = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
214        torch._dynamo.mark_dynamic(a, 0)
215        torch._dynamo.mark_dynamic(b, 0)
216        torch._dynamo.mark_dynamic(c, 0)
217        inputs = (a, b, c)
218        out = f(*inputs)
219
220        compiled_module = self.get_compiled_module()
221        # num_gb = size_a + size_c + size_out
222        # num_gb = (5 * 1000000 + 5 * 1000000 + 5 * 1000000) * 2 / 1e9
223        #        = 0.030
224        self.check_bandwidth(compiled_module, "0.030")
225
226    def test_reduction_bandwidth_computation(self):
227        @torch.compile
228        def f(a):
229            return torch.sum(a, dim=1)
230
231        a = torch.rand(1000, 20, 1000, dtype=torch.float16, device=GPU_TYPE)
232        inputs = (a,)
233        out = f(*inputs)
234
235        compiled_module = self.get_compiled_module()
236        # num_gb = size_a + size_out
237        # num_gb = (1000 * 20 * 1000 + 1000 * 1000) * 2 / 1e9
238        #        = 0.042
239        self.check_bandwidth(compiled_module, "0.042")
240
241    @config.patch(max_autotune=True)
242    def test_fused_layernorm_bandwidth_computation(self):
243        M, N = 10, 1000000
244
245        @torch.compile
246        def f(a, b, c, d):
247            x0 = a + b
248            x1 = torch.nn.functional.layer_norm(
249                x0, normalized_shape=(N,), weight=c, bias=d, eps=1e-05
250            )
251            x2 = torch.sigmoid(x1)
252            return x0 * x2
253
254        a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
255        b = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
256        c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
257        d = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
258        inputs = (a, b, c, d)
259        out = f(*inputs)
260
261        compiled_module = self.get_compiled_module()
262        # num_gb = size_a + size_b + size_c + size_d + size_out
263        # num_gb = (10 * 1000000 + 1000000 + 1000000 + 1000000 + 10 * 1000000) * 2 / 1e9
264        #        = 0.046
265        self.check_bandwidth(compiled_module, "0.046")
266
267    def test_slice_add_cat_bandwidth_computation(self):
268        M, N = 5, 1000000
269
270        @torch.compile
271        def f(a, b, c):
272            x0 = torch.narrow(b, 1, N, N)
273            # broadcasting
274            x1 = x0 + c
275            return torch.cat([a, x1], dim=1)
276
277        a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
278        b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE)
279        c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
280        torch._dynamo.mark_dynamic(a, 0)
281        torch._dynamo.mark_dynamic(b, 0)
282        inputs = (a, b, c)
283        out = f(*inputs)
284
285        compiled_module = self.get_compiled_module()
286        # we overestimate the size of "slice_b" due to torch.cat
287        # num_gp = size_a + size_slice_b + size_c + size_out
288        # num_gb = (5 * 1000000 + 5 * 2000000 + 1000000 + 5 * 2000000) * 2 / 1e9
289        #        = 0.052
290        self.check_bandwidth(compiled_module, "0.052")
291
292    def test_slice_add_bandwidth_computation(self):
293        M, N = 5, 1000000
294
295        @torch.compile
296        def f(a, b, c):
297            x0 = torch.narrow(b, 1, N, N)
298            return a + x0 + c
299
300        a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
301        b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE)
302        c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
303        torch._dynamo.mark_dynamic(a, 0)
304        torch._dynamo.mark_dynamic(b, 0)
305        inputs = (a, b, c)
306        out = f(*inputs)
307
308        compiled_module = self.get_compiled_module()
309        # num_gb = size_a + size_slice_b + size_c + out_size
310        # num_gb = (5 * 1000000 + 5 * 1000000 + 1000000 + 5 * 1000000) * 2 / 1e9
311        #        = 0.032
312        self.check_bandwidth(compiled_module, "0.032")
313
314    def test_mm_slice_add_bandwidth_computation(self):
315        M, N, K = 1000, 1000, 30
316
317        @torch.compile
318        def f(a, b, c):
319            x0 = torch.mm(a, b)
320            x1 = torch.narrow(c, 1, 20 * N, N)
321            x2 = torch.narrow(c, 1, 21 * N, N)
322            return x0 + x1 + x2
323
324        a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
325        b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
326        c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE)
327        inputs = (a, b, c)
328        out = f(*inputs)
329
330        compiled_module = self.get_compiled_module()
331        # torch.mm becomes an extern kernel, so we measure the nbytes
332        # for the pointwise add kernel:
333        # num_gb = x0 + 2 * size_slice_c + size_out
334        # num_gb = (1000 * 1000 + 2 * 1000 * 1000 + 1000 * 1000) * 2/ 1e9
335        #        = 0.008
336        num_gb = "0.008"
337        if GPU_TYPE == "xpu":
338            # In XPU backend, mm + add + add will be fused as admm + add
339            # And CUDA prefer not fuse add + mm, please check in function
340            # `should_prefer_unfused_addmm` in torch/_inductor/fx_passes/post_grad.py
341            num_gb = "0.006"
342
343        self.check_bandwidth(compiled_module, num_gb)
344
345    def test_mm_slice_add_bandwidth_computation_2(self):
346        M, N, K = 1000, 1000, 30
347
348        @torch.compile
349        def f(a, b, c):
350            x0 = torch.mm(a, b)
351            x1 = torch.narrow(c, 1, 20 * N, N)
352            x2 = torch.narrow(c, 1, 20 * N, N)
353            return x0 + x1 + x2
354
355        a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
356        b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
357        c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE)
358        inputs = (a, b, c)
359        out = f(*inputs)
360
361        compiled_module = self.get_compiled_module()
362        # torch.mm becomes an extern kernel, so we measure the nbytes
363        # for the pointwise add kernel:
364        # num_gb = x0 + size_slice_c + size_out
365        # num_gb = (1000 * 1000 + 1000 * 1000 + 1000 * 1000) * 2 / 1e9
366        #        = 0.006
367        # note that we only count one size_slice_c because two accesses
368        # have the same index.
369        self.check_bandwidth(compiled_module, "0.006")
370
371    @expectedFailureXPU
372    @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
373    def test_slice_mm_bandwidth_computation(self):
374        M, N, K = 1000, 2000, 3000
375
376        @torch.compile
377        def f(a, b):
378            x = torch.narrow(a, 1, K, K)
379            return torch.mm(x, b)
380
381        a = torch.rand(M, 3 * K, dtype=torch.float16, device=GPU_TYPE)
382        b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
383        torch._dynamo.mark_dynamic(a, 0)
384        inputs = (a, b)
385        out = f(*inputs)
386
387        compiled_module = self.get_compiled_module()
388
389        # c[1000, 2000] = x[1000, 3000] @ b[3000, 2000]
390        # num_gb = (1000 * 2000 + 1000 * 3000 + 3000 * 2000) * 2 / 1e9
391        #        = 0.022
392        self.check_bandwidth(compiled_module, "0.022")
393
394    def test_star_dep(self):
395        """
396        Test the bandwidth estimation for StarDep
397        """
398
399        @torch.compile
400        def f(a, b):
401            a[b] = 3.0
402
403        a = torch.rand(10000, 5000, device=GPU_TYPE)
404        b = torch.randint(
405            0, 10000, [20000], device=GPU_TYPE, dtype=torch.int32
406        ).unsqueeze(1)
407        f(a, b)
408        compiled_module = self.get_compiled_module()
409        # 20000 * 4 = 80KB for b
410        # 20000 * 5000 * 4 = 200MB for a
411        self.check_bandwidth(compiled_module, "0.200")
412
413    def test_split_scan(self):
414        @torch.compile
415        def f(a):
416            return a.cumsum(-1)
417
418        a = torch.rand(10000, 5000, device=GPU_TYPE)
419        f(a.reshape(-1))
420        compiled_module = self.get_compiled_module()
421        # 10000 * 5000 * 4 = 200 MB for a
422        # Double that for output as well
423        self.check_bandwidth(compiled_module, "0.400")
424
425    @config.patch("triton.unique_kernel_names", True)
426    @config.patch(benchmark_kernel=False)
427    @config.patch(compile_threads=1)
428    def test_remove_inductor_deps(self):
429        @torch.compile
430        def f(a):
431            return a.cos().sin()
432
433        a = torch.randn(5, device=GPU_TYPE)
434        f(a)
435        compiled_module = self.get_compiled_module()
436        cleaned_triton = self.verify_remove_inductor_deps(compiled_module)
437
438    @config.patch("triton.unique_kernel_names", True)
439    @config.patch(benchmark_kernel=False)
440    @config.patch(compile_threads=1)
441    def test_remove_inductor_deps_multiple_kernels(self):
442        @torch.compile
443        def f(a):
444            a = torch.mm(a, a)
445            a = a.cos().sin()
446            a = torch.mm(a, a)
447            a = torch.softmax(a, dim=-1)
448            return a
449
450        a = torch.randn(5, 5, device=GPU_TYPE)
451        f(a)
452        compiled_module = self.get_compiled_module()
453        self.verify_remove_inductor_deps(compiled_module)
454
455    @config.patch("triton.unique_kernel_names", True)
456    @config.patch("triton.unique_kernel_names", True)
457    @config.patch(benchmark_kernel=False)
458    @config.patch(compile_threads=1)
459    @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
460    def test_remove_inductor_deps_templates(self):
461        @torch.compile
462        def f(a):
463            a = torch.mm(a, a)
464            a = a.cos()
465            a = torch.mm(a, a)
466            a = a.sin()
467            return a
468
469        a = torch.randn(128, 128, device=GPU_TYPE)
470        f(a)
471        compiled_module = self.get_compiled_module()
472        self.verify_remove_inductor_deps(compiled_module)
473
474    @config.patch("triton.unique_kernel_names", True)
475    @config.patch(benchmark_kernel=False)
476    @config.patch(compile_threads=1)
477    def test_remove_inductor_deps_scalar(self):
478        @torch.compile
479        def f(a, b):
480            return a + b
481
482        a = torch.tensor(1.0, device=GPU_TYPE)
483        b = torch.tensor(2.0, device=GPU_TYPE)
484        f(a, b)
485        compiled_module = self.get_compiled_module()
486        self.verify_remove_inductor_deps(compiled_module)
487
488
489if __name__ == "__main__":
490    if HAS_GPU:
491        run_tests()
492