1# Owner(s): ["module: inductor"] 2import json 3import unittest 4from typing import Callable, Optional 5 6import torch 7import torch._inductor.test_case 8import torch._inductor.utils 9from torch._inductor import config 10from torch.profiler import ProfilerActivity 11from torch.testing._internal.common_utils import TemporaryFileName 12from torch.testing._internal.inductor_utils import HAS_CUDA 13from torch.utils._triton import has_triton 14 15 16HAS_TRITON = has_triton() 17 18 19class DynamoProfilerTests(torch._inductor.test_case.TestCase): 20 @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") 21 def test_inductor_profiling_triton_launch(self): 22 # Verify that we get some sort of CPU-side indication of triton kernel launches 23 # in the profile traces. Currently, those appear as `cuLaunchKernel`. If this 24 # detail changes, the test can be updated or removed. 25 @torch.compile 26 def fn(x, y): 27 return (x + y).sin().cos() 28 29 x, y = (torch.rand((4, 4), device="cuda") for _ in range(2)) 30 31 with torch.profiler.profile() as prof: 32 fn(x, y) 33 34 with TemporaryFileName(mode="w+") as fname: 35 prof.export_chrome_trace(fname) 36 with open(fname) as f: 37 trace_json = json.load(f) 38 39 self.assertTrue("traceEvents" in trace_json) 40 events = trace_json["traceEvents"] 41 42 kernel_name = "hipModuleLaunchKernel" if torch.version.hip else "cuLaunchKernel" 43 44 def nameMatchesLaunchKernel(event_name): 45 return kernel_name in event_name 46 47 self.assertTrue( 48 any(("name" in event and kernel_name == event["name"]) for event in events) 49 ) 50 51 def _test_profiling_kernel_names( 52 self, fn, args, kernel_name_str: str, check_fn: Optional[Callable] = None 53 ): 54 """ 55 We expect a record_function event to be added on the CPU side, surrounding 56 the launch of each triton kernel. 57 """ 58 fn_opt = torch.compile(fn) 59 60 for _ in range(2): 61 fn_opt(*args) 62 63 if check_fn is not None: 64 check_fn() 65 66 with torch.profiler.profile( 67 activities=[ProfilerActivity.CPU], record_shapes=True 68 ) as prof: 69 fn_opt(*args) 70 71 # The name of the kernel is expected to match the name of the kernel in debug 72 # files etc. The name could change in the future, but it seems reasonable that 73 # the name should always contain "triton" and "kernel_name_str" - e.g. if the 74 # kernel contains a sin op, it should probably contain "str" in the name. 75 # If this changes in the future, feel free to change the assertion here. 76 # Debugging tips: you can add prof.export_chrome_trace("test.json") inline in 77 # this test, and then view test.json in chrome://tracing to see the trace. 78 self.assertTrue( 79 any( 80 ( 81 hasattr(event, "name") 82 and kernel_name_str in event.name 83 and "triton" in event.name 84 ) 85 for event in prof.events() 86 ) 87 ) 88 return prof.events() 89 90 @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") 91 def test_inductor_profiling_kernel_names_pointwise(self): 92 def fn(x, y): 93 return (x + y).sin().cos() 94 95 args = [torch.rand((4, 4), device="cuda") for _ in range(2)] 96 97 events = self._test_profiling_kernel_names(fn, args, "sin") 98 event_found = False 99 for event in events: 100 if event.name == "triton_poi_fused_add_cos_sin_0": 101 event_found = True 102 self.assertTrue(event.input_shapes == [[4, 4], [4, 4], [4, 4], []]) 103 self.assertTrue(event_found) 104 105 @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") 106 def test_inductor_profiling_kernel_names_template(self): 107 with config.patch( 108 {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"} 109 ): 110 111 def fn(x, y): 112 return x @ y 113 114 args = [torch.rand((4, 4), device="cuda") for _ in range(2)] 115 116 def check_fn(): 117 # test_profiling_kernel_names will check this before asserting mm is in the trace. 118 # reason: sometimes testing runs on machines with not enough SMs, and autotuning is skipped. 119 if ( 120 torch._dynamo.utils.counters["inductor"][ 121 "select_algorithm_autotune" 122 ] 123 == 0 124 ): 125 raise unittest.SkipTest( 126 "select_algorithm didn't run, we probably won't get profiling data. GPU might not have enough SMs." 127 ) 128 129 events = self._test_profiling_kernel_names(fn, args, "mm", check_fn) 130 131 event_found = False 132 for event in events: 133 if event.name == "triton_tem_fused_mm_0": 134 event_found = True 135 self.assertTrue(event.input_shapes == [[4, 4], [4, 4], [4, 4]]) 136 self.assertTrue(event_found) 137 138 @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") 139 def test_inductor_profiling_kernel_names_foreach(self): 140 with config.patch( 141 {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"} 142 ): 143 144 def fn(x, y): 145 return torch._foreach_add(x, y) 146 147 x = [torch.rand((4, 4), device="cuda") for _ in range(3)] 148 y = [torch.rand((4, 4), device="cuda") for _ in range(3)] 149 150 args = (x, y) 151 152 events = self._test_profiling_kernel_names(fn, args, "_for_") 153 event_found = False 154 for event in events: 155 if event.name == "triton_for_fused_0": 156 event_found = True 157 self.assertTrue( 158 event.input_shapes 159 == [ 160 [4, 4], 161 [4, 4], 162 [4, 4], 163 [4, 4], 164 [4, 4], 165 [4, 4], 166 [4, 4], 167 [4, 4], 168 [4, 4], 169 ] 170 ) 171 self.assertTrue(event_found) 172 173 @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") 174 def test_inductor_profiling_triton_hooks(self): 175 from triton.compiler import CompiledKernel 176 177 hooks_called = {"enter": False, "exit": False} 178 179 def launch_enter_hook(lazy_dict): 180 hooks_called["enter"] = True 181 182 def launch_exit_hook(lazy_dict): 183 hooks_called["exit"] = True 184 185 CompiledKernel.launch_enter_hook = launch_enter_hook 186 CompiledKernel.launch_exit_hook = launch_exit_hook 187 188 def fn(x, y): 189 return torch._foreach_add(x, y) 190 191 x = [torch.rand((4, 4), device="cuda") for _ in range(3)] 192 y = [torch.rand((4, 4), device="cuda") for _ in range(3)] 193 194 args = (x, y) 195 fn_opt = torch.compile(fn) 196 fn_opt(*args) 197 198 self.assertTrue(hooks_called["enter"]) 199 self.assertTrue(hooks_called["exit"]) 200 201 202if __name__ == "__main__": 203 from torch._inductor.test_case import run_tests 204 205 if HAS_CUDA: 206 run_tests() 207