1# Owner(s): ["module: inductor"] 2import contextlib 3import importlib 4import math 5import operator 6import os 7import sys 8import unittest 9from functools import partial 10from typing import List, Tuple 11 12import torch 13import torch.library 14from torch._dynamo.testing import make_test_cls_with_patches 15from torch._inductor import metrics 16from torch._inductor.codegen.common import device_codegens, register_backend_for_device 17from torch._inductor.codegen.cpp import CppScheduling 18from torch._inductor.codegen.wrapper import WrapperCodeGen 19from torch._inductor.test_case import TestCase 20from torch._inductor.utils import run_and_get_code 21from torch._inductor.virtualized import V 22from torch.testing import FileCheck 23from torch.testing._internal.common_device_type import ( 24 instantiate_device_type_tests, 25 onlyCPU, 26 onlyOn, 27) 28from torch.testing._internal.common_utils import ( 29 IS_ARM64, 30 IS_FBCODE, 31 parametrize, 32 TEST_CUDA_MEM_LEAK_CHECK, 33 TEST_WITH_ASAN, 34 TEST_WITH_ROCM, 35) 36from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU 37 38 39# Make the helper files in test/ importable 40pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 41sys.path.append(pytorch_test_dir) 42from inductor.test_torchinductor import ( 43 check_model, 44 check_model_gpu, 45 CommonTemplate, 46 copy_tests, 47 TestFailure, 48) 49 50 51importlib.import_module("filelock") 52 53# xfail by default, set is_skip=True to skip 54test_failures = { 55 "test_kwargs_dynamic_shapes": TestFailure(("cpu",)), 56 # calling div on only symint args 57 "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure( 58 ("cpu", "cuda", "xpu") 59 ), 60 "test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")), 61} 62 63if TEST_WITH_ROCM: 64 # Tensor-likes are not close 65 test_failures["test_dynamic_stride_nobreak"] = TestFailure( 66 ("cpu", "cuda"), is_skip=True 67 ) 68 test_failures["test_item_to_inputs_kernel_nobreak"] = TestFailure( 69 ("cpu", "cuda"), is_skip=True 70 ) 71 test_failures["test_unbacked_reduction"] = TestFailure(("cpu"), is_skip=True) 72 73 74if os.getenv("BUILD_ENVIRONMENT", "").endswith("-debug"): 75 # Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073 76 test_failures["test_resize_as_dynamic_shapes"] = TestFailure(("cpu", "cuda")) 77 test_failures["test_resize_dynamic_shapes"] = TestFailure(("cpu", "cuda")) 78 79 80def make_dynamic_cls(cls, xfail_prop="_expected_failure_dynamic"): 81 return make_test_cls_with_patches( 82 cls, 83 "DynamicShapes", 84 "_dynamic_shapes", 85 (torch._dynamo.config, "assume_static_by_default", False), 86 xfail_prop=xfail_prop, 87 ) 88 89 90DynamicShapesCommonTemplate = make_dynamic_cls(CommonTemplate) 91 92 93if HAS_CPU: 94 95 class DynamicShapesCpuTests(TestCase): 96 common = check_model 97 device = "cpu" 98 99 copy_tests(DynamicShapesCommonTemplate, DynamicShapesCpuTests, "cpu", test_failures) 100 101 102if HAS_GPU and not TEST_WITH_ASAN: 103 104 class DynamicShapesGPUTests(TestCase): 105 common = check_model_gpu 106 device = GPU_TYPE 107 108 copy_tests( 109 DynamicShapesCommonTemplate, DynamicShapesGPUTests, GPU_TYPE, test_failures 110 ) 111 112 113class TestInductorDynamic(TestCase): 114 compile_fn = partial(torch.compile, dynamic=True) 115 116 def setUp(self): 117 # HAS_CUDA also checks compute capability to skip tests 118 # on older devices 119 if not HAS_GPU: 120 self.skipTest("Triton not available") 121 torch._dynamo.reset() 122 TestCase.setUp(self) 123 # this should be in setUpClass, but device-generic tests 124 # don't work with setUpClass well (non-deterministically the wrong setUpClass is resolved), 125 # so put it in test setUp, it's cheap 126 self._stack = contextlib.ExitStack() 127 self._stack.enter_context( 128 torch._inductor.config.patch( 129 { 130 "debug": False, 131 "cpp.min_chunk_size": 1, 132 "triton.autotune_pointwise": False, # too slow 133 "implicit_fallbacks": False, 134 } 135 ) 136 ) 137 138 def tearDown(self): 139 self._stack.close() 140 TestCase.tearDown(self) 141 torch._dynamo.reset() 142 143 def test_constant_fold_uniform_value_dynamic(self, device): 144 def full_add_zero(x): 145 a = torch.full(x.shape, 1, dtype=x.dtype, device=x.device) 146 b = a - 1 147 return x + b 148 149 def full_mul_one(x): 150 a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device) 151 b = 2 + a 152 return x * b 153 154 def full_view_op(x): 155 a = torch.ones([1], dtype=x.dtype, device=x.device) 156 a = a[:, None] 157 return x * a 158 159 def full_mul_symint(x): 160 a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device) 161 b = 2 + a 162 return b * x.shape[0] 163 164 fns = (full_add_zero, full_mul_one, full_view_op) 165 166 x = torch.randn((2, 4), device=device) 167 y = torch.randn((3, 4), device=device) 168 169 for dynamic in [False, True]: 170 torch._dynamo.reset() 171 for fn in fns: 172 ref = fn(x) 173 fn_c = torch.compile(fn, dynamic=dynamic) 174 175 actual, source_codes = run_and_get_code(fn_c, x) 176 177 if fn is not full_mul_symint: 178 # due to constant folding, fn returns x directly. 179 if device == "cpu": 180 FileCheck().check_not("cpp_fused").run(source_codes[0]) 181 else: 182 FileCheck().check_not("triton.jit").run(source_codes[0]) 183 184 self.assertEqual(ref, actual) 185 self.assertEqual(fn(x), fn_c(x)) 186 self.assertEqual(fn(y), fn_c(y)) 187 188 def test_arange_dynamic(self, device): 189 def fn(a): 190 batch_size = a.numel() 191 max_len = a.max() 192 return ~( 193 torch.arange(0, max_len, device=a.device) 194 .type_as(a) 195 .repeat(batch_size, 1) 196 .lt(a.unsqueeze(1)) 197 ) 198 199 a = torch.randint(10, 30, (10,), device=device) 200 a[0] = 29 # fix max_len 201 opt = self.compile_fn(fn) 202 res = opt(a) 203 ref = fn(a) 204 self.assertEqual(res, ref) 205 206 def test_shape_as_constant_reciprocal_float_exp(self, device): 207 def fn(x, a): 208 return x, -1 / a**1.0 209 210 x = torch.rand(10, 20, device=device) 211 opt = self.compile_fn(fn) 212 res = opt(x, x.size(0)) 213 ref = fn(x, x.size(0)) 214 self.assertEqual(res, ref) 215 216 # not supported yet on cpu, https://github.com/pytorch/pytorch/issues/109897 217 @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 218 def test_bool_mask_nobreak(self, device): 219 def f(x, b): 220 return (x[b] * 2).sum() 221 222 opt_f = torch.compile(f, fullgraph=True) 223 x = torch.randn(5, device=device) 224 b = torch.tensor([True, True, False, False, True], device=device) 225 r = f(x, b) 226 opt_r = opt_f(x, b) 227 self.assertEqual(r, opt_r) 228 229 def test_adaptive_max_pool3d_with_indices(self, device): 230 x = 5 231 y = torch.rand([9, 10, 9, 8, 6], dtype=torch.float32, device=device) 232 233 def fn(x, y): 234 return torch.nn.functional.adaptive_max_pool3d_with_indices( 235 output_size=x, input=y, return_indices=True 236 ) 237 238 opt_f = self.compile_fn(fn) 239 r = fn(x, y) 240 opt_r = opt_f(x, y) 241 self.assertEqual(r, opt_r) 242 243 @torch._dynamo.config.patch(capture_scalar_outputs=True) 244 def test_unwrap_storage_didnt_work_repro(self, device): 245 def f(): 246 full = torch.full((), 11) 247 i0 = full.item() 248 torch._check_is_size(i0) 249 return torch.full((i0,), 0) 250 251 opt_f = torch.compile(f, fullgraph=True) 252 r = f() 253 opt_r = opt_f() 254 self.assertEqual(r, opt_r) 255 256 @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 257 def test_nonzero_size_factory_nobreak(self, device): 258 def f(x, b): 259 y = torch.nonzero(b) 260 return x.new_zeros(y.size(0)) 261 262 opt_f = torch.compile(f, fullgraph=True) 263 x = torch.randn(5, device=device) 264 b = torch.tensor([True, True, False, False, True], device=device) 265 r = f(x, b) 266 opt_r = opt_f(x, b) 267 self.assertEqual(r, opt_r) 268 269 @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 270 def test_nonzero_no_realloc(self, device): 271 @torch.compile(fullgraph=True, dynamic=True) 272 def f(x, y): 273 z = x.nonzero() 274 return torch.split(z, [y.size(0)]) 275 276 f(torch.tensor([1, 0, 1, 1, 0, 1, 0]), torch.randn(4)) 277 278 @torch._dynamo.config.patch(capture_scalar_outputs=True) 279 def test_item_nobreak(self, device): 280 @torch.compile(fullgraph=True) 281 def f(x): 282 y = x.item() 283 return torch.empty(y) 284 285 f(torch.tensor([3], device=device)) 286 287 @torch._dynamo.config.patch(capture_scalar_outputs=True) 288 def test_item_bool_nobreak(self, device): 289 @torch.compile(fullgraph=True) 290 def f(x): 291 return x.item() 292 293 f(torch.tensor([True], device=device)) 294 295 @torch._dynamo.config.patch( 296 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 297 ) 298 def test_noops_tensor_repropagate(self, device): 299 @torch.compile(fullgraph=True) 300 def f(x): 301 b = torch.ops.prims.convert_element_type.default(x, torch.int64) 302 r = b.nonzero() 303 return r * 2 304 305 f(torch.tensor([0, 4, 2, 0, 1], dtype=torch.int64, device=device)) 306 307 @torch._dynamo.config.patch(capture_scalar_outputs=True) 308 def test_item_zeros_nobreak(self, device): 309 @torch.compile(fullgraph=True) 310 def f(x): 311 y = x.item() 312 torch.empty(y) 313 # This will avoid a NopSchedulerNode 314 return x.new_zeros(y) 315 316 f(torch.tensor([3], device=device)) 317 318 @torch._dynamo.config.patch(capture_scalar_outputs=True) 319 def test_item_return(self, device): 320 @torch.compile(fullgraph=True) 321 def f(x): 322 y = x.item() 323 z = x.item() 324 return y + z 325 326 f(torch.tensor([3], device=device)) 327 328 @torch._dynamo.config.patch( 329 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 330 ) 331 def test_float_item_inf(self, device): 332 @torch.compile(fullgraph=True) 333 def f(x): 334 return x.item() == math.inf 335 336 f(torch.tensor([3.0], device=device)) 337 338 @torch._dynamo.config.patch( 339 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 340 ) 341 def test_float_item_neginf(self, device): 342 @torch.compile(fullgraph=True) 343 def f(x): 344 return x.item() == -math.inf 345 346 f(torch.tensor([3.0], device=device)) 347 348 @torch._dynamo.config.patch(capture_scalar_outputs=True) 349 @torch._inductor.config.patch(implicit_fallbacks=True) 350 def test_item_to_inputs_kernel_nobreak(self, device): 351 @torch.library.custom_op("test::foo", mutates_args=()) 352 def foo(x: torch.Tensor, y: int) -> torch.Tensor: 353 return x.clone() 354 355 @foo.register_fake 356 def _(x: torch.Tensor, y: int) -> torch.Tensor: 357 return x.clone() 358 359 @torch.compile(fullgraph=True) 360 def f(x, r): 361 y = x.item() 362 return torch.ops.test.foo(r, y) 363 364 f(torch.tensor([3], device=device), torch.randn(10, device=device)) 365 366 @unittest.skipUnless(IS_FBCODE, "") 367 @torch._dynamo.config.patch( 368 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 369 ) 370 def test_float_item_return(self, device): 371 @torch.compile(fullgraph=True) 372 def f(x): 373 return x.item() 374 375 f(torch.tensor([3.0], device=device)) 376 377 @unittest.skipIf(TEST_CUDA_MEM_LEAK_CHECK, "failing memory leak check") 378 @torch._dynamo.config.patch(capture_scalar_outputs=True) 379 def test_unbacked_index_select(self, device): 380 # Tests if unbacked symbols captured by inner_fn are properly tracked 381 def f(x): 382 y = x.item() 383 return torch.index_select( 384 torch.ones(y, device=device), 0, torch.tensor([0, 2, 1], device=device) 385 ) 386 387 cf = torch.compile(fullgraph=True)(f) 388 arg = torch.tensor(5, device=device) 389 self.assertEqual(f(arg), cf(arg)) 390 391 @torch._dynamo.config.patch( 392 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 393 ) 394 def test_return_unbacked_view_split(self, device): 395 def f(values, length_per_key): 396 u0, u1 = length_per_key.tolist() 397 torch._check_is_size(u0) 398 torch._check_is_size(u1) 399 v1, v2 = torch.functional.split(values, [u0, u1]) 400 return v1, v2 401 402 cf = torch.compile(fullgraph=True)(f) 403 args = ( 404 torch.randn(8, requires_grad=True, device=device), 405 torch.tensor([3, 5], device=device), 406 ) 407 self.assertEqual(f(*args), cf(*args)) 408 409 @torch._dynamo.config.patch(capture_scalar_outputs=True) 410 def test_unbacked_matmul(self, device): 411 def f(x): 412 y = x.item() 413 return torch.ones(1, y, device=device) @ torch.ones(y, 1, device=device) 414 415 cf = torch.compile(fullgraph=True)(f) 416 arg = torch.tensor(5, device=device) 417 self.assertEqual(f(arg), cf(arg)) 418 419 @torch._dynamo.config.patch( 420 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 421 ) 422 @torch._inductor.config.patch(implicit_fallbacks=True) 423 def test_unbacked_save_for_backwards(self, device) -> None: 424 @torch.library.custom_op("_test::_cat", mutates_args=()) 425 def _cat(t: torch.Tensor, ds: List[int]) -> torch.Tensor: 426 return t * t.new_ones([sum(ds)]) 427 428 @torch.library.register_fake("_test::_cat") 429 def _cat_fake(t: torch.Tensor, ds: List[int]) -> torch.Tensor: 430 [torch._check_is_size(d) for d in ds] 431 return t.new_empty([sum(ds)]) 432 433 def _cat_setup_context(ctx, inputs, output): 434 pass 435 436 def _cat_backward(ctx, grad): 437 return grad.sum(), None 438 439 torch.library.register_autograd( 440 "_test::_cat", 441 _cat_backward, 442 setup_context=_cat_setup_context, 443 ) 444 445 def fn(t, sizes): 446 r = torch.ops._test._cat(t, sizes.tolist()) 447 return r * t 448 449 t = torch.randn((), requires_grad=True, device=device) 450 sizes = torch.tensor([4, 8], dtype=torch.int64, device="cpu") 451 out = fn(t, sizes) 452 out.sum().backward() 453 expect = t.grad 454 t.grad = None 455 torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)( 456 t, sizes 457 ).sum().backward() 458 self.assertEqual(t.grad, expect) 459 460 @torch._dynamo.config.patch(capture_scalar_outputs=True) 461 def test_unbacked_reduction(self, device): 462 expect_fail = device == "cpu" and not IS_ARM64 463 try: 464 465 def f(x): 466 y = x.item() 467 return torch.ones(y, device=device).sum() 468 469 cf = torch.compile(fullgraph=True)(f) 470 arg = torch.tensor(5, device=device) 471 self.assertEqual(f(arg), cf(arg)) 472 except Exception: 473 if not expect_fail: 474 raise 475 else: 476 if expect_fail: 477 self.fail("expected to fail, but actually passed") 478 479 @torch._dynamo.config.patch( 480 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 481 ) 482 def test_cat_unbacked_duplicate_size(self, device): 483 def f(x): 484 device = x.device 485 s, s2 = x.tolist() 486 g = torch.zeros(s, device=device) 487 g2 = torch.ones(s2, device=device) 488 return torch.ops.aten.cat.default([g, g, g2]) 489 490 cf = torch.compile(fullgraph=True)(f) 491 arg = torch.tensor([4, 6], device=GPU_TYPE) 492 self.assertEqual(f(arg), cf(arg)) 493 494 @torch._dynamo.config.patch( 495 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 496 ) 497 def test_unbacked_cat_backwards(self, device): 498 def f(x, w): 499 device = w.device 500 a, b = x.tolist() 501 ta = torch.ones(a, device=device) 502 tb = torch.ones(b, device=device) 503 pa = ta * w # make it require gradients 504 pb = tb * w 505 r = torch.cat([pa, pb]) 506 return r.sum() 507 508 x = torch.tensor([4, 9]) 509 w = torch.randn(1, requires_grad=True) 510 f(x, w).backward() 511 orig_w = w.grad 512 w.grad = None 513 514 torch.compile(fullgraph=True)(f)(x, w).backward() 515 self.assertEqual(orig_w, w.grad) 516 517 @torch._dynamo.config.patch( 518 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 519 ) 520 def test_unbacked_cat_backwards_save_data_dependent(self, device): 521 def f(x, w): 522 device = w.device 523 a, b = x.tolist() 524 ta = torch.ones(a, device=device) 525 tb = torch.ones(b, device=device) 526 pa = ta * w # make it require gradients 527 pb = tb * w 528 r = torch.cat([pa, pb]) 529 return r 530 531 x = torch.tensor([4, 9]) 532 w = torch.randn(1, requires_grad=True) 533 f(x, w).sum().backward() 534 orig_w = w.grad 535 w.grad = None 536 537 torch.compile(fullgraph=True)(f)(x, w).sum().backward() 538 self.assertEqual(orig_w, w.grad) 539 540 @torch._dynamo.config.patch( 541 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 542 ) 543 @torch._inductor.config.patch(implicit_fallbacks=True) 544 def test_dynamic_stride_nobreak(self, device): 545 @torch.library.custom_op("test::foo", mutates_args=()) 546 def foo(x: torch.Tensor) -> torch.Tensor: 547 stride = x.item() 548 return torch.empty_strided((1,), (stride,), device=x.device) 549 550 @foo.register_fake 551 def _(x: torch.Tensor) -> torch.Tensor: 552 ctx = torch.library.get_ctx() 553 stride = ctx.new_dynamic_size() 554 return torch.empty_strided((1,), (stride,), device=x.device) 555 556 @torch.compile(fullgraph=True) 557 def f(x): 558 r = torch.ops.test.foo(x) 559 y = r.stride(0) 560 return torch.empty(y, device=x.device) 561 562 f(torch.tensor([3], device=device)) 563 564 @torch._dynamo.config.patch( 565 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 566 ) 567 @torch._inductor.config.patch(implicit_fallbacks=True) 568 def test_multi_output_unbacked_custom_op(self, device): 569 @torch.library.custom_op("test::foo", mutates_args=()) 570 def foo(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 571 return torch.empty(2, device=x.device), torch.empty(3, device=x.device) 572 573 @foo.register_fake 574 def _(x: torch.Tensor) -> torch.Tensor: 575 ctx = torch.library.get_ctx() 576 u0 = ctx.new_dynamic_size() 577 return torch.empty(u0, device=x.device), torch.empty(3, device=x.device) 578 579 @torch.compile(fullgraph=True) 580 def f(x): 581 a, b = torch.ops.test.foo(x) 582 return a.sum() + b.sum() 583 584 f(torch.tensor([3], device=device)) 585 586 @torch._inductor.config.patch(disable_cpp_codegen=True) 587 def test_floor(self): 588 # `int(n * 0.2)` will be generated as `floor(0.2*s0)` of torch.SymInt type. 589 # If cpp codegen is disabled, we should generate `math.floor` using PythonPrinter. 590 def fn(x): 591 n = x.size(-1) 592 y = x + int(n * 0.2) + 1 593 return y 594 595 opt = self.compile_fn(fn) 596 # The first run doesn't trigger dynamic shapes. 597 x0 = torch.rand(5) 598 ref0 = fn(x0) 599 res0 = opt(x0) 600 self.assertEqual(ref0, res0) 601 # The second run triggers dynamic shapes. 602 x1 = torch.rand(8) 603 ref1 = fn(x1) 604 res1 = opt(x1) 605 self.assertEqual(ref1, res1) 606 607 @onlyOn(GPU_TYPE) 608 def test_pad_dynamic(self, device): 609 def get_same_padding(x: int, k: int, s: int, d: int): 610 return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 611 612 def pad_same(x, k, s, d=(1, 1), value=0): 613 ih, iw = x.size()[-2:] 614 pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding( 615 iw, k[1], s[1], d[1] 616 ) 617 if pad_h > 0 or pad_w > 0: 618 x = torch.nn.functional.pad( 619 x, 620 [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], 621 value=value, 622 ) 623 return x 624 625 x = torch.randn(2, 24, 110, 110, device=device) 626 opt = self.compile_fn(pad_same) 627 res = opt(x, (5, 5), (2, 2)) 628 ref = pad_same(x, (5, 5), (2, 2)) 629 self.assertEqual(res, ref, atol=0, rtol=0) 630 631 def test_slice_scatter(self, device): 632 def fn(i): 633 s3 = i.size(0) 634 x = torch.ones(64, s3, device=device) 635 y = torch.ones(64, s3 // 2, device=device) 636 return torch.slice_scatter(x, y, 1, s3 // 2, 2 * (s3 // 2)) 637 638 a = torch.randn(16, device=device) 639 cfn = self.compile_fn(fn) 640 expect = fn(a) 641 actual = cfn(a) 642 self.assertEqual(expect, actual) 643 644 def test_slice_index_changing_sign(self, device): 645 def fn(x, y): 646 y0, y1 = y.shape 647 return x[: (y0 - y1)].clone() 648 649 a = torch.randn(32, 32, device=device) 650 cfn = self.compile_fn(fn) 651 652 # y0 > y1 -> y0 - y1 is positive 653 b = torch.randn(16, 2, device=device) 654 expect = fn(a, b) 655 actual = cfn(a, b) 656 self.assertEqual(expect, actual) 657 658 # y0 < y1 -> y0 - y1 is negative 659 b = torch.randn(2, 16, device=device) 660 expect = fn(a, b) 661 actual = cfn(a, b) 662 self.assertEqual(expect, actual) 663 664 def test_sym_stride_lowering(self, device): 665 def fn(x): 666 s0 = (x + 1).stride(0) 667 return x * s0 668 669 a = torch.randn(32, 32, device=device) 670 cfn = self.compile_fn(fn) 671 self.assertEqual(fn(a), cfn(a)) 672 673 @torch._dynamo.config.patch(capture_scalar_outputs=True) 674 def test_item_materialize(self, device): 675 def fn(x): 676 return x.sum(dim=0).view(4).tolist() 677 678 cfn = torch.compile(fullgraph=True)(fn) 679 680 a = torch.ones(3, 4, dtype=torch.int64, device=device) 681 self.assertEqual(cfn(a), fn(a)) 682 683 def test_abs(self, device): 684 def fn(x, y): 685 y0, y1 = y.shape 686 # Slicing checks abs in wrapper code, 687 # multiplication tests abs in kernel code 688 return x[: abs(y0 - y1)] * abs(y0 - y1) 689 690 a = torch.randn(32, 32, device=device) 691 cfn = self.compile_fn(fn) 692 693 # y0 > y1 -> y0 - y1 is positive 694 b = torch.randn(16, 2, device=device) 695 expect = fn(a, b) 696 actual = cfn(a, b) 697 self.assertEqual(expect, actual) 698 699 # y0 < y1 -> y0 - y1 is negative 700 b = torch.randn(2, 16, device=device) 701 expect = fn(a, b) 702 actual = cfn(a, b) 703 self.assertEqual(expect, actual) 704 705 def test_float_is_integer(self, device): 706 def fn(x, mul, dim=-1): 707 size = x.size(dim) 708 m = size / mul 709 if m.is_integer(): 710 return m 711 return size 712 713 a = torch.randn((3, 6, 4, 2), device=device) 714 cfn = self.compile_fn(fn) 715 716 expect = fn(a, 2) 717 actual = cfn(a, 2) 718 self.assertEqual(expect, actual) 719 720 @onlyCPU 721 def test_arithmetic_constant_folding(self, device): 722 def test(fn): 723 cfn = self.compile_fn(fn) 724 expect = fn(3) 725 actual = cfn(3) 726 self.assertEqual(expect, actual) 727 728 def add(x): 729 return x + torch.zeros(3) 730 731 test(add) 732 733 def mul(x): 734 return x * torch.ones(3) 735 736 test(mul) 737 738 def div(x): 739 return x / torch.ones(3) 740 741 test(div) 742 743 @onlyCPU 744 def test_sub_constant_folding(self, device): 745 def sub(x): 746 return x - torch.zeros(3) 747 748 cfn = self.compile_fn(sub) 749 expect = sub(3) 750 actual = cfn(3) 751 self.assertEqual(expect, actual) 752 753 def test_full_symbolic_value(self, device): 754 def fn(a): 755 return torch.full((3,), a), torch.full((3,), torch.sym_float(a)) 756 757 cfn = self.compile_fn(fn) 758 expect = fn(5) 759 actual = cfn(5) 760 self.assertEqual(expect, actual) 761 762 def test_interpolate_ceil_eq(self, device): 763 ceiling = math.ceil 764 IntTrueDiv = operator.truediv 765 766 def fn(t): 767 s0, s2, s3 = t.size() 768 x = torch.zeros( 769 ( 770 s0, 771 2048, 772 ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)), 773 ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)), 774 ), 775 dtype=torch.bfloat16, 776 ) 777 return torch.nn.functional.interpolate( 778 x, 779 scale_factor=2, 780 mode="nearest", 781 ) 782 783 cfn = self.compile_fn(fn) 784 arg = torch.randn(4, 16, 18) 785 expect = fn(arg) 786 actual = cfn(arg) 787 self.assertEqual(expect, actual) 788 789 def test_full_recompiles(self, device): 790 def fn(x): 791 _, L = x.shape 792 return torch.full((L, L), torch.finfo(torch.float16).min, device=device) 793 794 cfn = self.compile_fn(fn) 795 796 import functools 797 798 input_fn = functools.partial(torch.randint, 10, 1000, device=device) 799 800 cfn(input_fn((2, 3))) 801 cfn(input_fn((2, 4))) # expect don't recompile here 802 803 # check compiled times of frame 0 804 from torch._dynamo.convert_frame import FRAME_COMPILE_COUNTER 805 806 self.assertEqual(FRAME_COMPILE_COUNTER[0], 1) 807 808 @parametrize( 809 "op", 810 [ 811 math.sqrt, 812 math.sin, 813 math.cos, 814 math.cosh, 815 math.sin, 816 math.sinh, 817 math.tan, 818 math.tanh, 819 math.asin, 820 math.acos, 821 math.atan, 822 ], 823 ) 824 def test_math_ops(self, device, op): 825 def func(x, fn, a): 826 return x + fn(a) 827 828 cfunc = self.compile_fn(func, fullgraph=True) 829 x = torch.rand(10, device=device) 830 a = -1 if op in (math.asin, math.acos) else 12 831 expected = func(x, op, a) 832 output = cfunc(x, op, a) 833 self.assertEqual(output, expected) 834 835 def test_wrapper_codegen_statically_known_int_or_none(self): 836 torch._dynamo.reset() 837 838 _x = torch.randn([5, 3, 3]) 839 torch._dynamo.maybe_mark_dynamic(_x, 0) 840 841 # Simple functions introducing constraints on x.shape[0] 842 def fn_1(x): 843 # no constraint 844 return x.sin() 845 846 def fn_2(x): 847 # constrain in two directions 848 if x.shape[0] > 5: 849 return x.cos() 850 if x.shape[0] < 5: 851 return x * 2 852 # x.shape[0] == 5 at this point 853 return x.sin() 854 855 def fn_3(x): 856 # equality constraint, which matches example shape 857 if x.size(0) == 5: 858 return x.sin() 859 else: 860 return x.cos() 861 862 call_count = 0 863 864 def _test_wrapper_codegen_statically_known_int_or_none_in_context(): 865 nonlocal call_count 866 call_count += 1 867 graph = V.graph 868 input_layouts = [ 869 inp.layout 870 for inp in graph.graph_inputs.values() 871 if hasattr(inp, "layout") 872 ] 873 batch_dim = input_layouts[0].size[0] 874 if call_count == 1: 875 # testing fn_1 876 assert ( 877 WrapperCodeGen.statically_known_int_or_none(batch_dim) is None 878 ), "Should not be statically known on first call" 879 elif call_count == 2: 880 # testing fn_2 881 assert ( 882 WrapperCodeGen.statically_known_int_or_none(batch_dim) == 5 883 ), "Should be limited to exactly 5 on second call due to multiple constraints" 884 elif call_count == 2: 885 # testing fn_3 886 assert ( 887 WrapperCodeGen.statically_known_int_or_none(batch_dim) == 5 888 ), "Should be exactly 5 on third call" 889 890 class TestWrapperCodegen(WrapperCodeGen): 891 def __init__(self, *args, **kwargs): 892 super().__init__(*args, **kwargs) 893 894 def generate(self, is_inference, *args, **kwargs): 895 _test_wrapper_codegen_statically_known_int_or_none_in_context() 896 return super().generate(is_inference, *args, **kwargs) 897 898 if "cpu" not in device_codegens: 899 register_backend_for_device("cpu", CppScheduling, WrapperCodeGen) 900 orig_cpu_codegens = device_codegens["cpu"] 901 try: 902 register_backend_for_device( 903 "cpu", orig_cpu_codegens.scheduling, TestWrapperCodegen 904 ) 905 # Compile each of the functions above, with an example input 906 # that has 5 in the first dimension, but is marked as dynamic 907 908 torch.compile(backend="inductor", dynamic=None)(fn_1)(_x) 909 torch.compile(backend="inductor", dynamic=None)(fn_2)(_x) 910 torch.compile(backend="inductor", dynamic=None)(fn_3)(_x) 911 finally: 912 register_backend_for_device( 913 "cpu", orig_cpu_codegens.scheduling, orig_cpu_codegens.wrapper_codegen 914 ) 915 916 @torch._dynamo.config.patch(capture_scalar_outputs=True) 917 def test_item_unbacked_stride_nobreak(self, device): 918 @torch.compile(fullgraph=True, dynamic=True) 919 def f(x): 920 a = x.item() 921 torch._check_is_size(a) 922 torch._check(a >= 1) 923 torch._check(a <= 10) 924 return torch.ones(a, a) 925 926 f(torch.tensor([5], device=device)) 927 928 @torch._dynamo.config.patch(capture_scalar_outputs=True) 929 def test_symint_sum_list(self, device): 930 @torch.compile() 931 def f(xt): 932 xs = xt.tolist() 933 for x in xs: 934 torch._check_is_size(x) 935 y = sum(xs) 936 return torch.zeros(y, device=device) 937 938 f(torch.tensor([5] * 320)) 939 940 def test_sort_dynamic_shape_with_check(self, device): 941 if TEST_WITH_ROCM or torch.device(device).type != GPU_TYPE: 942 943 def check_count(n): 944 self.assertEqual(metrics.generated_kernel_count, 0) 945 946 else: 947 948 def check_count(n): 949 self.assertEqual(metrics.generated_kernel_count, n) 950 951 # Test dynamic shapes with statically known small enough to generate 952 # persistent sort kernel 953 def fn(a, descending): 954 torch._check(a.shape[-1] <= 256) 955 return a.sort(dim=-1, stable=True, descending=descending) 956 957 inp = torch.rand(10, 128, dtype=torch.float32, device=device) 958 inp[:, 10:20] = 1.0 959 inp[:, 30:40] = 1.0 960 metrics.reset() 961 962 opt_fn = torch.compile(fn, dynamic=True) 963 expect = fn(inp, False) 964 actual = opt_fn(inp, False) 965 self.assertEqual(actual, expect) 966 check_count(1) 967 968 expect = fn(inp, True) 969 actual = opt_fn(inp, True) 970 self.assertEqual(actual, expect) 971 check_count(2) 972 973 # Non-power of two 974 inp[:, :120] 975 976 expect = fn(inp, False) 977 actual = opt_fn(inp, False) 978 self.assertEqual(actual, expect) 979 check_count(2) # Reused existing kernel 980 981 expect = fn(inp, True) 982 actual = opt_fn(inp, True) 983 self.assertEqual(actual, expect) 984 check_count(2) # Reused existing kernel 985 986 987instantiate_device_type_tests(TestInductorDynamic, globals(), allow_xpu=True) 988 989if __name__ == "__main__": 990 from torch._inductor.test_case import run_tests 991 992 # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 993 if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: 994 run_tests(needs="filelock") 995