1# Owner(s): ["module: inductor"] 2import contextlib 3import os 4import subprocess 5import sys 6from unittest.mock import patch 7 8import torch 9import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools 10from torch._dynamo.testing import rand_strided 11from torch._inductor import config 12from torch._inductor.codecache import PyCodeCache 13from torch._inductor.test_case import run_tests, TestCase 14from torch._inductor.utils import fresh_inductor_cache 15from torch.testing import FileCheck 16from torch.testing._internal.common_device_type import expectedFailureXPU 17from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 18 19 20class TestKernelBenchmark(TestCase): 21 device_type = GPU_TYPE 22 23 @classmethod 24 def setUpClass(cls): 25 cls.exit_stack = contextlib.ExitStack() 26 cls.exit_stack.enter_context(patch.object(config, "benchmark_kernel", True)) 27 28 @classmethod 29 def tearDownClass(cls): 30 cls.exit_stack.close() 31 32 def setUp(self): 33 super().setUp() 34 PyCodeCache.cache.clear() 35 36 def get_compiled_module(self): 37 compiled_module = None 38 for v in PyCodeCache.cache.values(): 39 if hasattr(v, "benchmark_compiled_module"): 40 self.assertTrue( 41 compiled_module is None, "Found multiple compiled modules" 42 ) 43 compiled_module = v 44 45 self.assertTrue(compiled_module is not None) 46 return compiled_module 47 48 def verify_compiled_kernels(self, GB_count=1): 49 compiled_module = self.get_compiled_module() 50 51 # now run the compiled module in subprocess and check its output 52 bench_out = subprocess.check_output( 53 f"{sys.executable} {compiled_module.__file__} -kc".split(), 54 stderr=subprocess.STDOUT, 55 ).decode() 56 57 # make sure we have the bandwidth information in the output 58 FileCheck().check_count( 59 "GB/s", 60 GB_count, 61 exactly=1, 62 ).run(bench_out) 63 64 def verify_remove_inductor_deps(self, compiled_module): 65 try: 66 out = subprocess.check_output( 67 f"{sys.executable} {compiled_module.__file__}".split(), 68 env={**os.environ.copy(), "TORCHINDUCTOR_DUMP_LAUNCH_PARAMS": "1"}, 69 stderr=subprocess.STDOUT, 70 ) 71 except subprocess.CalledProcessError as e: 72 print( 73 "Failed when runinng triton code with TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1", 74 e, 75 ) 76 print(e.output.decode()) 77 raise e 78 from torch.utils._get_clean_triton import get_clean_triton 79 80 cleaned_triton = get_clean_triton( 81 compiled_module.__file__, f"{compiled_module.__file__}.cleaned" 82 ) 83 self.assertTrue("@triton_heuristics" not in cleaned_triton) 84 self.assertTrue(".run(" not in cleaned_triton) 85 try: 86 out = subprocess.check_output( 87 f"{sys.executable} {compiled_module.__file__}.cleaned".split(), 88 stderr=subprocess.STDOUT, 89 ) 90 except subprocess.CalledProcessError as e: 91 print("Failed when when running cleaned triton", e) 92 print(e.output.decode()) 93 print(cleaned_triton) 94 raise e 95 return cleaned_triton 96 97 def check_bandwidth(self, compiled_module, num_gb): 98 # now run the compiled module in subprocess and check its output 99 bench_out = subprocess.check_output( 100 f"{sys.executable} {compiled_module.__file__} -k".split(), 101 stderr=subprocess.STDOUT, 102 ).decode() 103 104 # make sure we have the bandwidth information in the output 105 FileCheck().check_count( 106 f"{num_gb} GB ", 107 1, 108 exactly=1, 109 ).run(bench_out) 110 111 def test_pw_kernel_benchmark(self): 112 @torch.compile 113 def f(x): 114 return torch.sin(x) + torch.cos(x) 115 116 inp = torch.rand(2, 3).to(device=GPU_TYPE) 117 out = f(inp) 118 self.verify_compiled_kernels() 119 120 @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 121 @fresh_inductor_cache() 122 def test_matmul_triton_kernel_benchmark(self): 123 M = 12544 124 N = 256 125 K = 64 126 a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE) 127 b = torch.rand(N, K, dtype=torch.float16, device=GPU_TYPE).t() 128 129 @torch.compile 130 def f(a, b): 131 return torch.relu(a @ b) 132 133 f(a, b) 134 self.verify_compiled_kernels() 135 136 @expectedFailureXPU 137 @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 138 @fresh_inductor_cache() 139 def test_mm_triton_kernel_benchmark(self): 140 M = 2048 141 N = 2432 142 K = 1949 143 K_2 = 3581 144 a = rand_strided((M, K_2), (K_2, 1), device="cuda", dtype=torch.float16) 145 b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float16) 146 147 @torch.compile 148 def f(a, b): 149 a_1 = torch.narrow(a, 1, 0, K) 150 c = torch.mm(a_1, b) 151 return c 152 153 f(a, b) 154 self.verify_compiled_kernels(GB_count=3) 155 156 # make sure we correctly generate the grid info 157 compiled_module = self.get_compiled_module() 158 with open(compiled_module.__file__) as f: 159 source_code = f.read() 160 lines = source_code.split("\n") 161 meta = [l for l in lines if "meta0 = {" in l] 162 scope = {} 163 from torch._inductor.kernel.mm_common import mm_grid 164 165 exec(meta[0], scope) 166 grid = mm_grid(M, N, scope["meta0"]) 167 FileCheck().check_count( 168 f"grid={grid}", 169 2, 170 exactly=1, 171 ).run(source_code) 172 173 def test_matmul_bandwidth_computation(self): 174 """ 175 The test does a matmul and then mul. Without max-autotune, we use 176 the matmul in aten. So there is a single triton kernel for mul. 177 The kernel we generated is like: 178 179 @triton.jit 180 def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr): 181 182 Note the in_out_ptr0 argument. It's for a 1000x1000 tensor, but it's 183 inplace udpated, so when computing the bandwidth, we should count 184 the total memory access as 2 * 1000 * 1000 * 4 = 8MB. This amount is 185 what this test asserts. 186 """ 187 torch.set_float32_matmul_precision("high") # suggested by a warning 188 189 @torch.compile 190 def f(x, y): 191 z = x @ y 192 w = z * z 193 return w 194 195 M, N, K = 1000, 1000, 10 196 x = torch.rand(M, K).to(device=GPU_TYPE) 197 y = torch.rand(K, N).to(device=GPU_TYPE) 198 out = f(x, y) 199 200 compiled_module = self.get_compiled_module() 201 202 self.check_bandwidth(compiled_module, 0.008) 203 204 def test_unused_input_bandwidth_computation(self): 205 M, N = 5, 1000000 206 207 @torch.compile 208 def f(a, b, c): 209 return a + c 210 211 a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) 212 b = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) 213 c = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) 214 torch._dynamo.mark_dynamic(a, 0) 215 torch._dynamo.mark_dynamic(b, 0) 216 torch._dynamo.mark_dynamic(c, 0) 217 inputs = (a, b, c) 218 out = f(*inputs) 219 220 compiled_module = self.get_compiled_module() 221 # num_gb = size_a + size_c + size_out 222 # num_gb = (5 * 1000000 + 5 * 1000000 + 5 * 1000000) * 2 / 1e9 223 # = 0.030 224 self.check_bandwidth(compiled_module, "0.030") 225 226 def test_reduction_bandwidth_computation(self): 227 @torch.compile 228 def f(a): 229 return torch.sum(a, dim=1) 230 231 a = torch.rand(1000, 20, 1000, dtype=torch.float16, device=GPU_TYPE) 232 inputs = (a,) 233 out = f(*inputs) 234 235 compiled_module = self.get_compiled_module() 236 # num_gb = size_a + size_out 237 # num_gb = (1000 * 20 * 1000 + 1000 * 1000) * 2 / 1e9 238 # = 0.042 239 self.check_bandwidth(compiled_module, "0.042") 240 241 @config.patch(max_autotune=True) 242 def test_fused_layernorm_bandwidth_computation(self): 243 M, N = 10, 1000000 244 245 @torch.compile 246 def f(a, b, c, d): 247 x0 = a + b 248 x1 = torch.nn.functional.layer_norm( 249 x0, normalized_shape=(N,), weight=c, bias=d, eps=1e-05 250 ) 251 x2 = torch.sigmoid(x1) 252 return x0 * x2 253 254 a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) 255 b = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) 256 c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) 257 d = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) 258 inputs = (a, b, c, d) 259 out = f(*inputs) 260 261 compiled_module = self.get_compiled_module() 262 # num_gb = size_a + size_b + size_c + size_d + size_out 263 # num_gb = (10 * 1000000 + 1000000 + 1000000 + 1000000 + 10 * 1000000) * 2 / 1e9 264 # = 0.046 265 self.check_bandwidth(compiled_module, "0.046") 266 267 def test_slice_add_cat_bandwidth_computation(self): 268 M, N = 5, 1000000 269 270 @torch.compile 271 def f(a, b, c): 272 x0 = torch.narrow(b, 1, N, N) 273 # broadcasting 274 x1 = x0 + c 275 return torch.cat([a, x1], dim=1) 276 277 a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) 278 b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE) 279 c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) 280 torch._dynamo.mark_dynamic(a, 0) 281 torch._dynamo.mark_dynamic(b, 0) 282 inputs = (a, b, c) 283 out = f(*inputs) 284 285 compiled_module = self.get_compiled_module() 286 # we overestimate the size of "slice_b" due to torch.cat 287 # num_gp = size_a + size_slice_b + size_c + size_out 288 # num_gb = (5 * 1000000 + 5 * 2000000 + 1000000 + 5 * 2000000) * 2 / 1e9 289 # = 0.052 290 self.check_bandwidth(compiled_module, "0.052") 291 292 def test_slice_add_bandwidth_computation(self): 293 M, N = 5, 1000000 294 295 @torch.compile 296 def f(a, b, c): 297 x0 = torch.narrow(b, 1, N, N) 298 return a + x0 + c 299 300 a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) 301 b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE) 302 c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) 303 torch._dynamo.mark_dynamic(a, 0) 304 torch._dynamo.mark_dynamic(b, 0) 305 inputs = (a, b, c) 306 out = f(*inputs) 307 308 compiled_module = self.get_compiled_module() 309 # num_gb = size_a + size_slice_b + size_c + out_size 310 # num_gb = (5 * 1000000 + 5 * 1000000 + 1000000 + 5 * 1000000) * 2 / 1e9 311 # = 0.032 312 self.check_bandwidth(compiled_module, "0.032") 313 314 def test_mm_slice_add_bandwidth_computation(self): 315 M, N, K = 1000, 1000, 30 316 317 @torch.compile 318 def f(a, b, c): 319 x0 = torch.mm(a, b) 320 x1 = torch.narrow(c, 1, 20 * N, N) 321 x2 = torch.narrow(c, 1, 21 * N, N) 322 return x0 + x1 + x2 323 324 a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE) 325 b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE) 326 c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE) 327 inputs = (a, b, c) 328 out = f(*inputs) 329 330 compiled_module = self.get_compiled_module() 331 # torch.mm becomes an extern kernel, so we measure the nbytes 332 # for the pointwise add kernel: 333 # num_gb = x0 + 2 * size_slice_c + size_out 334 # num_gb = (1000 * 1000 + 2 * 1000 * 1000 + 1000 * 1000) * 2/ 1e9 335 # = 0.008 336 num_gb = "0.008" 337 if GPU_TYPE == "xpu": 338 # In XPU backend, mm + add + add will be fused as admm + add 339 # And CUDA prefer not fuse add + mm, please check in function 340 # `should_prefer_unfused_addmm` in torch/_inductor/fx_passes/post_grad.py 341 num_gb = "0.006" 342 343 self.check_bandwidth(compiled_module, num_gb) 344 345 def test_mm_slice_add_bandwidth_computation_2(self): 346 M, N, K = 1000, 1000, 30 347 348 @torch.compile 349 def f(a, b, c): 350 x0 = torch.mm(a, b) 351 x1 = torch.narrow(c, 1, 20 * N, N) 352 x2 = torch.narrow(c, 1, 20 * N, N) 353 return x0 + x1 + x2 354 355 a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE) 356 b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE) 357 c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE) 358 inputs = (a, b, c) 359 out = f(*inputs) 360 361 compiled_module = self.get_compiled_module() 362 # torch.mm becomes an extern kernel, so we measure the nbytes 363 # for the pointwise add kernel: 364 # num_gb = x0 + size_slice_c + size_out 365 # num_gb = (1000 * 1000 + 1000 * 1000 + 1000 * 1000) * 2 / 1e9 366 # = 0.006 367 # note that we only count one size_slice_c because two accesses 368 # have the same index. 369 self.check_bandwidth(compiled_module, "0.006") 370 371 @expectedFailureXPU 372 @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 373 def test_slice_mm_bandwidth_computation(self): 374 M, N, K = 1000, 2000, 3000 375 376 @torch.compile 377 def f(a, b): 378 x = torch.narrow(a, 1, K, K) 379 return torch.mm(x, b) 380 381 a = torch.rand(M, 3 * K, dtype=torch.float16, device=GPU_TYPE) 382 b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE) 383 torch._dynamo.mark_dynamic(a, 0) 384 inputs = (a, b) 385 out = f(*inputs) 386 387 compiled_module = self.get_compiled_module() 388 389 # c[1000, 2000] = x[1000, 3000] @ b[3000, 2000] 390 # num_gb = (1000 * 2000 + 1000 * 3000 + 3000 * 2000) * 2 / 1e9 391 # = 0.022 392 self.check_bandwidth(compiled_module, "0.022") 393 394 def test_star_dep(self): 395 """ 396 Test the bandwidth estimation for StarDep 397 """ 398 399 @torch.compile 400 def f(a, b): 401 a[b] = 3.0 402 403 a = torch.rand(10000, 5000, device=GPU_TYPE) 404 b = torch.randint( 405 0, 10000, [20000], device=GPU_TYPE, dtype=torch.int32 406 ).unsqueeze(1) 407 f(a, b) 408 compiled_module = self.get_compiled_module() 409 # 20000 * 4 = 80KB for b 410 # 20000 * 5000 * 4 = 200MB for a 411 self.check_bandwidth(compiled_module, "0.200") 412 413 def test_split_scan(self): 414 @torch.compile 415 def f(a): 416 return a.cumsum(-1) 417 418 a = torch.rand(10000, 5000, device=GPU_TYPE) 419 f(a.reshape(-1)) 420 compiled_module = self.get_compiled_module() 421 # 10000 * 5000 * 4 = 200 MB for a 422 # Double that for output as well 423 self.check_bandwidth(compiled_module, "0.400") 424 425 @config.patch("triton.unique_kernel_names", True) 426 @config.patch(benchmark_kernel=False) 427 @config.patch(compile_threads=1) 428 def test_remove_inductor_deps(self): 429 @torch.compile 430 def f(a): 431 return a.cos().sin() 432 433 a = torch.randn(5, device=GPU_TYPE) 434 f(a) 435 compiled_module = self.get_compiled_module() 436 cleaned_triton = self.verify_remove_inductor_deps(compiled_module) 437 438 @config.patch("triton.unique_kernel_names", True) 439 @config.patch(benchmark_kernel=False) 440 @config.patch(compile_threads=1) 441 def test_remove_inductor_deps_multiple_kernels(self): 442 @torch.compile 443 def f(a): 444 a = torch.mm(a, a) 445 a = a.cos().sin() 446 a = torch.mm(a, a) 447 a = torch.softmax(a, dim=-1) 448 return a 449 450 a = torch.randn(5, 5, device=GPU_TYPE) 451 f(a) 452 compiled_module = self.get_compiled_module() 453 self.verify_remove_inductor_deps(compiled_module) 454 455 @config.patch("triton.unique_kernel_names", True) 456 @config.patch("triton.unique_kernel_names", True) 457 @config.patch(benchmark_kernel=False) 458 @config.patch(compile_threads=1) 459 @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 460 def test_remove_inductor_deps_templates(self): 461 @torch.compile 462 def f(a): 463 a = torch.mm(a, a) 464 a = a.cos() 465 a = torch.mm(a, a) 466 a = a.sin() 467 return a 468 469 a = torch.randn(128, 128, device=GPU_TYPE) 470 f(a) 471 compiled_module = self.get_compiled_module() 472 self.verify_remove_inductor_deps(compiled_module) 473 474 @config.patch("triton.unique_kernel_names", True) 475 @config.patch(benchmark_kernel=False) 476 @config.patch(compile_threads=1) 477 def test_remove_inductor_deps_scalar(self): 478 @torch.compile 479 def f(a, b): 480 return a + b 481 482 a = torch.tensor(1.0, device=GPU_TYPE) 483 b = torch.tensor(2.0, device=GPU_TYPE) 484 f(a, b) 485 compiled_module = self.get_compiled_module() 486 self.verify_remove_inductor_deps(compiled_module) 487 488 489if __name__ == "__main__": 490 if HAS_GPU: 491 run_tests() 492