1# Owner(s): ["module: inductor"] 2import gc 3import math 4import sys 5import unittest 6 7import torch 8import torch._dynamo.config as dynamo_config 9import torch.backends.cuda 10import torch.nn.functional as F 11from torch import nn 12from torch._dynamo.debug_utils import same_two_models 13from torch._dynamo.testing import rand_strided 14from torch._dynamo.utils import same 15from torch._inductor import config 16from torch._inductor.compile_fx import compile_fx_inner 17from torch._inductor.runtime.hints import DeviceProperties 18from torch._inductor.utils import ( 19 run_and_get_code, 20 run_and_get_graph_lowering, 21 run_fw_bw_and_get_code, 22) 23from torch.fx.experimental.proxy_tensor import make_fx 24from torch.testing import FileCheck 25from torch.testing._internal.common_cuda import ( 26 PLATFORM_SUPPORTS_FLASH_ATTENTION, 27 SM80OrLater, 28) 29from torch.testing._internal.common_utils import ( 30 DeterministicGuard, 31 freeze_rng_state, 32 IS_FBCODE, 33 skipIfRocm, 34 TEST_WITH_ASAN, 35) 36from torch.testing._internal.inductor_utils import skipCUDAIf 37 38 39try: 40 try: 41 import triton 42 from triton import language as tl 43 except ImportError: 44 raise unittest.SkipTest("requires triton") # noqa: B904 45 46 try: 47 from . import test_torchinductor 48 except ImportError: 49 import test_torchinductor 50except unittest.SkipTest: 51 if __name__ == "__main__": 52 sys.exit(0) 53 raise 54 55 56TestCase = test_torchinductor.TestCase 57ToTuple = test_torchinductor.ToTuple 58check_model_cuda = test_torchinductor.check_model_cuda 59aten = torch.ops.aten 60 61 62class CudaReproTests(TestCase): 63 device = "cuda" 64 common = check_model_cuda 65 66 def test_index_put_issue(self): 67 def forward( 68 self, 69 arg76_1, 70 expand_default, 71 full_like_default, 72 _to_copy_default_67, 73 zeros, 74 ): 75 sum_sym_int_19 = torch.ops.aten.sum(_to_copy_default_67, [0], True) 76 view_default_57 = torch.ops.aten.view.default(sum_sym_int_19, [512, 768]) 77 where_self = torch.ops.aten.where.self( 78 expand_default, view_default_57, full_like_default 79 ) 80 clone_default_12 = torch.ops.aten.clone.default(zeros) 81 index_put__default = torch.ops.aten.index_put_.default( 82 clone_default_12, [arg76_1], where_self, True 83 ) 84 return (index_put__default,) 85 86 inps = [ 87 (torch.Size([512]), torch.int64), 88 (torch.Size([512, 768]), torch.bool), 89 (torch.Size([512, 768]), torch.float16), 90 (torch.Size([4, 512, 768]), torch.float16), 91 (torch.Size([512, 768]), torch.float16), 92 ] 93 inps = [torch.zeros(())] + [ 94 torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps 95 ] 96 mod = make_fx(forward)(*inps) 97 compiled = compile_fx_inner(mod, inps) 98 compiled(inps) 99 100 @skipIfRocm 101 def test_input_channels_last(self): 102 m = torch.nn.Sequential( 103 torch.nn.Conv2d(3, 3, 1, 1), 104 ToTuple(), 105 ).cuda() 106 inp = torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda() 107 108 self.common( 109 m, 110 (inp,), 111 check_lowp=False, 112 ) 113 114 @torch._dynamo.optimize() 115 def foo(m, inp): 116 return m(inp) 117 118 self.assertTrue(foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last)) 119 120 # https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527 121 def test_unspec_inputs_interop(self): 122 class Repro(torch.nn.Module): 123 def forward(self, x, y): 124 unsqueeze = torch.ops.aten.unsqueeze.default(x, 4) 125 permute = torch.ops.aten.permute.default(unsqueeze, [0, 1, 2, 4, 3]) 126 add = torch.ops.aten.add.Tensor(y, 1) 127 return [permute, add] 128 129 inps = [ 130 rand_strided((12, 3, 512, 64), (64, 196608, 768, 1), torch.float32, "cuda"), 131 rand_strided((), (), torch.int64, "cpu"), 132 ] 133 mod = make_fx(Repro().to(device="cuda"))(*inps) 134 compiled = compile_fx_inner(mod, inps) 135 compiled(inps) 136 137 @unittest.skipIf( 138 IS_FBCODE, "RuntimeError: Triton Error [CUDA]: invalid device context" 139 ) 140 def test_backward_context(self): 141 def fn(x): 142 return x * 3 143 144 x = torch.randn(4, device="cuda", requires_grad=True) 145 gO = torch.rand_like(x) 146 opt_fn = torch.compile(fn) 147 out = opt_fn(x) 148 out.backward(gO) 149 150 @config.patch(fallback_random=True) 151 def test_dtype_factory_issue(self): 152 def forward(): 153 randn = torch.ops.aten.randn.default( 154 [12, 64, 1, 64], 155 dtype=torch.float32, 156 device=torch.device(type="cuda", index=0), 157 pin_memory=False, 158 ) 159 unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1) 160 return (unsqueeze_default_2,) 161 162 mod = make_fx(forward)() 163 compiled = compile_fx_inner(mod, ()) 164 assert compiled([])[0].device.type == "cuda" 165 166 @config.patch({"triton.cudagraphs": True}) 167 @dynamo_config.patch(automatic_dynamic_shapes=True) 168 def test_no_device_idx_repro_cudagraphs(self): 169 class Repro(torch.nn.Module): 170 def __init__(self) -> None: 171 super().__init__() 172 173 def forward(self): 174 full = torch.ops.aten.full.default( 175 [8, 512], 176 1, 177 dtype=torch.float32, 178 layout=torch.strided, 179 device=torch.device(type="cuda", index=0), 180 pin_memory=False, 181 ) 182 full_1 = torch.ops.aten.full.default( 183 [8, 512], 184 0, 185 dtype=torch.int64, 186 layout=torch.strided, 187 device=torch.device(type="cuda", index=0), 188 pin_memory=False, 189 ) 190 return (full_1, full) 191 192 self.common(Repro(), ()) 193 194 @config.patch({"triton.cudagraphs": True}) 195 @dynamo_config.patch(automatic_dynamic_shapes=True) 196 def test_expanded_inputs_cudagraphs(self): 197 @torch._dynamo.optimize("inductor") 198 def fn(x, y): 199 return x + y 200 201 inputs = ( 202 rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), 203 rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), 204 ) 205 self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) 206 207 @config.patch({"triton.cudagraphs": True}) 208 @dynamo_config.patch( 209 automatic_dynamic_shapes=True, 210 assume_static_by_default=False, 211 ) 212 def test_dynamic_to_static_cudagraphs(self): 213 for b in [False, True]: 214 with config.patch({"triton.cudagraph_trees": b}): 215 216 @torch._dynamo.optimize("inductor") 217 def fn(x, y): 218 r = x + y 219 return r, r.size(0) 220 221 inputs = ( 222 torch.randn((5, 5), device="cuda"), 223 torch.randn((5, 5), device="cuda"), 224 ) 225 self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 5))) 226 227 inputs = ( 228 torch.randn((6, 6), device="cuda"), 229 torch.randn((6, 6), device="cuda"), 230 ) 231 self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 6))) 232 233 @config.patch({"emulate_precision_casts": True}) 234 def test_emulate_low_precision(self): 235 def foo(x): 236 return torch.nn.functional.gelu(x) * 10.0 237 238 inp = torch.rand([32], device="cuda", requires_grad=True, dtype=torch.bfloat16) 239 out, codes = run_fw_bw_and_get_code(lambda: torch.compile(foo)(inp)) 240 241 # fwd, backward 242 for code in codes: 243 f = FileCheck() 244 # in eager, there are two down casts 245 for _ in range(2): 246 f.check(".to(tl.bfloat16)").check_next(".to(tl.float32)") 247 f.run(code) 248 249 self.assertEqual(foo(inp), out) 250 251 # TODO: Abstract this out, test more extensively 252 @torch._dynamo.config.patch(assume_static_by_default=False) 253 def test_dynamic_shapes(self): 254 torch._dynamo.reset() # Needed since everywhere else uses "inductor" 255 256 def f(x): 257 return x.cos().view(x.shape).sin() 258 259 cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor") 260 261 f2 = torch._dynamo.optimize(cnts)(f) 262 263 f2(torch.randn(32)) 264 265 inp = torch.randn(16) 266 real_out = f(inp) 267 compiled_out = f2(inp) 268 269 self.assertEqual(cnts.frame_count, 1) 270 self.assertEqual(real_out, compiled_out) 271 torch._dynamo.reset() 272 273 @config.patch({"triton.cudagraphs": True, "size_asserts": False}) 274 @dynamo_config.patch(automatic_dynamic_shapes=True) 275 def test_expanded_inputs_cudagraphs_no_size_asserts(self): 276 @torch._dynamo.optimize("inductor") 277 def fn(x, y): 278 return x + y 279 280 inputs = ( 281 rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), 282 rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), 283 ) 284 self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) 285 286 @config.patch({"triton.cudagraph_trees": False}) 287 @config.patch({"triton.cudagraphs": True}) 288 @dynamo_config.patch(automatic_dynamic_shapes=True) 289 def test_inplace_updates_cudagraphs(self): 290 class Repro(torch.nn.Module): 291 def __init__(self) -> None: 292 super().__init__() 293 self.weight1 = torch.nn.Parameter( 294 torch.randn(10, 20, requires_grad=True) 295 ) 296 297 def forward(self, x): 298 x = torch.matmul(x, self.weight1) 299 return x 300 301 from copy import deepcopy 302 303 model = Repro().cuda() 304 model_ref = deepcopy(model) 305 model_opt = torch._dynamo.optimize("inductor")(model) 306 307 input = torch.randn(10, 10, device="cuda", requires_grad=True) 308 309 for i in range(2): 310 output_ref = model_ref(input) 311 output_res = model_opt(input) 312 output_ref.sum().backward() 313 output_res.sum().backward() 314 for p_ref, p_res in zip(model_ref.parameters(), model_opt.parameters()): 315 self.assertEqual(p_ref.grad, p_res.grad) 316 with torch.no_grad(): 317 for param in model_ref.parameters(): 318 param.add_(1.0) 319 for param in model_opt.parameters(): 320 param.add_(1.0) 321 322 # https://github.com/pytorch/torchdynamo/issues/1850 323 def test_inductor_output_aliases_intermediate(self): 324 def foo(x): 325 out = x + x 326 return out.t() 327 328 foo_opt = torch._dynamo.optimize("inductor")(foo) 329 330 inpt = torch.randn(10, 10, device="cuda", requires_grad=True) 331 # TODO: this is broken, fix later 332 # out = foo_opt(inpt) 333 # out.add_(2) 334 335 out_ref = foo(inpt) 336 out_ref.add_(2) 337 # self.assertEqual(out_ref, out) 338 339 def test_accuracy_issue1(self): 340 class Repro(torch.nn.Module): 341 def __init__(self) -> None: 342 super().__init__() 343 self.linear = torch.nn.Linear( 344 in_features=768, out_features=2, bias=True 345 ) 346 347 def forward(self, start_positions: torch.Tensor, x: torch.Tensor): 348 linear = self.linear(x) 349 split = linear.split(1, dim=-1) 350 getitem = split[0] 351 squeeze = getitem.squeeze(-1) 352 clamp = start_positions.clamp(0, 128) 353 cross_entropy = torch.nn.functional.cross_entropy( 354 squeeze, clamp, None, None, 128, None, "mean", 0.0 355 ) 356 return cross_entropy 357 358 mod = Repro().cuda() 359 opt_mod = torch._dynamo.optimize("inductor")(mod) 360 mod.eval() 361 opt_mod.eval() 362 363 args = [ 364 ((1,), (1,), torch.int64, "cuda", False), 365 ((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True), 366 ] 367 args = [ 368 rand_strided(sh, st, dt, dev).requires_grad_(rg) 369 for (sh, st, dt, dev, rg) in args 370 ] 371 with torch.cuda.amp.autocast(enabled=False): 372 assert same_two_models(mod, opt_mod, args), "Dynamo failed" 373 374 @config.patch(allow_buffer_reuse=False) 375 def test_issue103461(self): 376 def forward(add_1): 377 var_mean = torch.ops.aten.var_mean.correction( 378 add_1, [2], correction=0, keepdim=True 379 ) 380 getitem_1 = var_mean[1] 381 return getitem_1 382 383 x = torch.randn(1, 8, 768, device="cuda") 384 correct = forward(x) 385 actual = torch.compile(forward, fullgraph=True)(x) 386 self.assertEqual(actual, correct) 387 388 def test_full_copy(self): 389 def forward(x): 390 full_10 = torch.ops.aten.full.default( 391 [204, 204, 28], 392 0, 393 dtype=torch.float64, 394 layout=torch.strided, 395 device="cuda", 396 pin_memory=False, 397 ) 398 return x + full_10.to("cpu") 399 400 o = torch.randn([204, 204, 28], dtype=torch.float64) 401 correct = forward(o) 402 actual = torch.compile(forward, fullgraph=True)(o) 403 self.assertEqual(actual, correct) 404 405 def test_autotune_inplace_kernel(self): 406 """ 407 This UT tests autotune on an inplace kernel. The autotune should not contaminate 408 the input buffers when tuning with multiple configs. For more details, refer to 409 https://github.com/openai/triton/issues/781 410 https://github.com/pytorch/torchdynamo/issues/1670 411 """ 412 from torch._C import _cuda_getCurrentRawStream as get_cuda_stream 413 from torch._inductor.runtime.hints import HeuristicType, instance_descriptor 414 from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid 415 416 def autotune(configs, meta): 417 def decorator(fn): 418 return CachingAutotuner( 419 # force autotune by setting save_cache_hook to False 420 fn, 421 triton_meta=meta, 422 configs=configs, 423 save_cache_hook=False, 424 mutated_arg_names=["in_out_ptr0"], 425 heuristic_type=HeuristicType.POINTWISE, 426 ) 427 428 return decorator 429 430 @autotune( 431 configs=[ 432 triton.Config({"XBLOCK": 1}), 433 triton.Config({"XBLOCK": 2}), 434 ], 435 meta={ 436 "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, 437 "device": DeviceProperties.create(torch.device("cuda")), 438 "configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())], 439 "constants": {}, 440 }, 441 ) 442 @triton.jit 443 def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr): 444 pid = tl.program_id(0) 445 block_start = pid * XBLOCK 446 offsets = block_start + tl.arange(0, XBLOCK) 447 mask = offsets < xnumel 448 x = tl.load(in_out_ptr0 + offsets, mask=mask, other=0.0) 449 y = tl.load(in_ptr0 + offsets, mask=mask, other=0.0) 450 output = x + y 451 tl.store(in_out_ptr0 + offsets, output, mask=mask) 452 453 xnumel = 384 454 in0 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32) 455 inout1 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32) 456 inout2 = inout1.clone() 457 458 stream0 = get_cuda_stream(0) 459 kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0) 460 kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0) 461 462 assert same( 463 inout1, inout2, tol=0.001, equal_nan=True 464 ), "failed autotune with inplace kernel" 465 466 def test_sort_stride_issue(self): 467 # This minified testcase comes from detectron2_maskrcnn_r_50_fpn 468 # There was a false error from our size_assert code 469 @torch._dynamo.optimize(nopython=True) 470 def forward(pred_objectness_logits_3_: torch.Tensor): 471 sort_3 = pred_objectness_logits_3_.sort(descending=True, dim=1) 472 getitem_12 = sort_3[0] 473 return getitem_12 474 475 args = [((1, 100), (0, 1), torch.float16, "cuda", False)] 476 args = [ 477 rand_strided(sh, st, dt, dev).requires_grad_(rg) 478 for (sh, st, dt, dev, rg) in args 479 ] 480 result = forward(*args) 481 assert same(result, torch.sort(args[0], descending=True, dim=1)[0]) 482 483 def test_scalar_triton_index(self): 484 # The indirect indexing via a scalar like below used to lead to 485 # bad triton code that made triton segfault when compiling. 486 # See https://github.com/pytorch/torchdynamo/issues/1515 487 def fn(a): 488 zero = torch.zeros((16,), device=a.device, dtype=torch.int64) 489 return (a[zero],) 490 491 a = torch.randn((8,), dtype=torch.float32, device="cuda") 492 493 fn_optimized = torch._dynamo.optimize("inductor")(fn) 494 assert same(fn(a), fn_optimized(a)) 495 496 def test_indirect_indexing_dense_mask(self): 497 def fn(x, y): 498 ne = torch.ops.aten.ne.Scalar(x, 1) 499 sum_1 = torch.ops.aten.sum.dim_IntList(ne, [1]) 500 sub = torch.ops.aten.sub.Tensor(sum_1, 1) 501 unsqueeze = torch.ops.aten.unsqueeze.default(sub, -1) 502 gather = torch.ops.aten.gather.default(x, 1, unsqueeze) 503 squeeze = torch.ops.aten.squeeze.default(gather) 504 out = torch.ops.aten.multiply(y, squeeze) 505 return (out,) 506 507 a = torch.zeros((1, 128), dtype=torch.int64, device="cuda") 508 b = torch.zeros((1, 128), dtype=torch.int64, device="cuda") 509 510 fn_optimized = torch._dynamo.optimize("inductor")(fn) 511 assert same(fn(a, b), fn_optimized(a, b)) 512 513 def test_simplify_dims(self): 514 def fn(a): 515 return (a + 1,) 516 517 self.common(fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],)) 518 519 @config.patch(permute_fusion=True) 520 def test_permute_fusion(self): 521 class Repro(torch.nn.Module): 522 def forward(self, view, reshape_2): 523 permute = view.permute(0, 2, 1) 524 view = None 525 reshape = torch.reshape(permute, (-1, 642)) 526 bmm = torch.bmm(permute, reshape_2) 527 return (bmm,) 528 529 args = [ 530 ((1024, 642, 160), (102720, 160, 1), torch.float32, "cuda", True), 531 ((1024, 642, 20), (12840, 20, 1), torch.float32, "cuda", True), 532 ] 533 args = [ 534 rand_strided(sh, st, dt, dev).requires_grad_(rg) 535 for (sh, st, dt, dev, rg) in args 536 ] 537 538 mod = Repro() 539 opt_mod = torch._dynamo.optimize("inductor")(mod) 540 541 ref = mod(*args) 542 res = opt_mod(*args) 543 self.assertTrue(same(ref, res)) 544 545 @config.patch({"triton.autotune_pointwise": True}) 546 def test_inplace_add_alpha_autotune(self): 547 def fn(x, y): 548 aten.add_.Tensor(x, y, alpha=0.55) 549 return (x,) 550 551 x1 = torch.zeros(2, 3, 4, 10, device="cuda") 552 x2 = torch.zeros(2, 3, 4, 10, device="cuda") 553 x3 = torch.zeros(2, 3, 4, 10, device="cuda") 554 y = torch.randn(2, 3, 4, 10, device="cuda").to( 555 memory_format=torch.channels_last 556 ) 557 fn_fx = make_fx(fn)(x1, y) 558 fn_compiled = compile_fx_inner(fn_fx, [x1, y]) 559 fn(x2, y) 560 fn_compiled([x3, y]) 561 assert same(x2, x3) 562 563 @config.patch({"triton.autotune_pointwise": True}) 564 def test_inplace_buffer_autotune(self): 565 def foo(x, y, z): 566 a = x @ y 567 return a.unsqueeze(0).unsqueeze(0) + z 568 569 x = torch.zeros(5, 5, device="cuda") 570 y = torch.zeros(5, 5, device="cuda") 571 z = torch.zeros(1, 1, 5, 5, device="cuda").to(memory_format=torch.channels_last) 572 self.common( 573 foo, 574 (x, y, z), 575 check_lowp=False, 576 ) 577 578 def test_memory_history_inductor(self): 579 def called_inside_compile(x, w, b): 580 a = x @ w + b 581 return torch.sigmoid(a) 582 583 @torch.compile 584 def fn(x, w, b): 585 x = called_inside_compile(x, w, b) 586 return called_inside_compile(x, w, b) 587 588 w = torch.rand(3, 3, device="cuda") 589 b = torch.rand(3, device="cuda") 590 x = torch.rand(3, device="cuda") 591 try: 592 torch.cuda.memory.empty_cache() 593 torch.cuda.memory._record_memory_history(True) 594 r = fn(x, w, b) 595 finally: 596 torch.cuda.memory._record_memory_history(False) 597 snapshot = str(torch.cuda.memory._snapshot()) 598 self.assertTrue("called_inside_compile" in snapshot) 599 600 def test_negative_arange_dynamic_shapes(self): 601 # Repro from alibi relative encodings 602 def sign(x): 603 return (x > 0) - (x < 0) 604 605 class Repro(torch.nn.Module): 606 def __init__(self) -> None: 607 super().__init__() 608 nheads = 16 609 start = math.log2(0.5) 610 end = math.log2(1 / (2**8)) 611 612 self.scales = nn.Buffer( 613 2 614 ** torch.arange( 615 start, 616 end + 1e-6 * sign(end - start), 617 (end - start) / (nheads - 1), 618 ).view(1, nheads, 1, 1), 619 ) 620 self.emb = nn.Embedding(1024, 256) 621 self.dec_layer = nn.TransformerDecoderLayer( 622 256, 16, 512, batch_first=True, norm_first=True 623 ) 624 self.head = nn.Linear(256, 1024) 625 626 def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor): 627 padmask = dec_in == 0 628 dec_mask = padmask.unsqueeze(-1) == padmask.unsqueeze(-2) 629 dec_mask = dec_mask.to(dtype=torch.float32) 630 dec_mask = dec_mask.tril(diagonal=0).cuda() 631 632 q_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda") 633 k_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda") 634 rel_pos = k_pos[None, :] - q_pos[:, None] 635 values = rel_pos.abs().neg().unsqueeze(0).unsqueeze(0) 636 dec_bias = values * self.scales 637 dec_bias.tril_(diagonal=0) 638 639 dec_mask = dec_mask + dec_bias[0] 640 out = self.emb(dec_in) 641 out = self.dec_layer(out, enc_out, tgt_mask=dec_mask) 642 return self.head(out) 643 644 mod = Repro().cuda() 645 opt_mod = torch._dynamo.optimize("inductor", dynamic=True)(mod) 646 mod.eval() 647 opt_mod.eval() 648 649 enc_out = torch.rand(1, 512, 256).cuda() 650 dec_inputs = [ 651 torch.randint(0, 512, (1, i + 1), dtype=torch.long).cuda() for i in range(8) 652 ] 653 654 for dec_inp in dec_inputs: 655 assert same_two_models( 656 mod, opt_mod, [enc_out, dec_inp], only_fwd=True 657 ), "Inductor with dynamic shapes failed" 658 659 def test_issue97695_1input(self): 660 def fn(arg3_1, relu, permute_1): 661 addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1) 662 cat_2 = torch.ops.aten.cat.default([addmm_1], 1) 663 return (cat_2,) 664 665 args = [ 666 ((96,), (1,), torch.float32, "cuda"), 667 ((10, 256), (256, 1), torch.float32, "cuda"), 668 ((256, 96), (1, 256), torch.float32, "cuda"), 669 ] 670 args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] 671 correct = fn(*args) 672 673 mod = make_fx(fn, tracing_mode="real")(*args) 674 compiled = compile_fx_inner(mod, args) 675 ref = compiled(list(args)) 676 assert same(ref, correct) 677 678 ref = torch.compile(fn, fullgraph=True)(*args) 679 assert same(ref, correct) 680 681 def test_issue_103924(self): 682 class MyModule(torch.nn.Module): 683 def __init__(self) -> None: 684 super().__init__() 685 self.temperature = 1 686 self.layer = torch.nn.Softmax(dim=1) 687 688 def forward(self, x): 689 n_samples, _ = x.shape 690 y = 1.0 * torch.ones(n_samples, dtype=x.dtype, device=x.device) 691 inp = x / y[..., None] 692 return self.layer(inp) 693 694 x = torch.rand([4, 4], device="cuda") 695 m = MyModule() 696 opt_m = torch.compile(backend="inductor")(m) 697 self.assertEqual(opt_m(x), m(x)) 698 699 def test_issue97695_2input(self): 700 def fn(arg3_1, arg3_2, relu, permute_1): 701 addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1) 702 addmm_2 = torch.ops.aten.addmm.default(arg3_2, relu, permute_1) 703 cat_2 = torch.ops.aten.cat.default([addmm_1, addmm_2], 1) 704 return (cat_2,) 705 706 args = [ 707 ((96,), (1,), torch.float32, "cuda"), 708 ((96,), (1,), torch.float32, "cuda"), 709 ((10, 256), (256, 1), torch.float32, "cuda"), 710 ((256, 96), (1, 256), torch.float32, "cuda"), 711 ] 712 args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] 713 correct = fn(*args) 714 715 ref = torch.compile(fn, fullgraph=True)(*args) 716 assert same(ref, correct) 717 718 def test_scatter_index_not_wrapped(self): 719 src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], device=self.device) 720 index = torch.tensor([0, 1, 0, 1, 2, 0], device=self.device) 721 input = torch.tensor([1.0, 2.0, 3.0, 4.0], device=self.device) 722 compiled_sr = torch.compile(torch.scatter_reduce) 723 724 input_orig = input.clone() 725 out, code = run_and_get_code(compiled_sr, input, 0, index, src, "sum") 726 # tmp0 - not wrapping of negative numbers 727 FileCheck().check("tl.device_assert(((0 <= tmp0) & (tmp0 < 4))").check_next( 728 "atomic_add" 729 ).run(code[0]) 730 self.assertEqual( 731 out, torch.scatter_reduce(input_orig.clone(), 0, index, src, "sum") 732 ) 733 734 def test_embedding_var_mean(self): 735 def forward(arg0_1): 736 full = torch.ops.aten.full.default( 737 [1, 2048], 738 1, 739 dtype=torch.float32, 740 layout=torch.strided, 741 device=torch.device(type="cuda", index=0), 742 pin_memory=False, 743 ) 744 convert_element_type_1 = torch.ops.prims.convert_element_type.default( 745 full, torch.int64 746 ) 747 cumsum = torch.ops.aten.cumsum.default(convert_element_type_1, 1) 748 mul = torch.ops.aten.mul.Tensor(cumsum, convert_element_type_1) 749 sub_1 = torch.ops.aten.sub.Tensor(mul, 1) 750 slice_5 = torch.ops.aten.slice.Tensor(sub_1, 0, 0, 9223372036854775807) 751 slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 0, 9223372036854775807) 752 add_2 = torch.ops.aten.add.Tensor(slice_6, 2) 753 embedding_1 = torch.ops.aten.embedding.default(arg0_1, add_2) 754 var_mean = torch.ops.aten.var_mean.correction( 755 embedding_1, [2], correction=0, keepdim=True 756 ) 757 return [var_mean[0], var_mean[1], add_2] 758 759 emb = torch.randn([2050, 768], device="cuda") 760 gm = make_fx(forward)(emb) 761 opt = torch._inductor.compile_fx.compile_fx_inner(gm, [emb]) 762 opt([emb]) 763 torch.cuda.synchronize() 764 765 def test_deterministic_algorithms(self): 766 N = 10000 767 768 @torch.compile 769 def fn(idx, values): 770 x = torch.zeros(1, device="cuda") 771 x[idx] += values 772 return x 773 774 idx = torch.zeros(N, dtype=torch.int64, device="cuda") 775 values = torch.randn(N, device="cuda") 776 777 r0 = fn(idx, values) 778 with DeterministicGuard(True): 779 r1 = fn(idx, values) 780 for _ in range(10): 781 rn = fn(idx, values) 782 self.assertEqual(r1, rn, atol=0, rtol=0) 783 784 # https://github.com/pytorch/pytorch/issues/96406 785 def test_linear_cpu_input(self): 786 class Model(nn.Module): 787 def __init__(self) -> None: 788 super().__init__() 789 self.linear = nn.Linear(4, 4) 790 791 def forward(self, data): 792 data = data.to("cuda") 793 return self.linear(data) 794 795 mod = Model().cuda().eval() 796 with torch.no_grad(): 797 self.common(mod, (torch.randn(4, 4),)) 798 799 @config.patch({"fallback_random": True, "triton.cudagraphs": True}) 800 def test_xlnet_lm_stride_repro(self): 801 class Repro(nn.Module): 802 def __init__(self) -> None: 803 super().__init__() 804 self.dropout = nn.Dropout(p=0.1, inplace=False) 805 806 def forward(self, x): 807 y = torch._C._nn.gelu(x) 808 return self.dropout(y) 809 810 mod = Repro() 811 x = torch.randn((512, 1, 4096), requires_grad=True, device="cuda") 812 y = torch.compile(mod)(x) 813 # Inductor claims the output layout of gelu's saved variable for 814 # backwards will be (4096, 4096, 1) but in actuality it is (4096, 815 # 2097152, 1). Fortunately this doesn't actually matter in practice. 816 y.sum().backward() 817 818 def test_lookup_seed_backward(self): 819 @torch.compile(fullgraph=True) 820 def forward(inductor_seeds, mul_4, view_15): 821 inductor_lookup_seed_2 = torch.ops.prims.inductor_lookup_seed.default( 822 inductor_seeds, 2 823 ) 824 inductor_random_2 = torch.ops.prims.inductor_random.default( 825 [2, 512, 768], inductor_lookup_seed_2, "rand" 826 ) 827 gt_2 = torch.ops.aten.gt.Scalar(inductor_random_2, 0.1) 828 mul_7 = torch.ops.aten.mul.Tensor(gt_2, view_15) 829 mul_8 = torch.ops.aten.mul.Tensor(mul_7, 1.1111111111111112) 830 add_5 = torch.ops.aten.add.Tensor(mul_8, mul_4) 831 var_mean_1 = torch.ops.aten.var_mean.correction( 832 add_5, [2], correction=0, keepdim=True 833 ) 834 getitem_3 = var_mean_1[1] 835 sub_3 = torch.ops.aten.sub.Tensor(add_5, getitem_3) 836 return (sub_3,) 837 838 buf0 = torch.zeros((37,), dtype=torch.int64, device="cuda") 839 buf1 = torch.zeros((2, 512, 768), device="cuda") 840 buf2 = torch.zeros((2, 512, 768), device="cuda") 841 forward(buf0, buf1, buf2) 842 843 def test_issue100806(self): 844 class Model(torch.nn.Module): 845 def __init__(self) -> None: 846 super().__init__() 847 self.linear1 = torch.nn.Linear(10, 20) 848 self.linear2 = torch.nn.Linear(20, 30) 849 self.relu = torch.nn.ReLU() 850 851 def forward(self, x): 852 x = self.linear1(x) 853 x = self.linear2(x) 854 x = torch.cat((x, x), dim=1) 855 x = x.view(-1, 2, 30) 856 x = x[:, 1, :] 857 x = self.relu(x) 858 return x 859 860 device = "cuda" 861 batch_size = 2 862 x = torch.randn(batch_size, 10).to(device) 863 func = Model().to(device) 864 865 with torch.no_grad(): 866 func.train(False) 867 jit_func = torch.compile(func) 868 869 res1 = func(x) 870 res2 = jit_func(x) 871 self.assertEqual(res1, res2) 872 873 def test_issue103481(self): 874 def fn(x, y): 875 # NOTE: 6 dimensions is important! does not fail for 5 dimensions 876 mean = torch.mean(x, [2, 3, 4, 5], keepdim=True) 877 add = mean + y 878 return add 879 880 x = torch.rand(4, 4, 4, 4, 4, 4, device="cuda") 881 y = torch.rand((), device="cuda") 882 expect = fn(x, y) 883 884 opt_fn = torch.compile(fn) 885 actual = opt_fn(x, y) 886 887 self.assertEqual(expect, actual) 888 889 @config.patch({"triton.dense_indexing": True}) 890 @dynamo_config.patch(automatic_dynamic_shapes=True) 891 def test_bucketize_dynamic_dense(self): 892 """ 893 Make sure that ops.bucketize() can handle dense_indexing, which previously 894 caused issues due to incorrect handling of the size of offsets. 895 """ 896 897 def fn(values, offsets): 898 return torch.bucketize(values, offsets) 899 900 values = torch.rand((64, 64), device="cuda") 901 offsets = torch.tensor([0.05, 0.1, 0.5, 0.8, 0.85, 0.95], device="cuda") 902 903 expect = fn(values, offsets) 904 905 opt_fn = torch.compile(fn, dynamic=True) 906 actual = opt_fn(values, offsets) 907 908 self.assertEqual(expect, actual) 909 910 def test_float64_constants(self): 911 def fn(): 912 # NOTE: tensors of all the same value are constant folded, so we 913 # need a tensor with two distinct values 914 a = torch.tensor([1 / 10, 2 / 10], dtype=torch.float64, device="cuda") 915 return a * 2e50 916 917 cfn = torch.compile(fn) 918 expect = fn() 919 actual = cfn() 920 self.assertEqual(expect, actual, atol=0, rtol=0) 921 922 def test_issue104759(self): 923 def fn(arg7_1, add_1, permute_2, select_scatter, slice_8): 924 slice_scatter_4 = torch.ops.aten.slice_scatter.default( 925 permute_2, select_scatter, 0, 1, 9223372036854775807 926 ) 927 permute_3 = torch.ops.aten.permute.default(slice_scatter_4, [1, 3, 0, 2, 4]) 928 view_6 = torch.ops.aten.view.default(permute_3, [1, 1000, 48]) 929 view_7 = torch.ops.aten.view.default(view_6, [1000, 48]) 930 view_8 = torch.ops.aten.view.default(view_7, [1, 1000, 48]) 931 view_9 = torch.ops.aten.view.default(view_8, [1, 1000, 3, 4, 4]) 932 permute_4 = torch.ops.aten.permute.default(view_9, [2, 0, 3, 1, 4]) 933 slice_7 = torch.ops.aten.slice.Tensor(permute_4, 0, 1, 9223372036854775807) 934 slice_scatter_5 = torch.ops.aten.slice_scatter.default( 935 slice_8, slice_7, 4, 0, 9223372036854775807 936 ) 937 slice_scatter_6 = torch.ops.aten.slice_scatter.default( 938 arg7_1, slice_scatter_5, 3, 0, 1000 939 ) 940 mul_8 = torch.ops.aten.mul.Scalar(add_1, 0.7071067811865476) 941 slice_9 = torch.ops.aten.slice.Tensor(slice_scatter_6, 3, 0, 1000) 942 slice_10 = torch.ops.aten.slice.Tensor(slice_9, 4, 0, 9223372036854775807) 943 select_2 = torch.ops.aten.select.int(slice_10, 0, 0) 944 permute_5 = torch.ops.aten.permute.default(select_2, [0, 1, 3, 2]) 945 mul_9 = torch.ops.aten.mul.Scalar(permute_5, 0.7071067811865476) 946 expand = torch.ops.aten.expand.default(mul_8, [1, 4, 1000, 4]) 947 view_10 = torch.ops.aten.view.default(expand, [4, 1000, 4]) 948 expand_1 = torch.ops.aten.expand.default(mul_9, [1, 4, 4, 1000]) 949 view_11 = torch.ops.aten.view.default(expand_1, [4, 4, 1000]) 950 bmm = torch.ops.aten.bmm.default(view_10, view_11) 951 return (bmm,) 952 953 args = [] 954 args.append(torch.randn((2, 1, 4, 1200, 4), dtype=torch.float16, device="cuda")) 955 args.append( 956 rand_strided( 957 (1, 4, 1000, 4), (16000, 4, 16, 1), dtype=torch.float16, device="cuda" 958 ) 959 ) 960 args.append( 961 rand_strided( 962 (3, 1, 4, 1000, 4), 963 (16, 48000, 4, 48, 1), 964 dtype=torch.float16, 965 device="cuda", 966 ) 967 ) 968 args.append( 969 rand_strided( 970 (2, 1, 4, 1000, 4), 971 (16, 48000, 4, 48, 1), 972 dtype=torch.float16, 973 device="cuda", 974 ) 975 ) 976 args.append( 977 rand_strided( 978 (2, 1, 4, 1000, 4), 979 (19200, 19200, 4800, 4, 1), 980 dtype=torch.float16, 981 device="cuda", 982 ) 983 ) 984 985 correct = fn(*args) 986 mod = make_fx(fn, tracing_mode="real")(*args) 987 compiled = compile_fx_inner(mod, args) 988 ref = compiled(list(args)) 989 assert same(ref, correct) 990 991 @config.patch({"triton.cudagraphs": True}) 992 def test_index_put_inplace_cudagraph(self): 993 def fn(x, y, z): 994 x = torch.zeros_like(x) 995 return x.index_put_([y], z, True) 996 997 x = torch.zeros((512, 512), device="cuda", dtype=torch.bool) 998 y = torch.zeros((512,), device="cuda", dtype=torch.int64) 999 z = torch.ones((512, 512), device="cuda", dtype=torch.bool) 1000 1001 opt_fn = torch._dynamo.optimize("inductor")(fn) 1002 1003 ref = fn(x, y, z) 1004 1005 # run it twice to test cuda graph issue 1006 res = opt_fn(x, y, z) 1007 res = opt_fn(x, y, z) 1008 1009 self.assertEqual(ref, res) 1010 1011 @config.patch({"triton.cudagraphs": True}) 1012 @config.patch({"fx_graph_cache": True}) 1013 def test_index_put_cudagraph(self): 1014 for _ in range(2): 1015 1016 def fn(x, y, z): 1017 x = torch.zeros_like(x) 1018 return x.index_put([y], z, True) 1019 1020 x = torch.zeros((512, 512), device="cuda", dtype=torch.bool) 1021 y = torch.zeros((512,), device="cuda", dtype=torch.int64) 1022 z = torch.ones((512, 512), device="cuda", dtype=torch.bool) 1023 1024 opt_fn = torch._dynamo.optimize("inductor")(fn) 1025 1026 ref = fn(x, y, z) 1027 1028 # run it twice to test cuda graph issue 1029 res = opt_fn(x, y, z) 1030 res = opt_fn(x, y, z) 1031 1032 self.assertEqual(ref, res) 1033 torch._dynamo.reset() 1034 gc.collect() 1035 1036 @unittest.skipIf( 1037 not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported" 1038 ) 1039 def test_flash_attention_dynamic(self): 1040 class Model(nn.Module): 1041 def __init__(self, *args, **kwargs) -> None: 1042 super().__init__(*args, **kwargs) 1043 1044 self.q = nn.Linear(1024, 1024) 1045 self.k = nn.Linear(1024, 1024) 1046 self.v = nn.Linear(1024, 1024) 1047 1048 def forward(self, x): 1049 batch_size, seq_len, _ = x.size() 1050 1051 queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1) 1052 keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1) 1053 values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1) 1054 1055 attn = F.scaled_dot_product_attention( 1056 queries, 1057 keys, 1058 values, 1059 ) 1060 1061 return attn 1062 1063 cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor") 1064 1065 model = Model().cuda().half() 1066 model = torch.compile(model, backend=cnts, dynamic=True) 1067 1068 with torch.backends.cuda.sdp_kernel( 1069 enable_flash=True, 1070 enable_math=False, 1071 enable_mem_efficient=False, 1072 enable_cudnn=False, 1073 ): 1074 input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16) 1075 input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16) 1076 input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16) 1077 1078 out1 = model(input1) 1079 out2 = model(input2) 1080 out3 = model(input3) 1081 1082 self.assertEqual(cnts.frame_count, 1) 1083 1084 @config.patch({"triton.cudagraphs": True}) 1085 def test_index_put_no_fallback_cudagraph(self): 1086 def fn(x, y, z): 1087 x = torch.zeros_like(x) 1088 return x.index_put([y], z, True) 1089 1090 x = torch.zeros((512, 512), device="cuda", dtype=torch.int32) 1091 y = torch.zeros((512,), device="cuda", dtype=torch.int64) 1092 z = torch.ones((512, 512), device="cuda", dtype=torch.int32) 1093 1094 opt_fn = torch._dynamo.optimize("inductor")(fn) 1095 1096 ref = fn(x, y, z) 1097 1098 # run it twice to test cuda graph issue 1099 res = opt_fn(x, y, z) 1100 res = opt_fn(x, y, z) 1101 1102 self.assertEqual(ref, res) 1103 1104 # https://github.com/pytorch/pytorch/issues/104937 1105 def test_linear_with_zero_infeature_size(self): 1106 m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda") 1107 x = torch.rand(1, 1, 0, device="cuda") 1108 expect = m(x) 1109 opt_fn = torch.compile(m) 1110 actual = opt_fn(x) 1111 self.assertEqual(expect, actual) 1112 1113 @config.patch(fallback_random=True) 1114 def test_multi_output_layout_fallback(self): 1115 mod = nn.RReLU(lower=3.2350976, upper=8.4220314, inplace=True) 1116 inp = torch.rand([4, 4]).cuda() 1117 m = torch.compile(mod) 1118 1119 with freeze_rng_state(): 1120 o1 = m(inp.clone()) 1121 1122 o2 = mod(inp.clone()) 1123 1124 self.assertEqual(o1, o2) 1125 1126 def test_cat_int8_one_kernel(self): 1127 @torch.compile() 1128 def cat(inps): 1129 return torch.cat(inps) + 1 1130 1131 for dtype in [torch.uint8, torch.int8]: 1132 inps = [ 1133 torch.empty([256, 256], dtype=dtype, device="cuda") for _ in range(4) 1134 ] 1135 1136 out, code = run_and_get_code(cat, inps) 1137 self.assertEqual(torch.cat(inps) + 1, out) 1138 FileCheck().check_not("aten.cat.default(").check_count( 1139 ".run(", 1, exactly=True 1140 ).run(code[0]) 1141 1142 @config.patch("triton.use_block_ptr", True) 1143 def test_selecsls42b_misaligned_address(self): 1144 # https://github.com/openai/triton/issues/2836 1145 1146 @torch.compile(fullgraph=True) 1147 def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3): 1148 div = torch.ops.aten.div.Scalar(expand, 16) 1149 where = torch.ops.aten.where.self(arg207_1, full, div) 1150 convert_element_type_43 = torch.ops.prims.convert_element_type.default( 1151 where, torch.float32 1152 ) 1153 sum_2 = torch.ops.aten.sum.dim_IntList(convert_element_type_43, [0, 2, 3]) 1154 sub = torch.ops.aten.sub.Tensor(convert_element_type_40, arg208_1) 1155 mul = torch.ops.aten.mul.Tensor(convert_element_type_43, sub) 1156 sum_3 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3]) 1157 mul_1 = torch.ops.aten.mul.Tensor(sum_2, 0.0078125) 1158 unsqueeze = torch.ops.aten.unsqueeze.default(mul_1, 0) 1159 unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2) 1160 unsqueeze_2 = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3) 1161 mul_2 = torch.ops.aten.mul.Tensor(sum_3, 0.0078125) 1162 mul_4 = torch.ops.aten.mul.Tensor(mul_2, mul_3) 1163 unsqueeze_3 = torch.ops.aten.unsqueeze.default(mul_4, 0) 1164 unsqueeze_4 = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2) 1165 unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3) 1166 mul_6 = torch.ops.aten.mul.Tensor(sub, unsqueeze_5) 1167 sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_43, mul_6) 1168 sub_2 = torch.ops.aten.sub.Tensor(sub_1, unsqueeze_2) 1169 return (sub_2,) 1170 1171 args = [ 1172 torch.randn((8, 1024, 4, 4), device="cuda") > 0, # torch.bool tensor 1173 torch.randn((1, 1024, 1, 1), device="cuda"), 1174 torch.randn((8, 1024, 4, 4), device="cuda"), 1175 torch.randn((8, 1024, 1, 1), dtype=torch.float16, device="cuda").expand( 1176 (8, 1024, 4, 4) 1177 ), 1178 torch.randn((), device="cuda"), 1179 torch.randn((1024,), device="cuda"), 1180 ] 1181 fn(*args) 1182 torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address 1183 1184 @skipIfRocm 1185 def test_non_commutative_scan_op(self): 1186 from torch._higher_order_ops.associative_scan import associative_scan 1187 1188 a = torch.randn(1024, 8192, dtype=torch.float64, device="cuda") 1189 b = torch.randn(1024, 8192, dtype=torch.float64, device="cuda") 1190 1191 def baseline(v, u): 1192 A = [] 1193 A.append(b[:, 0]) 1194 for i in range(1, v.shape[1]): 1195 A.append(a[:, i] * A[i - 1] + b[:, i]) 1196 return torch.stack(A, dim=1) 1197 1198 def combine_fn(i, j): 1199 ia, ib = i 1200 ja, jb = j 1201 return ia * ja, ib * ja + jb 1202 1203 @torch.compile 1204 def compiled_scan(a, b): 1205 return associative_scan(combine_fn, (a, b), dim=-1)[1] 1206 1207 out1 = baseline(a, b) 1208 out2 = compiled_scan(a, b) 1209 self.assertEqual(out1, out2) 1210 1211 def test_dynamic_persistent_reductions(self): 1212 @torch.compile(dynamic=True) 1213 def inner_reduce(x): 1214 assert x.shape[1] <= 1024 1215 return x.sum(1) 1216 1217 a = torch.randn(50, 600, device="cuda") 1218 out, code = run_and_get_code(inner_reduce, a) 1219 self.assertEqual(inner_reduce(a), out) 1220 self.assertTrue("for roffset" not in code) 1221 1222 @torch.compile(dynamic=True) 1223 def outer_reduce(x): 1224 assert x.shape[0] <= 64 1225 return x.sum(0) 1226 1227 out, code = run_and_get_code(outer_reduce, a) 1228 self.assertEqual(outer_reduce(a), out) 1229 self.assertTrue("for roffset" not in code) 1230 1231 def test_non_contiguous_unaligned_input_indices(self): 1232 from torch._inductor.compile_fx import remove_unaligned_input_idxs 1233 1234 inputs = [torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")[1:]] 1235 idxs = remove_unaligned_input_idxs(inputs, [1]) 1236 self.assertEqual(idxs, []) 1237 1238 inputs = [ 1239 torch.ones(2, 2, device="cuda"), 1240 torch.ones(2, 2, device="cuda"), 1241 torch.ones(2, 2, device="cuda")[1:], 1242 ] 1243 idxs = remove_unaligned_input_idxs(inputs, [0, 2]) 1244 self.assertEqual(idxs, [0]) 1245 1246 @config.patch("triton.cudagraphs", True) 1247 def test_unused_cpu_input_cudagraphs(self): 1248 def fn(x, y): 1249 return x.sin().sin().sin().sin().cos() + 1 1250 1251 fx_graph = torch.fx.symbolic_trace(fn) 1252 inp = [torch.randn(64, device="cuda"), torch.randn(64, device="cpu")] 1253 compiled_fn, (graph,) = run_and_get_graph_lowering( 1254 torch._inductor.compile, fx_graph, inp 1255 ) 1256 self.assertEqual(graph.disable_cudagraphs_reason, None) 1257 self.assertEqual(graph.device_types, {"cuda"}) 1258 self.assertEqual(compiled_fn(*inp), fn(*inp)) 1259 1260 def test_epilogue_fusion_with_view(self): 1261 class ToyModel(torch.nn.Module): 1262 def __init__(self) -> None: 1263 super().__init__() 1264 self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 1265 self.linear = torch.nn.Linear(262144, 100) 1266 self.relu = torch.nn.ReLU() 1267 1268 def forward(self, x): 1269 x = self.conv(x) 1270 x = x.view(x.size(0), -1) 1271 return self.relu(self.linear(x)) 1272 1273 m = ToyModel().to(device="cuda:0") 1274 input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0") 1275 from torch._inductor.utils import fresh_inductor_cache 1276 1277 with fresh_inductor_cache(): 1278 cm = torch.compile(m, mode="max-autotune") 1279 out = cm(input_tensor) 1280 out2 = m(input_tensor) 1281 self.assertEqual(out, out2, atol=1e-3, rtol=1e-3) 1282 1283 @config.patch("triton.cudagraphs", True) 1284 def test_cpu_index(self): 1285 @torch.compile(fullgraph=True) 1286 def fn(x): 1287 return x[torch.arange(32)] 1288 1289 result, (graph,) = run_and_get_graph_lowering( 1290 fn, torch.randn(64, device="cuda") 1291 ) 1292 self.assertEqual(graph.disable_cudagraphs_reason, None) 1293 self.assertEqual(graph.device_types, {"cuda"}) 1294 1295 inp = torch.randn(64, device="cuda", requires_grad=True) 1296 result, (graph,) = run_and_get_graph_lowering(fn, inp) 1297 self.assertEqual(graph.disable_cudagraphs_reason, None) 1298 self.assertEqual(graph.device_types, {"cuda"}) 1299 1300 result, (graph,) = run_and_get_graph_lowering(lambda: result.sum().backward()) 1301 self.assertEqual(graph.disable_cudagraphs_reason, None) 1302 self.assertEqual(graph.device_types, {"cuda"}) 1303 1304 def test_reflection_pad_loop_order(self): 1305 def fn(x, y): 1306 a = torch.nn.functional.pad(x, (5, 5, 5, 5), mode="reflect") 1307 b = torch.nn.functional.pad(y, (5, 5, 5, 5), mode="reflect") 1308 return a + b 1309 1310 cfn = torch.compile(fn) 1311 a = torch.rand((10, 10, 10), device="cuda") 1312 b = torch.rand((10, 10, 10), device="cuda") 1313 expect = fn(a, b) 1314 actual, code = run_and_get_code(cfn, a, b) 1315 self.assertEqual(expect, actual) 1316 1317 # Expect the code iterates in contiguous order, and is not tiled 1318 kernel_code = "\n".join(code[0].split("\n")[60:74]) 1319 self.assertExpectedInline( 1320 kernel_code, 1321 """\ 1322@triton.jit 1323def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): 1324 xnumel = 4000 1325 xoffset = tl.program_id(0) * XBLOCK 1326 xindex = xoffset + tl.arange(0, XBLOCK)[:] 1327 xmask = xindex < xnumel 1328 x0 = xindex % 20 1329 x1 = (xindex // 20) % 20 1330 x2 = (xindex // 400) 1331 x3 = xindex 1332 tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') 1333 tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') 1334 tmp2 = tmp0 + tmp1 1335 tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 1336 ) 1337 1338 @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") 1339 def test_int64_index_intermediate(self): 1340 def foo(inp): 1341 view_23 = torch.ops.aten.view.default(inp, [-1, 8192, 8192]) 1342 split_1 = torch.ops.aten.split.Tensor(view_23, 1024, 1) 1343 view_23 = None 1344 getitem_17 = split_1[0] 1345 getitem_18 = split_1[1] 1346 getitem_19 = split_1[2] 1347 getitem_20 = split_1[3] 1348 getitem_21 = split_1[4] 1349 getitem_22 = split_1[5] 1350 getitem_23 = split_1[6] 1351 getitem_24 = split_1[7] 1352 split_1 = None 1353 cat_1 = torch.ops.aten.cat.default( 1354 [ 1355 getitem_17, 1356 getitem_18, 1357 getitem_19, 1358 getitem_20, 1359 getitem_21, 1360 getitem_22, 1361 getitem_23, 1362 getitem_24, 1363 ] 1364 ) 1365 getitem_17 = ( 1366 getitem_18 1367 ) = ( 1368 getitem_19 1369 ) = getitem_20 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None 1370 return cat_1 1371 1372 for mark_dynamic in [False, True]: 1373 inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device="cuda") 1374 if mark_dynamic: 1375 torch._dynamo.mark_dynamic(inp, 0) 1376 foo_c = torch.compile(foo) 1377 torch.testing.assert_allclose(foo(inp), foo_c(inp)) 1378 1379 1380if __name__ == "__main__": 1381 from torch._inductor.test_case import run_tests 1382 from torch.testing._internal.inductor_utils import HAS_CUDA 1383 1384 if HAS_CUDA and not TEST_WITH_ASAN: 1385 run_tests(needs="filelock") 1386