xref: /aosp_15_r20/external/pytorch/test/inductor/test_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import json
3import unittest
4from typing import Callable, Optional
5
6import torch
7import torch._inductor.test_case
8import torch._inductor.utils
9from torch._inductor import config
10from torch.profiler import ProfilerActivity
11from torch.testing._internal.common_utils import TemporaryFileName
12from torch.testing._internal.inductor_utils import HAS_CUDA
13from torch.utils._triton import has_triton
14
15
16HAS_TRITON = has_triton()
17
18
19class DynamoProfilerTests(torch._inductor.test_case.TestCase):
20    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
21    def test_inductor_profiling_triton_launch(self):
22        # Verify that we get some sort of CPU-side indication of triton kernel launches
23        # in the profile traces. Currently, those appear as `cuLaunchKernel`. If this
24        # detail changes, the test can be updated or removed.
25        @torch.compile
26        def fn(x, y):
27            return (x + y).sin().cos()
28
29        x, y = (torch.rand((4, 4), device="cuda") for _ in range(2))
30
31        with torch.profiler.profile() as prof:
32            fn(x, y)
33
34        with TemporaryFileName(mode="w+") as fname:
35            prof.export_chrome_trace(fname)
36            with open(fname) as f:
37                trace_json = json.load(f)
38
39        self.assertTrue("traceEvents" in trace_json)
40        events = trace_json["traceEvents"]
41
42        kernel_name = "hipModuleLaunchKernel" if torch.version.hip else "cuLaunchKernel"
43
44        def nameMatchesLaunchKernel(event_name):
45            return kernel_name in event_name
46
47        self.assertTrue(
48            any(("name" in event and kernel_name == event["name"]) for event in events)
49        )
50
51    def _test_profiling_kernel_names(
52        self, fn, args, kernel_name_str: str, check_fn: Optional[Callable] = None
53    ):
54        """
55        We expect a record_function event to be added on the CPU side, surrounding
56        the launch of each triton kernel.
57        """
58        fn_opt = torch.compile(fn)
59
60        for _ in range(2):
61            fn_opt(*args)
62
63        if check_fn is not None:
64            check_fn()
65
66        with torch.profiler.profile(
67            activities=[ProfilerActivity.CPU], record_shapes=True
68        ) as prof:
69            fn_opt(*args)
70
71        # The name of the kernel is expected to match the name of the kernel in debug
72        # files etc. The name could change in the future, but it seems reasonable that
73        # the name should always contain "triton" and "kernel_name_str" - e.g. if the
74        # kernel contains a sin op, it should probably contain "str" in the name.
75        # If this changes in the future, feel free to change the assertion here.
76        # Debugging tips: you can add prof.export_chrome_trace("test.json") inline in
77        # this test, and then view test.json in chrome://tracing to see the trace.
78        self.assertTrue(
79            any(
80                (
81                    hasattr(event, "name")
82                    and kernel_name_str in event.name
83                    and "triton" in event.name
84                )
85                for event in prof.events()
86            )
87        )
88        return prof.events()
89
90    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
91    def test_inductor_profiling_kernel_names_pointwise(self):
92        def fn(x, y):
93            return (x + y).sin().cos()
94
95        args = [torch.rand((4, 4), device="cuda") for _ in range(2)]
96
97        events = self._test_profiling_kernel_names(fn, args, "sin")
98        event_found = False
99        for event in events:
100            if event.name == "triton_poi_fused_add_cos_sin_0":
101                event_found = True
102                self.assertTrue(event.input_shapes == [[4, 4], [4, 4], [4, 4], []])
103        self.assertTrue(event_found)
104
105    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
106    def test_inductor_profiling_kernel_names_template(self):
107        with config.patch(
108            {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
109        ):
110
111            def fn(x, y):
112                return x @ y
113
114            args = [torch.rand((4, 4), device="cuda") for _ in range(2)]
115
116            def check_fn():
117                # test_profiling_kernel_names will check this before asserting mm is in the trace.
118                # reason: sometimes testing runs on machines with not enough SMs, and autotuning is skipped.
119                if (
120                    torch._dynamo.utils.counters["inductor"][
121                        "select_algorithm_autotune"
122                    ]
123                    == 0
124                ):
125                    raise unittest.SkipTest(
126                        "select_algorithm didn't run, we probably won't get profiling data. GPU might not have enough SMs."
127                    )
128
129            events = self._test_profiling_kernel_names(fn, args, "mm", check_fn)
130
131            event_found = False
132            for event in events:
133                if event.name == "triton_tem_fused_mm_0":
134                    event_found = True
135                    self.assertTrue(event.input_shapes == [[4, 4], [4, 4], [4, 4]])
136            self.assertTrue(event_found)
137
138    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
139    def test_inductor_profiling_kernel_names_foreach(self):
140        with config.patch(
141            {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
142        ):
143
144            def fn(x, y):
145                return torch._foreach_add(x, y)
146
147            x = [torch.rand((4, 4), device="cuda") for _ in range(3)]
148            y = [torch.rand((4, 4), device="cuda") for _ in range(3)]
149
150            args = (x, y)
151
152            events = self._test_profiling_kernel_names(fn, args, "_for_")
153            event_found = False
154            for event in events:
155                if event.name == "triton_for_fused_0":
156                    event_found = True
157                    self.assertTrue(
158                        event.input_shapes
159                        == [
160                            [4, 4],
161                            [4, 4],
162                            [4, 4],
163                            [4, 4],
164                            [4, 4],
165                            [4, 4],
166                            [4, 4],
167                            [4, 4],
168                            [4, 4],
169                        ]
170                    )
171            self.assertTrue(event_found)
172
173    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
174    def test_inductor_profiling_triton_hooks(self):
175        from triton.compiler import CompiledKernel
176
177        hooks_called = {"enter": False, "exit": False}
178
179        def launch_enter_hook(lazy_dict):
180            hooks_called["enter"] = True
181
182        def launch_exit_hook(lazy_dict):
183            hooks_called["exit"] = True
184
185        CompiledKernel.launch_enter_hook = launch_enter_hook
186        CompiledKernel.launch_exit_hook = launch_exit_hook
187
188        def fn(x, y):
189            return torch._foreach_add(x, y)
190
191        x = [torch.rand((4, 4), device="cuda") for _ in range(3)]
192        y = [torch.rand((4, 4), device="cuda") for _ in range(3)]
193
194        args = (x, y)
195        fn_opt = torch.compile(fn)
196        fn_opt(*args)
197
198        self.assertTrue(hooks_called["enter"])
199        self.assertTrue(hooks_called["exit"])
200
201
202if __name__ == "__main__":
203    from torch._inductor.test_case import run_tests
204
205    if HAS_CUDA:
206        run_tests()
207