xref: /aosp_15_r20/external/pytorch/test/inductor/test_torchinductor_dynamic_shapes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3import importlib
4import math
5import operator
6import os
7import sys
8import unittest
9from functools import partial
10from typing import List, Tuple
11
12import torch
13import torch.library
14from torch._dynamo.testing import make_test_cls_with_patches
15from torch._inductor import metrics
16from torch._inductor.codegen.common import device_codegens, register_backend_for_device
17from torch._inductor.codegen.cpp import CppScheduling
18from torch._inductor.codegen.wrapper import WrapperCodeGen
19from torch._inductor.test_case import TestCase
20from torch._inductor.utils import run_and_get_code
21from torch._inductor.virtualized import V
22from torch.testing import FileCheck
23from torch.testing._internal.common_device_type import (
24    instantiate_device_type_tests,
25    onlyCPU,
26    onlyOn,
27)
28from torch.testing._internal.common_utils import (
29    IS_ARM64,
30    IS_FBCODE,
31    parametrize,
32    TEST_CUDA_MEM_LEAK_CHECK,
33    TEST_WITH_ASAN,
34    TEST_WITH_ROCM,
35)
36from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
37
38
39# Make the helper files in test/ importable
40pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
41sys.path.append(pytorch_test_dir)
42from inductor.test_torchinductor import (
43    check_model,
44    check_model_gpu,
45    CommonTemplate,
46    copy_tests,
47    TestFailure,
48)
49
50
51importlib.import_module("filelock")
52
53# xfail by default, set is_skip=True to skip
54test_failures = {
55    "test_kwargs_dynamic_shapes": TestFailure(("cpu",)),
56    # calling div on only symint args
57    "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure(
58        ("cpu", "cuda", "xpu")
59    ),
60    "test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")),
61}
62
63if TEST_WITH_ROCM:
64    # Tensor-likes are not close
65    test_failures["test_dynamic_stride_nobreak"] = TestFailure(
66        ("cpu", "cuda"), is_skip=True
67    )
68    test_failures["test_item_to_inputs_kernel_nobreak"] = TestFailure(
69        ("cpu", "cuda"), is_skip=True
70    )
71    test_failures["test_unbacked_reduction"] = TestFailure(("cpu"), is_skip=True)
72
73
74if os.getenv("BUILD_ENVIRONMENT", "").endswith("-debug"):
75    # Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073
76    test_failures["test_resize_as_dynamic_shapes"] = TestFailure(("cpu", "cuda"))
77    test_failures["test_resize_dynamic_shapes"] = TestFailure(("cpu", "cuda"))
78
79
80def make_dynamic_cls(cls, xfail_prop="_expected_failure_dynamic"):
81    return make_test_cls_with_patches(
82        cls,
83        "DynamicShapes",
84        "_dynamic_shapes",
85        (torch._dynamo.config, "assume_static_by_default", False),
86        xfail_prop=xfail_prop,
87    )
88
89
90DynamicShapesCommonTemplate = make_dynamic_cls(CommonTemplate)
91
92
93if HAS_CPU:
94
95    class DynamicShapesCpuTests(TestCase):
96        common = check_model
97        device = "cpu"
98
99    copy_tests(DynamicShapesCommonTemplate, DynamicShapesCpuTests, "cpu", test_failures)
100
101
102if HAS_GPU and not TEST_WITH_ASAN:
103
104    class DynamicShapesGPUTests(TestCase):
105        common = check_model_gpu
106        device = GPU_TYPE
107
108    copy_tests(
109        DynamicShapesCommonTemplate, DynamicShapesGPUTests, GPU_TYPE, test_failures
110    )
111
112
113class TestInductorDynamic(TestCase):
114    compile_fn = partial(torch.compile, dynamic=True)
115
116    def setUp(self):
117        # HAS_CUDA also checks compute capability to skip tests
118        # on older devices
119        if not HAS_GPU:
120            self.skipTest("Triton not available")
121        torch._dynamo.reset()
122        TestCase.setUp(self)
123        # this should be in setUpClass, but device-generic tests
124        # don't work with setUpClass well (non-deterministically the wrong setUpClass is resolved),
125        # so put it in test setUp, it's cheap
126        self._stack = contextlib.ExitStack()
127        self._stack.enter_context(
128            torch._inductor.config.patch(
129                {
130                    "debug": False,
131                    "cpp.min_chunk_size": 1,
132                    "triton.autotune_pointwise": False,  # too slow
133                    "implicit_fallbacks": False,
134                }
135            )
136        )
137
138    def tearDown(self):
139        self._stack.close()
140        TestCase.tearDown(self)
141        torch._dynamo.reset()
142
143    def test_constant_fold_uniform_value_dynamic(self, device):
144        def full_add_zero(x):
145            a = torch.full(x.shape, 1, dtype=x.dtype, device=x.device)
146            b = a - 1
147            return x + b
148
149        def full_mul_one(x):
150            a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device)
151            b = 2 + a
152            return x * b
153
154        def full_view_op(x):
155            a = torch.ones([1], dtype=x.dtype, device=x.device)
156            a = a[:, None]
157            return x * a
158
159        def full_mul_symint(x):
160            a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device)
161            b = 2 + a
162            return b * x.shape[0]
163
164        fns = (full_add_zero, full_mul_one, full_view_op)
165
166        x = torch.randn((2, 4), device=device)
167        y = torch.randn((3, 4), device=device)
168
169        for dynamic in [False, True]:
170            torch._dynamo.reset()
171            for fn in fns:
172                ref = fn(x)
173                fn_c = torch.compile(fn, dynamic=dynamic)
174
175                actual, source_codes = run_and_get_code(fn_c, x)
176
177                if fn is not full_mul_symint:
178                    # due to constant folding, fn returns x directly.
179                    if device == "cpu":
180                        FileCheck().check_not("cpp_fused").run(source_codes[0])
181                    else:
182                        FileCheck().check_not("triton.jit").run(source_codes[0])
183
184                self.assertEqual(ref, actual)
185                self.assertEqual(fn(x), fn_c(x))
186                self.assertEqual(fn(y), fn_c(y))
187
188    def test_arange_dynamic(self, device):
189        def fn(a):
190            batch_size = a.numel()
191            max_len = a.max()
192            return ~(
193                torch.arange(0, max_len, device=a.device)
194                .type_as(a)
195                .repeat(batch_size, 1)
196                .lt(a.unsqueeze(1))
197            )
198
199        a = torch.randint(10, 30, (10,), device=device)
200        a[0] = 29  # fix max_len
201        opt = self.compile_fn(fn)
202        res = opt(a)
203        ref = fn(a)
204        self.assertEqual(res, ref)
205
206    def test_shape_as_constant_reciprocal_float_exp(self, device):
207        def fn(x, a):
208            return x, -1 / a**1.0
209
210        x = torch.rand(10, 20, device=device)
211        opt = self.compile_fn(fn)
212        res = opt(x, x.size(0))
213        ref = fn(x, x.size(0))
214        self.assertEqual(res, ref)
215
216    # not supported yet on cpu, https://github.com/pytorch/pytorch/issues/109897
217    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
218    def test_bool_mask_nobreak(self, device):
219        def f(x, b):
220            return (x[b] * 2).sum()
221
222        opt_f = torch.compile(f, fullgraph=True)
223        x = torch.randn(5, device=device)
224        b = torch.tensor([True, True, False, False, True], device=device)
225        r = f(x, b)
226        opt_r = opt_f(x, b)
227        self.assertEqual(r, opt_r)
228
229    def test_adaptive_max_pool3d_with_indices(self, device):
230        x = 5
231        y = torch.rand([9, 10, 9, 8, 6], dtype=torch.float32, device=device)
232
233        def fn(x, y):
234            return torch.nn.functional.adaptive_max_pool3d_with_indices(
235                output_size=x, input=y, return_indices=True
236            )
237
238        opt_f = self.compile_fn(fn)
239        r = fn(x, y)
240        opt_r = opt_f(x, y)
241        self.assertEqual(r, opt_r)
242
243    @torch._dynamo.config.patch(capture_scalar_outputs=True)
244    def test_unwrap_storage_didnt_work_repro(self, device):
245        def f():
246            full = torch.full((), 11)
247            i0 = full.item()
248            torch._check_is_size(i0)
249            return torch.full((i0,), 0)
250
251        opt_f = torch.compile(f, fullgraph=True)
252        r = f()
253        opt_r = opt_f()
254        self.assertEqual(r, opt_r)
255
256    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
257    def test_nonzero_size_factory_nobreak(self, device):
258        def f(x, b):
259            y = torch.nonzero(b)
260            return x.new_zeros(y.size(0))
261
262        opt_f = torch.compile(f, fullgraph=True)
263        x = torch.randn(5, device=device)
264        b = torch.tensor([True, True, False, False, True], device=device)
265        r = f(x, b)
266        opt_r = opt_f(x, b)
267        self.assertEqual(r, opt_r)
268
269    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
270    def test_nonzero_no_realloc(self, device):
271        @torch.compile(fullgraph=True, dynamic=True)
272        def f(x, y):
273            z = x.nonzero()
274            return torch.split(z, [y.size(0)])
275
276        f(torch.tensor([1, 0, 1, 1, 0, 1, 0]), torch.randn(4))
277
278    @torch._dynamo.config.patch(capture_scalar_outputs=True)
279    def test_item_nobreak(self, device):
280        @torch.compile(fullgraph=True)
281        def f(x):
282            y = x.item()
283            return torch.empty(y)
284
285        f(torch.tensor([3], device=device))
286
287    @torch._dynamo.config.patch(capture_scalar_outputs=True)
288    def test_item_bool_nobreak(self, device):
289        @torch.compile(fullgraph=True)
290        def f(x):
291            return x.item()
292
293        f(torch.tensor([True], device=device))
294
295    @torch._dynamo.config.patch(
296        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
297    )
298    def test_noops_tensor_repropagate(self, device):
299        @torch.compile(fullgraph=True)
300        def f(x):
301            b = torch.ops.prims.convert_element_type.default(x, torch.int64)
302            r = b.nonzero()
303            return r * 2
304
305        f(torch.tensor([0, 4, 2, 0, 1], dtype=torch.int64, device=device))
306
307    @torch._dynamo.config.patch(capture_scalar_outputs=True)
308    def test_item_zeros_nobreak(self, device):
309        @torch.compile(fullgraph=True)
310        def f(x):
311            y = x.item()
312            torch.empty(y)
313            # This will avoid a NopSchedulerNode
314            return x.new_zeros(y)
315
316        f(torch.tensor([3], device=device))
317
318    @torch._dynamo.config.patch(capture_scalar_outputs=True)
319    def test_item_return(self, device):
320        @torch.compile(fullgraph=True)
321        def f(x):
322            y = x.item()
323            z = x.item()
324            return y + z
325
326        f(torch.tensor([3], device=device))
327
328    @torch._dynamo.config.patch(
329        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
330    )
331    def test_float_item_inf(self, device):
332        @torch.compile(fullgraph=True)
333        def f(x):
334            return x.item() == math.inf
335
336        f(torch.tensor([3.0], device=device))
337
338    @torch._dynamo.config.patch(
339        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
340    )
341    def test_float_item_neginf(self, device):
342        @torch.compile(fullgraph=True)
343        def f(x):
344            return x.item() == -math.inf
345
346        f(torch.tensor([3.0], device=device))
347
348    @torch._dynamo.config.patch(capture_scalar_outputs=True)
349    @torch._inductor.config.patch(implicit_fallbacks=True)
350    def test_item_to_inputs_kernel_nobreak(self, device):
351        @torch.library.custom_op("test::foo", mutates_args=())
352        def foo(x: torch.Tensor, y: int) -> torch.Tensor:
353            return x.clone()
354
355        @foo.register_fake
356        def _(x: torch.Tensor, y: int) -> torch.Tensor:
357            return x.clone()
358
359        @torch.compile(fullgraph=True)
360        def f(x, r):
361            y = x.item()
362            return torch.ops.test.foo(r, y)
363
364        f(torch.tensor([3], device=device), torch.randn(10, device=device))
365
366    @unittest.skipUnless(IS_FBCODE, "")
367    @torch._dynamo.config.patch(
368        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
369    )
370    def test_float_item_return(self, device):
371        @torch.compile(fullgraph=True)
372        def f(x):
373            return x.item()
374
375        f(torch.tensor([3.0], device=device))
376
377    @unittest.skipIf(TEST_CUDA_MEM_LEAK_CHECK, "failing memory leak check")
378    @torch._dynamo.config.patch(capture_scalar_outputs=True)
379    def test_unbacked_index_select(self, device):
380        # Tests if unbacked symbols captured by inner_fn are properly tracked
381        def f(x):
382            y = x.item()
383            return torch.index_select(
384                torch.ones(y, device=device), 0, torch.tensor([0, 2, 1], device=device)
385            )
386
387        cf = torch.compile(fullgraph=True)(f)
388        arg = torch.tensor(5, device=device)
389        self.assertEqual(f(arg), cf(arg))
390
391    @torch._dynamo.config.patch(
392        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
393    )
394    def test_return_unbacked_view_split(self, device):
395        def f(values, length_per_key):
396            u0, u1 = length_per_key.tolist()
397            torch._check_is_size(u0)
398            torch._check_is_size(u1)
399            v1, v2 = torch.functional.split(values, [u0, u1])
400            return v1, v2
401
402        cf = torch.compile(fullgraph=True)(f)
403        args = (
404            torch.randn(8, requires_grad=True, device=device),
405            torch.tensor([3, 5], device=device),
406        )
407        self.assertEqual(f(*args), cf(*args))
408
409    @torch._dynamo.config.patch(capture_scalar_outputs=True)
410    def test_unbacked_matmul(self, device):
411        def f(x):
412            y = x.item()
413            return torch.ones(1, y, device=device) @ torch.ones(y, 1, device=device)
414
415        cf = torch.compile(fullgraph=True)(f)
416        arg = torch.tensor(5, device=device)
417        self.assertEqual(f(arg), cf(arg))
418
419    @torch._dynamo.config.patch(
420        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
421    )
422    @torch._inductor.config.patch(implicit_fallbacks=True)
423    def test_unbacked_save_for_backwards(self, device) -> None:
424        @torch.library.custom_op("_test::_cat", mutates_args=())
425        def _cat(t: torch.Tensor, ds: List[int]) -> torch.Tensor:
426            return t * t.new_ones([sum(ds)])
427
428        @torch.library.register_fake("_test::_cat")
429        def _cat_fake(t: torch.Tensor, ds: List[int]) -> torch.Tensor:
430            [torch._check_is_size(d) for d in ds]
431            return t.new_empty([sum(ds)])
432
433        def _cat_setup_context(ctx, inputs, output):
434            pass
435
436        def _cat_backward(ctx, grad):
437            return grad.sum(), None
438
439        torch.library.register_autograd(
440            "_test::_cat",
441            _cat_backward,
442            setup_context=_cat_setup_context,
443        )
444
445        def fn(t, sizes):
446            r = torch.ops._test._cat(t, sizes.tolist())
447            return r * t
448
449        t = torch.randn((), requires_grad=True, device=device)
450        sizes = torch.tensor([4, 8], dtype=torch.int64, device="cpu")
451        out = fn(t, sizes)
452        out.sum().backward()
453        expect = t.grad
454        t.grad = None
455        torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)(
456            t, sizes
457        ).sum().backward()
458        self.assertEqual(t.grad, expect)
459
460    @torch._dynamo.config.patch(capture_scalar_outputs=True)
461    def test_unbacked_reduction(self, device):
462        expect_fail = device == "cpu" and not IS_ARM64
463        try:
464
465            def f(x):
466                y = x.item()
467                return torch.ones(y, device=device).sum()
468
469            cf = torch.compile(fullgraph=True)(f)
470            arg = torch.tensor(5, device=device)
471            self.assertEqual(f(arg), cf(arg))
472        except Exception:
473            if not expect_fail:
474                raise
475        else:
476            if expect_fail:
477                self.fail("expected to fail, but actually passed")
478
479    @torch._dynamo.config.patch(
480        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
481    )
482    def test_cat_unbacked_duplicate_size(self, device):
483        def f(x):
484            device = x.device
485            s, s2 = x.tolist()
486            g = torch.zeros(s, device=device)
487            g2 = torch.ones(s2, device=device)
488            return torch.ops.aten.cat.default([g, g, g2])
489
490        cf = torch.compile(fullgraph=True)(f)
491        arg = torch.tensor([4, 6], device=GPU_TYPE)
492        self.assertEqual(f(arg), cf(arg))
493
494    @torch._dynamo.config.patch(
495        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
496    )
497    def test_unbacked_cat_backwards(self, device):
498        def f(x, w):
499            device = w.device
500            a, b = x.tolist()
501            ta = torch.ones(a, device=device)
502            tb = torch.ones(b, device=device)
503            pa = ta * w  # make it require gradients
504            pb = tb * w
505            r = torch.cat([pa, pb])
506            return r.sum()
507
508        x = torch.tensor([4, 9])
509        w = torch.randn(1, requires_grad=True)
510        f(x, w).backward()
511        orig_w = w.grad
512        w.grad = None
513
514        torch.compile(fullgraph=True)(f)(x, w).backward()
515        self.assertEqual(orig_w, w.grad)
516
517    @torch._dynamo.config.patch(
518        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
519    )
520    def test_unbacked_cat_backwards_save_data_dependent(self, device):
521        def f(x, w):
522            device = w.device
523            a, b = x.tolist()
524            ta = torch.ones(a, device=device)
525            tb = torch.ones(b, device=device)
526            pa = ta * w  # make it require gradients
527            pb = tb * w
528            r = torch.cat([pa, pb])
529            return r
530
531        x = torch.tensor([4, 9])
532        w = torch.randn(1, requires_grad=True)
533        f(x, w).sum().backward()
534        orig_w = w.grad
535        w.grad = None
536
537        torch.compile(fullgraph=True)(f)(x, w).sum().backward()
538        self.assertEqual(orig_w, w.grad)
539
540    @torch._dynamo.config.patch(
541        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
542    )
543    @torch._inductor.config.patch(implicit_fallbacks=True)
544    def test_dynamic_stride_nobreak(self, device):
545        @torch.library.custom_op("test::foo", mutates_args=())
546        def foo(x: torch.Tensor) -> torch.Tensor:
547            stride = x.item()
548            return torch.empty_strided((1,), (stride,), device=x.device)
549
550        @foo.register_fake
551        def _(x: torch.Tensor) -> torch.Tensor:
552            ctx = torch.library.get_ctx()
553            stride = ctx.new_dynamic_size()
554            return torch.empty_strided((1,), (stride,), device=x.device)
555
556        @torch.compile(fullgraph=True)
557        def f(x):
558            r = torch.ops.test.foo(x)
559            y = r.stride(0)
560            return torch.empty(y, device=x.device)
561
562        f(torch.tensor([3], device=device))
563
564    @torch._dynamo.config.patch(
565        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
566    )
567    @torch._inductor.config.patch(implicit_fallbacks=True)
568    def test_multi_output_unbacked_custom_op(self, device):
569        @torch.library.custom_op("test::foo", mutates_args=())
570        def foo(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
571            return torch.empty(2, device=x.device), torch.empty(3, device=x.device)
572
573        @foo.register_fake
574        def _(x: torch.Tensor) -> torch.Tensor:
575            ctx = torch.library.get_ctx()
576            u0 = ctx.new_dynamic_size()
577            return torch.empty(u0, device=x.device), torch.empty(3, device=x.device)
578
579        @torch.compile(fullgraph=True)
580        def f(x):
581            a, b = torch.ops.test.foo(x)
582            return a.sum() + b.sum()
583
584        f(torch.tensor([3], device=device))
585
586    @torch._inductor.config.patch(disable_cpp_codegen=True)
587    def test_floor(self):
588        # `int(n * 0.2)` will be generated as `floor(0.2*s0)` of torch.SymInt type.
589        # If cpp codegen is disabled, we should generate `math.floor` using PythonPrinter.
590        def fn(x):
591            n = x.size(-1)
592            y = x + int(n * 0.2) + 1
593            return y
594
595        opt = self.compile_fn(fn)
596        # The first run doesn't trigger dynamic shapes.
597        x0 = torch.rand(5)
598        ref0 = fn(x0)
599        res0 = opt(x0)
600        self.assertEqual(ref0, res0)
601        # The second run triggers dynamic shapes.
602        x1 = torch.rand(8)
603        ref1 = fn(x1)
604        res1 = opt(x1)
605        self.assertEqual(ref1, res1)
606
607    @onlyOn(GPU_TYPE)
608    def test_pad_dynamic(self, device):
609        def get_same_padding(x: int, k: int, s: int, d: int):
610            return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
611
612        def pad_same(x, k, s, d=(1, 1), value=0):
613            ih, iw = x.size()[-2:]
614            pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(
615                iw, k[1], s[1], d[1]
616            )
617            if pad_h > 0 or pad_w > 0:
618                x = torch.nn.functional.pad(
619                    x,
620                    [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
621                    value=value,
622                )
623            return x
624
625        x = torch.randn(2, 24, 110, 110, device=device)
626        opt = self.compile_fn(pad_same)
627        res = opt(x, (5, 5), (2, 2))
628        ref = pad_same(x, (5, 5), (2, 2))
629        self.assertEqual(res, ref, atol=0, rtol=0)
630
631    def test_slice_scatter(self, device):
632        def fn(i):
633            s3 = i.size(0)
634            x = torch.ones(64, s3, device=device)
635            y = torch.ones(64, s3 // 2, device=device)
636            return torch.slice_scatter(x, y, 1, s3 // 2, 2 * (s3 // 2))
637
638        a = torch.randn(16, device=device)
639        cfn = self.compile_fn(fn)
640        expect = fn(a)
641        actual = cfn(a)
642        self.assertEqual(expect, actual)
643
644    def test_slice_index_changing_sign(self, device):
645        def fn(x, y):
646            y0, y1 = y.shape
647            return x[: (y0 - y1)].clone()
648
649        a = torch.randn(32, 32, device=device)
650        cfn = self.compile_fn(fn)
651
652        # y0 > y1 -> y0 - y1 is positive
653        b = torch.randn(16, 2, device=device)
654        expect = fn(a, b)
655        actual = cfn(a, b)
656        self.assertEqual(expect, actual)
657
658        # y0 < y1 -> y0 - y1 is negative
659        b = torch.randn(2, 16, device=device)
660        expect = fn(a, b)
661        actual = cfn(a, b)
662        self.assertEqual(expect, actual)
663
664    def test_sym_stride_lowering(self, device):
665        def fn(x):
666            s0 = (x + 1).stride(0)
667            return x * s0
668
669        a = torch.randn(32, 32, device=device)
670        cfn = self.compile_fn(fn)
671        self.assertEqual(fn(a), cfn(a))
672
673    @torch._dynamo.config.patch(capture_scalar_outputs=True)
674    def test_item_materialize(self, device):
675        def fn(x):
676            return x.sum(dim=0).view(4).tolist()
677
678        cfn = torch.compile(fullgraph=True)(fn)
679
680        a = torch.ones(3, 4, dtype=torch.int64, device=device)
681        self.assertEqual(cfn(a), fn(a))
682
683    def test_abs(self, device):
684        def fn(x, y):
685            y0, y1 = y.shape
686            # Slicing checks abs in wrapper code,
687            # multiplication tests abs in kernel code
688            return x[: abs(y0 - y1)] * abs(y0 - y1)
689
690        a = torch.randn(32, 32, device=device)
691        cfn = self.compile_fn(fn)
692
693        # y0 > y1 -> y0 - y1 is positive
694        b = torch.randn(16, 2, device=device)
695        expect = fn(a, b)
696        actual = cfn(a, b)
697        self.assertEqual(expect, actual)
698
699        # y0 < y1 -> y0 - y1 is negative
700        b = torch.randn(2, 16, device=device)
701        expect = fn(a, b)
702        actual = cfn(a, b)
703        self.assertEqual(expect, actual)
704
705    def test_float_is_integer(self, device):
706        def fn(x, mul, dim=-1):
707            size = x.size(dim)
708            m = size / mul
709            if m.is_integer():
710                return m
711            return size
712
713        a = torch.randn((3, 6, 4, 2), device=device)
714        cfn = self.compile_fn(fn)
715
716        expect = fn(a, 2)
717        actual = cfn(a, 2)
718        self.assertEqual(expect, actual)
719
720    @onlyCPU
721    def test_arithmetic_constant_folding(self, device):
722        def test(fn):
723            cfn = self.compile_fn(fn)
724            expect = fn(3)
725            actual = cfn(3)
726            self.assertEqual(expect, actual)
727
728        def add(x):
729            return x + torch.zeros(3)
730
731        test(add)
732
733        def mul(x):
734            return x * torch.ones(3)
735
736        test(mul)
737
738        def div(x):
739            return x / torch.ones(3)
740
741        test(div)
742
743    @onlyCPU
744    def test_sub_constant_folding(self, device):
745        def sub(x):
746            return x - torch.zeros(3)
747
748        cfn = self.compile_fn(sub)
749        expect = sub(3)
750        actual = cfn(3)
751        self.assertEqual(expect, actual)
752
753    def test_full_symbolic_value(self, device):
754        def fn(a):
755            return torch.full((3,), a), torch.full((3,), torch.sym_float(a))
756
757        cfn = self.compile_fn(fn)
758        expect = fn(5)
759        actual = cfn(5)
760        self.assertEqual(expect, actual)
761
762    def test_interpolate_ceil_eq(self, device):
763        ceiling = math.ceil
764        IntTrueDiv = operator.truediv
765
766        def fn(t):
767            s0, s2, s3 = t.size()
768            x = torch.zeros(
769                (
770                    s0,
771                    2048,
772                    ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)),
773                    ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)),
774                ),
775                dtype=torch.bfloat16,
776            )
777            return torch.nn.functional.interpolate(
778                x,
779                scale_factor=2,
780                mode="nearest",
781            )
782
783        cfn = self.compile_fn(fn)
784        arg = torch.randn(4, 16, 18)
785        expect = fn(arg)
786        actual = cfn(arg)
787        self.assertEqual(expect, actual)
788
789    def test_full_recompiles(self, device):
790        def fn(x):
791            _, L = x.shape
792            return torch.full((L, L), torch.finfo(torch.float16).min, device=device)
793
794        cfn = self.compile_fn(fn)
795
796        import functools
797
798        input_fn = functools.partial(torch.randint, 10, 1000, device=device)
799
800        cfn(input_fn((2, 3)))
801        cfn(input_fn((2, 4)))  # expect don't recompile here
802
803        # check compiled times of frame 0
804        from torch._dynamo.convert_frame import FRAME_COMPILE_COUNTER
805
806        self.assertEqual(FRAME_COMPILE_COUNTER[0], 1)
807
808    @parametrize(
809        "op",
810        [
811            math.sqrt,
812            math.sin,
813            math.cos,
814            math.cosh,
815            math.sin,
816            math.sinh,
817            math.tan,
818            math.tanh,
819            math.asin,
820            math.acos,
821            math.atan,
822        ],
823    )
824    def test_math_ops(self, device, op):
825        def func(x, fn, a):
826            return x + fn(a)
827
828        cfunc = self.compile_fn(func, fullgraph=True)
829        x = torch.rand(10, device=device)
830        a = -1 if op in (math.asin, math.acos) else 12
831        expected = func(x, op, a)
832        output = cfunc(x, op, a)
833        self.assertEqual(output, expected)
834
835    def test_wrapper_codegen_statically_known_int_or_none(self):
836        torch._dynamo.reset()
837
838        _x = torch.randn([5, 3, 3])
839        torch._dynamo.maybe_mark_dynamic(_x, 0)
840
841        # Simple functions introducing constraints on x.shape[0]
842        def fn_1(x):
843            # no constraint
844            return x.sin()
845
846        def fn_2(x):
847            # constrain in two directions
848            if x.shape[0] > 5:
849                return x.cos()
850            if x.shape[0] < 5:
851                return x * 2
852            # x.shape[0] == 5 at this point
853            return x.sin()
854
855        def fn_3(x):
856            # equality constraint, which matches example shape
857            if x.size(0) == 5:
858                return x.sin()
859            else:
860                return x.cos()
861
862        call_count = 0
863
864        def _test_wrapper_codegen_statically_known_int_or_none_in_context():
865            nonlocal call_count
866            call_count += 1
867            graph = V.graph
868            input_layouts = [
869                inp.layout
870                for inp in graph.graph_inputs.values()
871                if hasattr(inp, "layout")
872            ]
873            batch_dim = input_layouts[0].size[0]
874            if call_count == 1:
875                # testing fn_1
876                assert (
877                    WrapperCodeGen.statically_known_int_or_none(batch_dim) is None
878                ), "Should not be statically known on first call"
879            elif call_count == 2:
880                # testing fn_2
881                assert (
882                    WrapperCodeGen.statically_known_int_or_none(batch_dim) == 5
883                ), "Should be limited to exactly 5 on second call due to multiple constraints"
884            elif call_count == 2:
885                # testing fn_3
886                assert (
887                    WrapperCodeGen.statically_known_int_or_none(batch_dim) == 5
888                ), "Should be exactly 5 on third call"
889
890        class TestWrapperCodegen(WrapperCodeGen):
891            def __init__(self, *args, **kwargs):
892                super().__init__(*args, **kwargs)
893
894            def generate(self, is_inference, *args, **kwargs):
895                _test_wrapper_codegen_statically_known_int_or_none_in_context()
896                return super().generate(is_inference, *args, **kwargs)
897
898        if "cpu" not in device_codegens:
899            register_backend_for_device("cpu", CppScheduling, WrapperCodeGen)
900        orig_cpu_codegens = device_codegens["cpu"]
901        try:
902            register_backend_for_device(
903                "cpu", orig_cpu_codegens.scheduling, TestWrapperCodegen
904            )
905            # Compile each of the functions above, with an example input
906            # that has 5 in the first dimension, but is marked as dynamic
907
908            torch.compile(backend="inductor", dynamic=None)(fn_1)(_x)
909            torch.compile(backend="inductor", dynamic=None)(fn_2)(_x)
910            torch.compile(backend="inductor", dynamic=None)(fn_3)(_x)
911        finally:
912            register_backend_for_device(
913                "cpu", orig_cpu_codegens.scheduling, orig_cpu_codegens.wrapper_codegen
914            )
915
916    @torch._dynamo.config.patch(capture_scalar_outputs=True)
917    def test_item_unbacked_stride_nobreak(self, device):
918        @torch.compile(fullgraph=True, dynamic=True)
919        def f(x):
920            a = x.item()
921            torch._check_is_size(a)
922            torch._check(a >= 1)
923            torch._check(a <= 10)
924            return torch.ones(a, a)
925
926        f(torch.tensor([5], device=device))
927
928    @torch._dynamo.config.patch(capture_scalar_outputs=True)
929    def test_symint_sum_list(self, device):
930        @torch.compile()
931        def f(xt):
932            xs = xt.tolist()
933            for x in xs:
934                torch._check_is_size(x)
935            y = sum(xs)
936            return torch.zeros(y, device=device)
937
938        f(torch.tensor([5] * 320))
939
940    def test_sort_dynamic_shape_with_check(self, device):
941        if TEST_WITH_ROCM or torch.device(device).type != GPU_TYPE:
942
943            def check_count(n):
944                self.assertEqual(metrics.generated_kernel_count, 0)
945
946        else:
947
948            def check_count(n):
949                self.assertEqual(metrics.generated_kernel_count, n)
950
951        # Test dynamic shapes with statically known small enough to generate
952        # persistent sort kernel
953        def fn(a, descending):
954            torch._check(a.shape[-1] <= 256)
955            return a.sort(dim=-1, stable=True, descending=descending)
956
957        inp = torch.rand(10, 128, dtype=torch.float32, device=device)
958        inp[:, 10:20] = 1.0
959        inp[:, 30:40] = 1.0
960        metrics.reset()
961
962        opt_fn = torch.compile(fn, dynamic=True)
963        expect = fn(inp, False)
964        actual = opt_fn(inp, False)
965        self.assertEqual(actual, expect)
966        check_count(1)
967
968        expect = fn(inp, True)
969        actual = opt_fn(inp, True)
970        self.assertEqual(actual, expect)
971        check_count(2)
972
973        # Non-power of two
974        inp[:, :120]
975
976        expect = fn(inp, False)
977        actual = opt_fn(inp, False)
978        self.assertEqual(actual, expect)
979        check_count(2)  # Reused existing kernel
980
981        expect = fn(inp, True)
982        actual = opt_fn(inp, True)
983        self.assertEqual(actual, expect)
984        check_count(2)  # Reused existing kernel
985
986
987instantiate_device_type_tests(TestInductorDynamic, globals(), allow_xpu=True)
988
989if __name__ == "__main__":
990    from torch._inductor.test_case import run_tests
991
992    # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068
993    if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN:
994        run_tests(needs="filelock")
995