1# Owner(s): ["module: inductor"] 2import math 3import os 4import sys 5 6import torch 7from torch._inductor.test_case import TestCase as InductorTestCase 8from torch._inductor.test_operators import realize 9from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code 10from torch.testing import FileCheck 11from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN 12from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 13 14 15# Make the helper files in test/ importable 16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 17sys.path.append(pytorch_test_dir) 18 19import contextlib 20import unittest 21 22from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests 23from torch._inductor import config 24from torch._inductor.scheduler import Scheduler 25 26 27class TestCase(InductorTestCase): 28 @classmethod 29 def setUpClass(cls): 30 super().setUpClass() 31 cls._stack = contextlib.ExitStack() 32 cls._stack.enter_context( 33 config.patch( 34 { 35 "benchmark_kernel": True, 36 "benchmark_fusion": True, 37 } 38 ) 39 ) 40 41 @classmethod 42 def tearDownClass(cls): 43 cls._stack.close() 44 super().tearDownClass() 45 46 47class BenchmarkFusionTestTemplate: 48 def test_softmax(self): 49 def f(x): 50 return torch.nn.functional.softmax(x, dim=-1) 51 52 self.common(f, (torch.rand(2, 8192),)) 53 54 @slowTest 55 def test_resnet18(self): 56 import torchvision 57 58 model = torchvision.models.resnet18() 59 model.eval() 60 batch_size = 16 61 inputs = (torch.randn((batch_size, 3, 224, 224)),) 62 self.common(model, inputs, atol=1e-2, rtol=1e-2) 63 64 def test_register_spills(self): 65 """ 66 The test can potentially trigger register spills 67 """ 68 old_benchmark_fn = Scheduler.benchmark_fused_nodes 69 70 def new_benchmark_fn(scheduler, nodes): 71 """ 72 We override Scheduler.benchmark_fused_nodes to return latency 1.0 73 if there are no register spills. Without this, we may not able to 74 test the code path handling register spilling because before register 75 start spilling, the related fusion may have already been skipped 76 due to longer lantency. 77 """ 78 ms, path = old_benchmark_fn(scheduler, nodes) 79 if not math.isinf(ms): 80 ms = 1.0 81 return ms, path 82 83 # Disable dynamic_scale_rblock to make it easier to trigger register 84 # spilling. 85 with unittest.mock.patch.object( 86 Scheduler, "benchmark_fused_nodes", new_benchmark_fn 87 ), config.patch("dynamic_scale_rblock", False): 88 S = 512 89 90 def f(*inputs): 91 inputs = list(inputs) 92 outputs = [] 93 out = torch.zeros(S, device=self.device) 94 for x in inputs: 95 x = x * 2 96 x = x + 1 97 x = x.sum(dim=-1) 98 outputs.append(x) 99 out = out + x 100 return outputs, out 101 102 N = int(os.environ.get("NINP", "30")) 103 inputs = [torch.randn(S, 2560, device=self.device) for _ in range(N)] 104 opt_f = torch.compile(f) 105 opt_f(*inputs) 106 107 def test_foreach_kernel(self): 108 """ 109 Benchmark fusion should skip benchmarking kernels involves foreach kernel 110 for now. Without the skipping logic, `codegen_node_schedule` may fail. 111 """ 112 a = torch.randn(1024, 256, device=self.device) 113 b = torch.randn(1024, 512, device=self.device) 114 115 def f(a, b): 116 a, b = torch._foreach_abs([a, b]) 117 return a + 1, b + 2 118 119 self.common(f, (a, b)) 120 121 @torch._inductor.config.patch(max_autotune_gemm_backends="TRITON") 122 def test_avoid_register_spilling(self): 123 if self.device != "cuda": 124 raise unittest.SkipTest("CUDA only") 125 126 from torch.nn.functional import gelu 127 128 def foo(m, inp): 129 curr = m(inp) 130 tmps = [] 131 for _ in range(4): 132 curr = gelu(curr) 133 for t in tmps: 134 curr = curr + t 135 tmps.append(curr) 136 137 return curr 138 139 m = torch.nn.Linear(2048, 2048, bias=True).half().cuda() 140 inp = torch.rand([2048, 2048]).half().cuda() 141 142 with torch.no_grad(): 143 foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) 144 145 _, out_code = run_and_get_code(foo_c, m, inp) 146 147 # occasionally, CI will make this one kernel. just skip in this case 148 if not out_code[0].count("def triton_") == 2: 149 return 150 151 # should be multiple triton invocations 152 FileCheck().check("async_compile.wait").check_count( 153 ".run", 2, exactly=True 154 ).run(out_code[0]) 155 156 with config.patch( 157 {"benchmark_fusion": False, "epilogue_fusion": False} 158 ), torch.no_grad(): 159 torch._dynamo.reset() 160 161 foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) 162 163 _, out_code2 = run_and_get_code(foo_c, m, inp) 164 165 for c in out_code[0], out_code2[0]: 166 FileCheck().check("async_compile.wait").check("DeviceGuard").check_count( 167 "empty_strided_cuda", 2, exactly=True 168 ).check("return").run(c) 169 170 def test_tield_kernel_fusion(self): 171 def f(x): 172 y = realize(x + x.t()) 173 return y + 1 174 175 x = torch.randn(1024, 1024, device=self.device) 176 self.common(f, (x,)) 177 178 179if HAS_CUDA and not TEST_WITH_ASAN: 180 181 class BenchmarkFusionCudaTest(TestCase): 182 common = check_model_cuda 183 device = "cuda" 184 185 copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda") 186 187 class BenchmarkMultiTemplateFusionCudaTest(InductorTestCase): 188 @classmethod 189 def setUpClass(cls): 190 super().setUpClass() 191 cls._stack = contextlib.ExitStack() 192 cls._stack.enter_context( 193 config.patch( 194 { 195 "benchmark_kernel": True, 196 "benchmark_fusion": True, 197 "benchmark_epilogue_fusion": True, 198 } 199 ) 200 ) 201 202 @classmethod 203 def tearDownClass(cls): 204 cls._stack.close() 205 super().tearDownClass() 206 207 def setUp(self): 208 super().setUp() 209 if not is_big_gpu(0): 210 return self.skipTest("Need a big GPU to run max_autotune=True") 211 212 def _equivalent_output_code_impl(self, size, first_dim=None, activation=True): 213 def foo(m, inp): 214 a = m(inp) 215 if activation: 216 return torch.nn.functional.relu(a) 217 return a 218 219 foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) 220 first_dim = first_dim if first_dim is not None else size 221 222 m = torch.nn.Linear(size, size, bias=True).half().cuda() 223 inp = torch.rand([first_dim, size]).half().cuda() 224 225 with torch.no_grad(): 226 res, code = run_and_get_code(foo_c, m, inp) 227 228 torch._dynamo.reset() 229 with unittest.mock.patch.object( 230 torch._inductor.config, "benchmark_epilogue_fusion", False 231 ): 232 foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) 233 with torch.no_grad(): 234 res2, code2 = run_and_get_code(foo_c, m, inp) 235 236 self.assertEqual(res, res2, atol=1e-4, rtol=1.1) 237 return code, code2 238 239 @fresh_inductor_cache() 240 @torch._inductor.config.patch(max_autotune_gemm_backends="TRITON") 241 def test_equivalent_template_code(self): 242 code, code2 = self._equivalent_output_code_impl(256) 243 for out_code in [code, code2]: 244 FileCheck().check("def call").check_count( 245 "empty_strided_cuda", 1, exactly=True 246 ).check("triton_tem_fused_relu_0.run").check_count( 247 "del", 3, exactly=True 248 ).check( 249 "return" 250 ).run( 251 out_code[0] 252 ) 253 254 @fresh_inductor_cache() 255 @torch._inductor.config.patch(max_autotune_gemm_backends="ATEN") 256 def test_equivalent_extern_code(self): 257 torch._dynamo.reset() 258 259 code, code2 = self._equivalent_output_code_impl(512, 1, False) 260 261 for out_code in [code, code2]: 262 FileCheck().check("def call").check_count( 263 "empty_strided_cuda", 1, exactly=True 264 ).check("extern_kernels.").check_count("del", 3, exactly=True).check( 265 "return" 266 ).run( 267 out_code[0] 268 ) 269 270 def test_changed_layout(self): 271 # cat addmm planning will change layout - make sure propagated 272 def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): 273 return torch.cat( 274 [ 275 torch.addmm(a, b, c), 276 torch.addmm(b, c, a), 277 ], 278 1, 279 ) 280 281 args = [ 282 torch.randn(4, 4, device="cuda"), 283 torch.randn(4, 4, device="cuda"), 284 torch.randn(4, 4, device="cuda"), 285 ] 286 287 expected = fn(*args) 288 actual = torch.compile(fn, mode="max-autotune")(*args) 289 self.assertEqual(expected, actual) 290 291 torch._dynamo.reset() 292 293 294if HAS_CPU and not torch.backends.mps.is_available(): 295 296 class BenchmarkFusionCpuTest(TestCase): 297 common = check_model 298 device = "cpu" 299 300 copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCpuTest, "cpu") 301 302if __name__ == "__main__": 303 from torch._inductor.test_case import run_tests 304 305 if HAS_CPU or HAS_CUDA: 306 run_tests() 307