xref: /aosp_15_r20/external/pytorch/test/inductor/test_cuda_repro.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import gc
3import math
4import sys
5import unittest
6
7import torch
8import torch._dynamo.config as dynamo_config
9import torch.backends.cuda
10import torch.nn.functional as F
11from torch import nn
12from torch._dynamo.debug_utils import same_two_models
13from torch._dynamo.testing import rand_strided
14from torch._dynamo.utils import same
15from torch._inductor import config
16from torch._inductor.compile_fx import compile_fx_inner
17from torch._inductor.runtime.hints import DeviceProperties
18from torch._inductor.utils import (
19    run_and_get_code,
20    run_and_get_graph_lowering,
21    run_fw_bw_and_get_code,
22)
23from torch.fx.experimental.proxy_tensor import make_fx
24from torch.testing import FileCheck
25from torch.testing._internal.common_cuda import (
26    PLATFORM_SUPPORTS_FLASH_ATTENTION,
27    SM80OrLater,
28)
29from torch.testing._internal.common_utils import (
30    DeterministicGuard,
31    freeze_rng_state,
32    IS_FBCODE,
33    skipIfRocm,
34    TEST_WITH_ASAN,
35)
36from torch.testing._internal.inductor_utils import skipCUDAIf
37
38
39try:
40    try:
41        import triton
42        from triton import language as tl
43    except ImportError:
44        raise unittest.SkipTest("requires triton")  # noqa: B904
45
46    try:
47        from . import test_torchinductor
48    except ImportError:
49        import test_torchinductor
50except unittest.SkipTest:
51    if __name__ == "__main__":
52        sys.exit(0)
53    raise
54
55
56TestCase = test_torchinductor.TestCase
57ToTuple = test_torchinductor.ToTuple
58check_model_cuda = test_torchinductor.check_model_cuda
59aten = torch.ops.aten
60
61
62class CudaReproTests(TestCase):
63    device = "cuda"
64    common = check_model_cuda
65
66    def test_index_put_issue(self):
67        def forward(
68            self,
69            arg76_1,
70            expand_default,
71            full_like_default,
72            _to_copy_default_67,
73            zeros,
74        ):
75            sum_sym_int_19 = torch.ops.aten.sum(_to_copy_default_67, [0], True)
76            view_default_57 = torch.ops.aten.view.default(sum_sym_int_19, [512, 768])
77            where_self = torch.ops.aten.where.self(
78                expand_default, view_default_57, full_like_default
79            )
80            clone_default_12 = torch.ops.aten.clone.default(zeros)
81            index_put__default = torch.ops.aten.index_put_.default(
82                clone_default_12, [arg76_1], where_self, True
83            )
84            return (index_put__default,)
85
86        inps = [
87            (torch.Size([512]), torch.int64),
88            (torch.Size([512, 768]), torch.bool),
89            (torch.Size([512, 768]), torch.float16),
90            (torch.Size([4, 512, 768]), torch.float16),
91            (torch.Size([512, 768]), torch.float16),
92        ]
93        inps = [torch.zeros(())] + [
94            torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps
95        ]
96        mod = make_fx(forward)(*inps)
97        compiled = compile_fx_inner(mod, inps)
98        compiled(inps)
99
100    @skipIfRocm
101    def test_input_channels_last(self):
102        m = torch.nn.Sequential(
103            torch.nn.Conv2d(3, 3, 1, 1),
104            ToTuple(),
105        ).cuda()
106        inp = torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda()
107
108        self.common(
109            m,
110            (inp,),
111            check_lowp=False,
112        )
113
114        @torch._dynamo.optimize()
115        def foo(m, inp):
116            return m(inp)
117
118        self.assertTrue(foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last))
119
120    # https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527
121    def test_unspec_inputs_interop(self):
122        class Repro(torch.nn.Module):
123            def forward(self, x, y):
124                unsqueeze = torch.ops.aten.unsqueeze.default(x, 4)
125                permute = torch.ops.aten.permute.default(unsqueeze, [0, 1, 2, 4, 3])
126                add = torch.ops.aten.add.Tensor(y, 1)
127                return [permute, add]
128
129        inps = [
130            rand_strided((12, 3, 512, 64), (64, 196608, 768, 1), torch.float32, "cuda"),
131            rand_strided((), (), torch.int64, "cpu"),
132        ]
133        mod = make_fx(Repro().to(device="cuda"))(*inps)
134        compiled = compile_fx_inner(mod, inps)
135        compiled(inps)
136
137    @unittest.skipIf(
138        IS_FBCODE, "RuntimeError: Triton Error [CUDA]: invalid device context"
139    )
140    def test_backward_context(self):
141        def fn(x):
142            return x * 3
143
144        x = torch.randn(4, device="cuda", requires_grad=True)
145        gO = torch.rand_like(x)
146        opt_fn = torch.compile(fn)
147        out = opt_fn(x)
148        out.backward(gO)
149
150    @config.patch(fallback_random=True)
151    def test_dtype_factory_issue(self):
152        def forward():
153            randn = torch.ops.aten.randn.default(
154                [12, 64, 1, 64],
155                dtype=torch.float32,
156                device=torch.device(type="cuda", index=0),
157                pin_memory=False,
158            )
159            unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1)
160            return (unsqueeze_default_2,)
161
162        mod = make_fx(forward)()
163        compiled = compile_fx_inner(mod, ())
164        assert compiled([])[0].device.type == "cuda"
165
166    @config.patch({"triton.cudagraphs": True})
167    @dynamo_config.patch(automatic_dynamic_shapes=True)
168    def test_no_device_idx_repro_cudagraphs(self):
169        class Repro(torch.nn.Module):
170            def __init__(self) -> None:
171                super().__init__()
172
173            def forward(self):
174                full = torch.ops.aten.full.default(
175                    [8, 512],
176                    1,
177                    dtype=torch.float32,
178                    layout=torch.strided,
179                    device=torch.device(type="cuda", index=0),
180                    pin_memory=False,
181                )
182                full_1 = torch.ops.aten.full.default(
183                    [8, 512],
184                    0,
185                    dtype=torch.int64,
186                    layout=torch.strided,
187                    device=torch.device(type="cuda", index=0),
188                    pin_memory=False,
189                )
190                return (full_1, full)
191
192        self.common(Repro(), ())
193
194    @config.patch({"triton.cudagraphs": True})
195    @dynamo_config.patch(automatic_dynamic_shapes=True)
196    def test_expanded_inputs_cudagraphs(self):
197        @torch._dynamo.optimize("inductor")
198        def fn(x, y):
199            return x + y
200
201        inputs = (
202            rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
203            rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
204        )
205        self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
206
207    @config.patch({"triton.cudagraphs": True})
208    @dynamo_config.patch(
209        automatic_dynamic_shapes=True,
210        assume_static_by_default=False,
211    )
212    def test_dynamic_to_static_cudagraphs(self):
213        for b in [False, True]:
214            with config.patch({"triton.cudagraph_trees": b}):
215
216                @torch._dynamo.optimize("inductor")
217                def fn(x, y):
218                    r = x + y
219                    return r, r.size(0)
220
221                inputs = (
222                    torch.randn((5, 5), device="cuda"),
223                    torch.randn((5, 5), device="cuda"),
224                )
225                self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 5)))
226
227                inputs = (
228                    torch.randn((6, 6), device="cuda"),
229                    torch.randn((6, 6), device="cuda"),
230                )
231                self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 6)))
232
233    @config.patch({"emulate_precision_casts": True})
234    def test_emulate_low_precision(self):
235        def foo(x):
236            return torch.nn.functional.gelu(x) * 10.0
237
238        inp = torch.rand([32], device="cuda", requires_grad=True, dtype=torch.bfloat16)
239        out, codes = run_fw_bw_and_get_code(lambda: torch.compile(foo)(inp))
240
241        # fwd, backward
242        for code in codes:
243            f = FileCheck()
244            # in eager, there are two down casts
245            for _ in range(2):
246                f.check(".to(tl.bfloat16)").check_next(".to(tl.float32)")
247            f.run(code)
248
249        self.assertEqual(foo(inp), out)
250
251    # TODO: Abstract this out, test more extensively
252    @torch._dynamo.config.patch(assume_static_by_default=False)
253    def test_dynamic_shapes(self):
254        torch._dynamo.reset()  # Needed since everywhere else uses "inductor"
255
256        def f(x):
257            return x.cos().view(x.shape).sin()
258
259        cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
260
261        f2 = torch._dynamo.optimize(cnts)(f)
262
263        f2(torch.randn(32))
264
265        inp = torch.randn(16)
266        real_out = f(inp)
267        compiled_out = f2(inp)
268
269        self.assertEqual(cnts.frame_count, 1)
270        self.assertEqual(real_out, compiled_out)
271        torch._dynamo.reset()
272
273    @config.patch({"triton.cudagraphs": True, "size_asserts": False})
274    @dynamo_config.patch(automatic_dynamic_shapes=True)
275    def test_expanded_inputs_cudagraphs_no_size_asserts(self):
276        @torch._dynamo.optimize("inductor")
277        def fn(x, y):
278            return x + y
279
280        inputs = (
281            rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
282            rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
283        )
284        self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
285
286    @config.patch({"triton.cudagraph_trees": False})
287    @config.patch({"triton.cudagraphs": True})
288    @dynamo_config.patch(automatic_dynamic_shapes=True)
289    def test_inplace_updates_cudagraphs(self):
290        class Repro(torch.nn.Module):
291            def __init__(self) -> None:
292                super().__init__()
293                self.weight1 = torch.nn.Parameter(
294                    torch.randn(10, 20, requires_grad=True)
295                )
296
297            def forward(self, x):
298                x = torch.matmul(x, self.weight1)
299                return x
300
301        from copy import deepcopy
302
303        model = Repro().cuda()
304        model_ref = deepcopy(model)
305        model_opt = torch._dynamo.optimize("inductor")(model)
306
307        input = torch.randn(10, 10, device="cuda", requires_grad=True)
308
309        for i in range(2):
310            output_ref = model_ref(input)
311            output_res = model_opt(input)
312            output_ref.sum().backward()
313            output_res.sum().backward()
314            for p_ref, p_res in zip(model_ref.parameters(), model_opt.parameters()):
315                self.assertEqual(p_ref.grad, p_res.grad)
316            with torch.no_grad():
317                for param in model_ref.parameters():
318                    param.add_(1.0)
319                for param in model_opt.parameters():
320                    param.add_(1.0)
321
322    # https://github.com/pytorch/torchdynamo/issues/1850
323    def test_inductor_output_aliases_intermediate(self):
324        def foo(x):
325            out = x + x
326            return out.t()
327
328        foo_opt = torch._dynamo.optimize("inductor")(foo)
329
330        inpt = torch.randn(10, 10, device="cuda", requires_grad=True)
331        # TODO: this is broken, fix later
332        # out = foo_opt(inpt)
333        # out.add_(2)
334
335        out_ref = foo(inpt)
336        out_ref.add_(2)
337        # self.assertEqual(out_ref, out)
338
339    def test_accuracy_issue1(self):
340        class Repro(torch.nn.Module):
341            def __init__(self) -> None:
342                super().__init__()
343                self.linear = torch.nn.Linear(
344                    in_features=768, out_features=2, bias=True
345                )
346
347            def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
348                linear = self.linear(x)
349                split = linear.split(1, dim=-1)
350                getitem = split[0]
351                squeeze = getitem.squeeze(-1)
352                clamp = start_positions.clamp(0, 128)
353                cross_entropy = torch.nn.functional.cross_entropy(
354                    squeeze, clamp, None, None, 128, None, "mean", 0.0
355                )
356                return cross_entropy
357
358        mod = Repro().cuda()
359        opt_mod = torch._dynamo.optimize("inductor")(mod)
360        mod.eval()
361        opt_mod.eval()
362
363        args = [
364            ((1,), (1,), torch.int64, "cuda", False),
365            ((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True),
366        ]
367        args = [
368            rand_strided(sh, st, dt, dev).requires_grad_(rg)
369            for (sh, st, dt, dev, rg) in args
370        ]
371        with torch.cuda.amp.autocast(enabled=False):
372            assert same_two_models(mod, opt_mod, args), "Dynamo failed"
373
374    @config.patch(allow_buffer_reuse=False)
375    def test_issue103461(self):
376        def forward(add_1):
377            var_mean = torch.ops.aten.var_mean.correction(
378                add_1, [2], correction=0, keepdim=True
379            )
380            getitem_1 = var_mean[1]
381            return getitem_1
382
383        x = torch.randn(1, 8, 768, device="cuda")
384        correct = forward(x)
385        actual = torch.compile(forward, fullgraph=True)(x)
386        self.assertEqual(actual, correct)
387
388    def test_full_copy(self):
389        def forward(x):
390            full_10 = torch.ops.aten.full.default(
391                [204, 204, 28],
392                0,
393                dtype=torch.float64,
394                layout=torch.strided,
395                device="cuda",
396                pin_memory=False,
397            )
398            return x + full_10.to("cpu")
399
400        o = torch.randn([204, 204, 28], dtype=torch.float64)
401        correct = forward(o)
402        actual = torch.compile(forward, fullgraph=True)(o)
403        self.assertEqual(actual, correct)
404
405    def test_autotune_inplace_kernel(self):
406        """
407        This UT tests autotune on an inplace kernel. The autotune should not contaminate
408        the input buffers when tuning with multiple configs. For more details, refer to
409        https://github.com/openai/triton/issues/781
410        https://github.com/pytorch/torchdynamo/issues/1670
411        """
412        from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
413        from torch._inductor.runtime.hints import HeuristicType, instance_descriptor
414        from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid
415
416        def autotune(configs, meta):
417            def decorator(fn):
418                return CachingAutotuner(
419                    # force autotune by setting save_cache_hook to False
420                    fn,
421                    triton_meta=meta,
422                    configs=configs,
423                    save_cache_hook=False,
424                    mutated_arg_names=["in_out_ptr0"],
425                    heuristic_type=HeuristicType.POINTWISE,
426                )
427
428            return decorator
429
430        @autotune(
431            configs=[
432                triton.Config({"XBLOCK": 1}),
433                triton.Config({"XBLOCK": 2}),
434            ],
435            meta={
436                "signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
437                "device": DeviceProperties.create(torch.device("cuda")),
438                "configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],
439                "constants": {},
440            },
441        )
442        @triton.jit
443        def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):
444            pid = tl.program_id(0)
445            block_start = pid * XBLOCK
446            offsets = block_start + tl.arange(0, XBLOCK)
447            mask = offsets < xnumel
448            x = tl.load(in_out_ptr0 + offsets, mask=mask, other=0.0)
449            y = tl.load(in_ptr0 + offsets, mask=mask, other=0.0)
450            output = x + y
451            tl.store(in_out_ptr0 + offsets, output, mask=mask)
452
453        xnumel = 384
454        in0 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32)
455        inout1 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32)
456        inout2 = inout1.clone()
457
458        stream0 = get_cuda_stream(0)
459        kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0)
460        kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0)
461
462        assert same(
463            inout1, inout2, tol=0.001, equal_nan=True
464        ), "failed autotune with inplace kernel"
465
466    def test_sort_stride_issue(self):
467        # This minified testcase comes from detectron2_maskrcnn_r_50_fpn
468        # There was a false error from our size_assert code
469        @torch._dynamo.optimize(nopython=True)
470        def forward(pred_objectness_logits_3_: torch.Tensor):
471            sort_3 = pred_objectness_logits_3_.sort(descending=True, dim=1)
472            getitem_12 = sort_3[0]
473            return getitem_12
474
475        args = [((1, 100), (0, 1), torch.float16, "cuda", False)]
476        args = [
477            rand_strided(sh, st, dt, dev).requires_grad_(rg)
478            for (sh, st, dt, dev, rg) in args
479        ]
480        result = forward(*args)
481        assert same(result, torch.sort(args[0], descending=True, dim=1)[0])
482
483    def test_scalar_triton_index(self):
484        # The indirect indexing via a scalar like below used to lead to
485        # bad triton code that made triton segfault when compiling.
486        # See https://github.com/pytorch/torchdynamo/issues/1515
487        def fn(a):
488            zero = torch.zeros((16,), device=a.device, dtype=torch.int64)
489            return (a[zero],)
490
491        a = torch.randn((8,), dtype=torch.float32, device="cuda")
492
493        fn_optimized = torch._dynamo.optimize("inductor")(fn)
494        assert same(fn(a), fn_optimized(a))
495
496    def test_indirect_indexing_dense_mask(self):
497        def fn(x, y):
498            ne = torch.ops.aten.ne.Scalar(x, 1)
499            sum_1 = torch.ops.aten.sum.dim_IntList(ne, [1])
500            sub = torch.ops.aten.sub.Tensor(sum_1, 1)
501            unsqueeze = torch.ops.aten.unsqueeze.default(sub, -1)
502            gather = torch.ops.aten.gather.default(x, 1, unsqueeze)
503            squeeze = torch.ops.aten.squeeze.default(gather)
504            out = torch.ops.aten.multiply(y, squeeze)
505            return (out,)
506
507        a = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
508        b = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
509
510        fn_optimized = torch._dynamo.optimize("inductor")(fn)
511        assert same(fn(a, b), fn_optimized(a, b))
512
513    def test_simplify_dims(self):
514        def fn(a):
515            return (a + 1,)
516
517        self.common(fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],))
518
519    @config.patch(permute_fusion=True)
520    def test_permute_fusion(self):
521        class Repro(torch.nn.Module):
522            def forward(self, view, reshape_2):
523                permute = view.permute(0, 2, 1)
524                view = None
525                reshape = torch.reshape(permute, (-1, 642))
526                bmm = torch.bmm(permute, reshape_2)
527                return (bmm,)
528
529        args = [
530            ((1024, 642, 160), (102720, 160, 1), torch.float32, "cuda", True),
531            ((1024, 642, 20), (12840, 20, 1), torch.float32, "cuda", True),
532        ]
533        args = [
534            rand_strided(sh, st, dt, dev).requires_grad_(rg)
535            for (sh, st, dt, dev, rg) in args
536        ]
537
538        mod = Repro()
539        opt_mod = torch._dynamo.optimize("inductor")(mod)
540
541        ref = mod(*args)
542        res = opt_mod(*args)
543        self.assertTrue(same(ref, res))
544
545    @config.patch({"triton.autotune_pointwise": True})
546    def test_inplace_add_alpha_autotune(self):
547        def fn(x, y):
548            aten.add_.Tensor(x, y, alpha=0.55)
549            return (x,)
550
551        x1 = torch.zeros(2, 3, 4, 10, device="cuda")
552        x2 = torch.zeros(2, 3, 4, 10, device="cuda")
553        x3 = torch.zeros(2, 3, 4, 10, device="cuda")
554        y = torch.randn(2, 3, 4, 10, device="cuda").to(
555            memory_format=torch.channels_last
556        )
557        fn_fx = make_fx(fn)(x1, y)
558        fn_compiled = compile_fx_inner(fn_fx, [x1, y])
559        fn(x2, y)
560        fn_compiled([x3, y])
561        assert same(x2, x3)
562
563    @config.patch({"triton.autotune_pointwise": True})
564    def test_inplace_buffer_autotune(self):
565        def foo(x, y, z):
566            a = x @ y
567            return a.unsqueeze(0).unsqueeze(0) + z
568
569        x = torch.zeros(5, 5, device="cuda")
570        y = torch.zeros(5, 5, device="cuda")
571        z = torch.zeros(1, 1, 5, 5, device="cuda").to(memory_format=torch.channels_last)
572        self.common(
573            foo,
574            (x, y, z),
575            check_lowp=False,
576        )
577
578    def test_memory_history_inductor(self):
579        def called_inside_compile(x, w, b):
580            a = x @ w + b
581            return torch.sigmoid(a)
582
583        @torch.compile
584        def fn(x, w, b):
585            x = called_inside_compile(x, w, b)
586            return called_inside_compile(x, w, b)
587
588        w = torch.rand(3, 3, device="cuda")
589        b = torch.rand(3, device="cuda")
590        x = torch.rand(3, device="cuda")
591        try:
592            torch.cuda.memory.empty_cache()
593            torch.cuda.memory._record_memory_history(True)
594            r = fn(x, w, b)
595        finally:
596            torch.cuda.memory._record_memory_history(False)
597        snapshot = str(torch.cuda.memory._snapshot())
598        self.assertTrue("called_inside_compile" in snapshot)
599
600    def test_negative_arange_dynamic_shapes(self):
601        # Repro from alibi relative encodings
602        def sign(x):
603            return (x > 0) - (x < 0)
604
605        class Repro(torch.nn.Module):
606            def __init__(self) -> None:
607                super().__init__()
608                nheads = 16
609                start = math.log2(0.5)
610                end = math.log2(1 / (2**8))
611
612                self.scales = nn.Buffer(
613                    2
614                    ** torch.arange(
615                        start,
616                        end + 1e-6 * sign(end - start),
617                        (end - start) / (nheads - 1),
618                    ).view(1, nheads, 1, 1),
619                )
620                self.emb = nn.Embedding(1024, 256)
621                self.dec_layer = nn.TransformerDecoderLayer(
622                    256, 16, 512, batch_first=True, norm_first=True
623                )
624                self.head = nn.Linear(256, 1024)
625
626            def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor):
627                padmask = dec_in == 0
628                dec_mask = padmask.unsqueeze(-1) == padmask.unsqueeze(-2)
629                dec_mask = dec_mask.to(dtype=torch.float32)
630                dec_mask = dec_mask.tril(diagonal=0).cuda()
631
632                q_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda")
633                k_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda")
634                rel_pos = k_pos[None, :] - q_pos[:, None]
635                values = rel_pos.abs().neg().unsqueeze(0).unsqueeze(0)
636                dec_bias = values * self.scales
637                dec_bias.tril_(diagonal=0)
638
639                dec_mask = dec_mask + dec_bias[0]
640                out = self.emb(dec_in)
641                out = self.dec_layer(out, enc_out, tgt_mask=dec_mask)
642                return self.head(out)
643
644        mod = Repro().cuda()
645        opt_mod = torch._dynamo.optimize("inductor", dynamic=True)(mod)
646        mod.eval()
647        opt_mod.eval()
648
649        enc_out = torch.rand(1, 512, 256).cuda()
650        dec_inputs = [
651            torch.randint(0, 512, (1, i + 1), dtype=torch.long).cuda() for i in range(8)
652        ]
653
654        for dec_inp in dec_inputs:
655            assert same_two_models(
656                mod, opt_mod, [enc_out, dec_inp], only_fwd=True
657            ), "Inductor with dynamic shapes failed"
658
659    def test_issue97695_1input(self):
660        def fn(arg3_1, relu, permute_1):
661            addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1)
662            cat_2 = torch.ops.aten.cat.default([addmm_1], 1)
663            return (cat_2,)
664
665        args = [
666            ((96,), (1,), torch.float32, "cuda"),
667            ((10, 256), (256, 1), torch.float32, "cuda"),
668            ((256, 96), (1, 256), torch.float32, "cuda"),
669        ]
670        args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
671        correct = fn(*args)
672
673        mod = make_fx(fn, tracing_mode="real")(*args)
674        compiled = compile_fx_inner(mod, args)
675        ref = compiled(list(args))
676        assert same(ref, correct)
677
678        ref = torch.compile(fn, fullgraph=True)(*args)
679        assert same(ref, correct)
680
681    def test_issue_103924(self):
682        class MyModule(torch.nn.Module):
683            def __init__(self) -> None:
684                super().__init__()
685                self.temperature = 1
686                self.layer = torch.nn.Softmax(dim=1)
687
688            def forward(self, x):
689                n_samples, _ = x.shape
690                y = 1.0 * torch.ones(n_samples, dtype=x.dtype, device=x.device)
691                inp = x / y[..., None]
692                return self.layer(inp)
693
694        x = torch.rand([4, 4], device="cuda")
695        m = MyModule()
696        opt_m = torch.compile(backend="inductor")(m)
697        self.assertEqual(opt_m(x), m(x))
698
699    def test_issue97695_2input(self):
700        def fn(arg3_1, arg3_2, relu, permute_1):
701            addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1)
702            addmm_2 = torch.ops.aten.addmm.default(arg3_2, relu, permute_1)
703            cat_2 = torch.ops.aten.cat.default([addmm_1, addmm_2], 1)
704            return (cat_2,)
705
706        args = [
707            ((96,), (1,), torch.float32, "cuda"),
708            ((96,), (1,), torch.float32, "cuda"),
709            ((10, 256), (256, 1), torch.float32, "cuda"),
710            ((256, 96), (1, 256), torch.float32, "cuda"),
711        ]
712        args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
713        correct = fn(*args)
714
715        ref = torch.compile(fn, fullgraph=True)(*args)
716        assert same(ref, correct)
717
718    def test_scatter_index_not_wrapped(self):
719        src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], device=self.device)
720        index = torch.tensor([0, 1, 0, 1, 2, 0], device=self.device)
721        input = torch.tensor([1.0, 2.0, 3.0, 4.0], device=self.device)
722        compiled_sr = torch.compile(torch.scatter_reduce)
723
724        input_orig = input.clone()
725        out, code = run_and_get_code(compiled_sr, input, 0, index, src, "sum")
726        # tmp0 - not wrapping of negative numbers
727        FileCheck().check("tl.device_assert(((0 <= tmp0) & (tmp0 < 4))").check_next(
728            "atomic_add"
729        ).run(code[0])
730        self.assertEqual(
731            out, torch.scatter_reduce(input_orig.clone(), 0, index, src, "sum")
732        )
733
734    def test_embedding_var_mean(self):
735        def forward(arg0_1):
736            full = torch.ops.aten.full.default(
737                [1, 2048],
738                1,
739                dtype=torch.float32,
740                layout=torch.strided,
741                device=torch.device(type="cuda", index=0),
742                pin_memory=False,
743            )
744            convert_element_type_1 = torch.ops.prims.convert_element_type.default(
745                full, torch.int64
746            )
747            cumsum = torch.ops.aten.cumsum.default(convert_element_type_1, 1)
748            mul = torch.ops.aten.mul.Tensor(cumsum, convert_element_type_1)
749            sub_1 = torch.ops.aten.sub.Tensor(mul, 1)
750            slice_5 = torch.ops.aten.slice.Tensor(sub_1, 0, 0, 9223372036854775807)
751            slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 0, 9223372036854775807)
752            add_2 = torch.ops.aten.add.Tensor(slice_6, 2)
753            embedding_1 = torch.ops.aten.embedding.default(arg0_1, add_2)
754            var_mean = torch.ops.aten.var_mean.correction(
755                embedding_1, [2], correction=0, keepdim=True
756            )
757            return [var_mean[0], var_mean[1], add_2]
758
759        emb = torch.randn([2050, 768], device="cuda")
760        gm = make_fx(forward)(emb)
761        opt = torch._inductor.compile_fx.compile_fx_inner(gm, [emb])
762        opt([emb])
763        torch.cuda.synchronize()
764
765    def test_deterministic_algorithms(self):
766        N = 10000
767
768        @torch.compile
769        def fn(idx, values):
770            x = torch.zeros(1, device="cuda")
771            x[idx] += values
772            return x
773
774        idx = torch.zeros(N, dtype=torch.int64, device="cuda")
775        values = torch.randn(N, device="cuda")
776
777        r0 = fn(idx, values)
778        with DeterministicGuard(True):
779            r1 = fn(idx, values)
780            for _ in range(10):
781                rn = fn(idx, values)
782                self.assertEqual(r1, rn, atol=0, rtol=0)
783
784    # https://github.com/pytorch/pytorch/issues/96406
785    def test_linear_cpu_input(self):
786        class Model(nn.Module):
787            def __init__(self) -> None:
788                super().__init__()
789                self.linear = nn.Linear(4, 4)
790
791            def forward(self, data):
792                data = data.to("cuda")
793                return self.linear(data)
794
795        mod = Model().cuda().eval()
796        with torch.no_grad():
797            self.common(mod, (torch.randn(4, 4),))
798
799    @config.patch({"fallback_random": True, "triton.cudagraphs": True})
800    def test_xlnet_lm_stride_repro(self):
801        class Repro(nn.Module):
802            def __init__(self) -> None:
803                super().__init__()
804                self.dropout = nn.Dropout(p=0.1, inplace=False)
805
806            def forward(self, x):
807                y = torch._C._nn.gelu(x)
808                return self.dropout(y)
809
810        mod = Repro()
811        x = torch.randn((512, 1, 4096), requires_grad=True, device="cuda")
812        y = torch.compile(mod)(x)
813        # Inductor claims the output layout of gelu's saved variable for
814        # backwards will be (4096, 4096, 1) but in actuality it is (4096,
815        # 2097152, 1).  Fortunately this doesn't actually matter in practice.
816        y.sum().backward()
817
818    def test_lookup_seed_backward(self):
819        @torch.compile(fullgraph=True)
820        def forward(inductor_seeds, mul_4, view_15):
821            inductor_lookup_seed_2 = torch.ops.prims.inductor_lookup_seed.default(
822                inductor_seeds, 2
823            )
824            inductor_random_2 = torch.ops.prims.inductor_random.default(
825                [2, 512, 768], inductor_lookup_seed_2, "rand"
826            )
827            gt_2 = torch.ops.aten.gt.Scalar(inductor_random_2, 0.1)
828            mul_7 = torch.ops.aten.mul.Tensor(gt_2, view_15)
829            mul_8 = torch.ops.aten.mul.Tensor(mul_7, 1.1111111111111112)
830            add_5 = torch.ops.aten.add.Tensor(mul_8, mul_4)
831            var_mean_1 = torch.ops.aten.var_mean.correction(
832                add_5, [2], correction=0, keepdim=True
833            )
834            getitem_3 = var_mean_1[1]
835            sub_3 = torch.ops.aten.sub.Tensor(add_5, getitem_3)
836            return (sub_3,)
837
838        buf0 = torch.zeros((37,), dtype=torch.int64, device="cuda")
839        buf1 = torch.zeros((2, 512, 768), device="cuda")
840        buf2 = torch.zeros((2, 512, 768), device="cuda")
841        forward(buf0, buf1, buf2)
842
843    def test_issue100806(self):
844        class Model(torch.nn.Module):
845            def __init__(self) -> None:
846                super().__init__()
847                self.linear1 = torch.nn.Linear(10, 20)
848                self.linear2 = torch.nn.Linear(20, 30)
849                self.relu = torch.nn.ReLU()
850
851            def forward(self, x):
852                x = self.linear1(x)
853                x = self.linear2(x)
854                x = torch.cat((x, x), dim=1)
855                x = x.view(-1, 2, 30)
856                x = x[:, 1, :]
857                x = self.relu(x)
858                return x
859
860        device = "cuda"
861        batch_size = 2
862        x = torch.randn(batch_size, 10).to(device)
863        func = Model().to(device)
864
865        with torch.no_grad():
866            func.train(False)
867            jit_func = torch.compile(func)
868
869            res1 = func(x)
870            res2 = jit_func(x)
871            self.assertEqual(res1, res2)
872
873    def test_issue103481(self):
874        def fn(x, y):
875            # NOTE: 6 dimensions is important! does not fail for 5 dimensions
876            mean = torch.mean(x, [2, 3, 4, 5], keepdim=True)
877            add = mean + y
878            return add
879
880        x = torch.rand(4, 4, 4, 4, 4, 4, device="cuda")
881        y = torch.rand((), device="cuda")
882        expect = fn(x, y)
883
884        opt_fn = torch.compile(fn)
885        actual = opt_fn(x, y)
886
887        self.assertEqual(expect, actual)
888
889    @config.patch({"triton.dense_indexing": True})
890    @dynamo_config.patch(automatic_dynamic_shapes=True)
891    def test_bucketize_dynamic_dense(self):
892        """
893        Make sure that ops.bucketize() can handle dense_indexing, which previously
894        caused issues due to incorrect handling of the size of offsets.
895        """
896
897        def fn(values, offsets):
898            return torch.bucketize(values, offsets)
899
900        values = torch.rand((64, 64), device="cuda")
901        offsets = torch.tensor([0.05, 0.1, 0.5, 0.8, 0.85, 0.95], device="cuda")
902
903        expect = fn(values, offsets)
904
905        opt_fn = torch.compile(fn, dynamic=True)
906        actual = opt_fn(values, offsets)
907
908        self.assertEqual(expect, actual)
909
910    def test_float64_constants(self):
911        def fn():
912            # NOTE: tensors of all the same value are constant folded, so we
913            # need a tensor with two distinct values
914            a = torch.tensor([1 / 10, 2 / 10], dtype=torch.float64, device="cuda")
915            return a * 2e50
916
917        cfn = torch.compile(fn)
918        expect = fn()
919        actual = cfn()
920        self.assertEqual(expect, actual, atol=0, rtol=0)
921
922    def test_issue104759(self):
923        def fn(arg7_1, add_1, permute_2, select_scatter, slice_8):
924            slice_scatter_4 = torch.ops.aten.slice_scatter.default(
925                permute_2, select_scatter, 0, 1, 9223372036854775807
926            )
927            permute_3 = torch.ops.aten.permute.default(slice_scatter_4, [1, 3, 0, 2, 4])
928            view_6 = torch.ops.aten.view.default(permute_3, [1, 1000, 48])
929            view_7 = torch.ops.aten.view.default(view_6, [1000, 48])
930            view_8 = torch.ops.aten.view.default(view_7, [1, 1000, 48])
931            view_9 = torch.ops.aten.view.default(view_8, [1, 1000, 3, 4, 4])
932            permute_4 = torch.ops.aten.permute.default(view_9, [2, 0, 3, 1, 4])
933            slice_7 = torch.ops.aten.slice.Tensor(permute_4, 0, 1, 9223372036854775807)
934            slice_scatter_5 = torch.ops.aten.slice_scatter.default(
935                slice_8, slice_7, 4, 0, 9223372036854775807
936            )
937            slice_scatter_6 = torch.ops.aten.slice_scatter.default(
938                arg7_1, slice_scatter_5, 3, 0, 1000
939            )
940            mul_8 = torch.ops.aten.mul.Scalar(add_1, 0.7071067811865476)
941            slice_9 = torch.ops.aten.slice.Tensor(slice_scatter_6, 3, 0, 1000)
942            slice_10 = torch.ops.aten.slice.Tensor(slice_9, 4, 0, 9223372036854775807)
943            select_2 = torch.ops.aten.select.int(slice_10, 0, 0)
944            permute_5 = torch.ops.aten.permute.default(select_2, [0, 1, 3, 2])
945            mul_9 = torch.ops.aten.mul.Scalar(permute_5, 0.7071067811865476)
946            expand = torch.ops.aten.expand.default(mul_8, [1, 4, 1000, 4])
947            view_10 = torch.ops.aten.view.default(expand, [4, 1000, 4])
948            expand_1 = torch.ops.aten.expand.default(mul_9, [1, 4, 4, 1000])
949            view_11 = torch.ops.aten.view.default(expand_1, [4, 4, 1000])
950            bmm = torch.ops.aten.bmm.default(view_10, view_11)
951            return (bmm,)
952
953        args = []
954        args.append(torch.randn((2, 1, 4, 1200, 4), dtype=torch.float16, device="cuda"))
955        args.append(
956            rand_strided(
957                (1, 4, 1000, 4), (16000, 4, 16, 1), dtype=torch.float16, device="cuda"
958            )
959        )
960        args.append(
961            rand_strided(
962                (3, 1, 4, 1000, 4),
963                (16, 48000, 4, 48, 1),
964                dtype=torch.float16,
965                device="cuda",
966            )
967        )
968        args.append(
969            rand_strided(
970                (2, 1, 4, 1000, 4),
971                (16, 48000, 4, 48, 1),
972                dtype=torch.float16,
973                device="cuda",
974            )
975        )
976        args.append(
977            rand_strided(
978                (2, 1, 4, 1000, 4),
979                (19200, 19200, 4800, 4, 1),
980                dtype=torch.float16,
981                device="cuda",
982            )
983        )
984
985        correct = fn(*args)
986        mod = make_fx(fn, tracing_mode="real")(*args)
987        compiled = compile_fx_inner(mod, args)
988        ref = compiled(list(args))
989        assert same(ref, correct)
990
991    @config.patch({"triton.cudagraphs": True})
992    def test_index_put_inplace_cudagraph(self):
993        def fn(x, y, z):
994            x = torch.zeros_like(x)
995            return x.index_put_([y], z, True)
996
997        x = torch.zeros((512, 512), device="cuda", dtype=torch.bool)
998        y = torch.zeros((512,), device="cuda", dtype=torch.int64)
999        z = torch.ones((512, 512), device="cuda", dtype=torch.bool)
1000
1001        opt_fn = torch._dynamo.optimize("inductor")(fn)
1002
1003        ref = fn(x, y, z)
1004
1005        # run it twice to test cuda graph issue
1006        res = opt_fn(x, y, z)
1007        res = opt_fn(x, y, z)
1008
1009        self.assertEqual(ref, res)
1010
1011    @config.patch({"triton.cudagraphs": True})
1012    @config.patch({"fx_graph_cache": True})
1013    def test_index_put_cudagraph(self):
1014        for _ in range(2):
1015
1016            def fn(x, y, z):
1017                x = torch.zeros_like(x)
1018                return x.index_put([y], z, True)
1019
1020            x = torch.zeros((512, 512), device="cuda", dtype=torch.bool)
1021            y = torch.zeros((512,), device="cuda", dtype=torch.int64)
1022            z = torch.ones((512, 512), device="cuda", dtype=torch.bool)
1023
1024            opt_fn = torch._dynamo.optimize("inductor")(fn)
1025
1026            ref = fn(x, y, z)
1027
1028            # run it twice to test cuda graph issue
1029            res = opt_fn(x, y, z)
1030            res = opt_fn(x, y, z)
1031
1032            self.assertEqual(ref, res)
1033            torch._dynamo.reset()
1034            gc.collect()
1035
1036    @unittest.skipIf(
1037        not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported"
1038    )
1039    def test_flash_attention_dynamic(self):
1040        class Model(nn.Module):
1041            def __init__(self, *args, **kwargs) -> None:
1042                super().__init__(*args, **kwargs)
1043
1044                self.q = nn.Linear(1024, 1024)
1045                self.k = nn.Linear(1024, 1024)
1046                self.v = nn.Linear(1024, 1024)
1047
1048            def forward(self, x):
1049                batch_size, seq_len, _ = x.size()
1050
1051                queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1052                keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1053                values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1054
1055                attn = F.scaled_dot_product_attention(
1056                    queries,
1057                    keys,
1058                    values,
1059                )
1060
1061                return attn
1062
1063        cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
1064
1065        model = Model().cuda().half()
1066        model = torch.compile(model, backend=cnts, dynamic=True)
1067
1068        with torch.backends.cuda.sdp_kernel(
1069            enable_flash=True,
1070            enable_math=False,
1071            enable_mem_efficient=False,
1072            enable_cudnn=False,
1073        ):
1074            input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16)
1075            input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16)
1076            input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16)
1077
1078            out1 = model(input1)
1079            out2 = model(input2)
1080            out3 = model(input3)
1081
1082        self.assertEqual(cnts.frame_count, 1)
1083
1084    @config.patch({"triton.cudagraphs": True})
1085    def test_index_put_no_fallback_cudagraph(self):
1086        def fn(x, y, z):
1087            x = torch.zeros_like(x)
1088            return x.index_put([y], z, True)
1089
1090        x = torch.zeros((512, 512), device="cuda", dtype=torch.int32)
1091        y = torch.zeros((512,), device="cuda", dtype=torch.int64)
1092        z = torch.ones((512, 512), device="cuda", dtype=torch.int32)
1093
1094        opt_fn = torch._dynamo.optimize("inductor")(fn)
1095
1096        ref = fn(x, y, z)
1097
1098        # run it twice to test cuda graph issue
1099        res = opt_fn(x, y, z)
1100        res = opt_fn(x, y, z)
1101
1102        self.assertEqual(ref, res)
1103
1104    # https://github.com/pytorch/pytorch/issues/104937
1105    def test_linear_with_zero_infeature_size(self):
1106        m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda")
1107        x = torch.rand(1, 1, 0, device="cuda")
1108        expect = m(x)
1109        opt_fn = torch.compile(m)
1110        actual = opt_fn(x)
1111        self.assertEqual(expect, actual)
1112
1113    @config.patch(fallback_random=True)
1114    def test_multi_output_layout_fallback(self):
1115        mod = nn.RReLU(lower=3.2350976, upper=8.4220314, inplace=True)
1116        inp = torch.rand([4, 4]).cuda()
1117        m = torch.compile(mod)
1118
1119        with freeze_rng_state():
1120            o1 = m(inp.clone())
1121
1122        o2 = mod(inp.clone())
1123
1124        self.assertEqual(o1, o2)
1125
1126    def test_cat_int8_one_kernel(self):
1127        @torch.compile()
1128        def cat(inps):
1129            return torch.cat(inps) + 1
1130
1131        for dtype in [torch.uint8, torch.int8]:
1132            inps = [
1133                torch.empty([256, 256], dtype=dtype, device="cuda") for _ in range(4)
1134            ]
1135
1136            out, code = run_and_get_code(cat, inps)
1137            self.assertEqual(torch.cat(inps) + 1, out)
1138            FileCheck().check_not("aten.cat.default(").check_count(
1139                ".run(", 1, exactly=True
1140            ).run(code[0])
1141
1142    @config.patch("triton.use_block_ptr", True)
1143    def test_selecsls42b_misaligned_address(self):
1144        # https://github.com/openai/triton/issues/2836
1145
1146        @torch.compile(fullgraph=True)
1147        def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3):
1148            div = torch.ops.aten.div.Scalar(expand, 16)
1149            where = torch.ops.aten.where.self(arg207_1, full, div)
1150            convert_element_type_43 = torch.ops.prims.convert_element_type.default(
1151                where, torch.float32
1152            )
1153            sum_2 = torch.ops.aten.sum.dim_IntList(convert_element_type_43, [0, 2, 3])
1154            sub = torch.ops.aten.sub.Tensor(convert_element_type_40, arg208_1)
1155            mul = torch.ops.aten.mul.Tensor(convert_element_type_43, sub)
1156            sum_3 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3])
1157            mul_1 = torch.ops.aten.mul.Tensor(sum_2, 0.0078125)
1158            unsqueeze = torch.ops.aten.unsqueeze.default(mul_1, 0)
1159            unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
1160            unsqueeze_2 = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3)
1161            mul_2 = torch.ops.aten.mul.Tensor(sum_3, 0.0078125)
1162            mul_4 = torch.ops.aten.mul.Tensor(mul_2, mul_3)
1163            unsqueeze_3 = torch.ops.aten.unsqueeze.default(mul_4, 0)
1164            unsqueeze_4 = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2)
1165            unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3)
1166            mul_6 = torch.ops.aten.mul.Tensor(sub, unsqueeze_5)
1167            sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_43, mul_6)
1168            sub_2 = torch.ops.aten.sub.Tensor(sub_1, unsqueeze_2)
1169            return (sub_2,)
1170
1171        args = [
1172            torch.randn((8, 1024, 4, 4), device="cuda") > 0,  # torch.bool tensor
1173            torch.randn((1, 1024, 1, 1), device="cuda"),
1174            torch.randn((8, 1024, 4, 4), device="cuda"),
1175            torch.randn((8, 1024, 1, 1), dtype=torch.float16, device="cuda").expand(
1176                (8, 1024, 4, 4)
1177            ),
1178            torch.randn((), device="cuda"),
1179            torch.randn((1024,), device="cuda"),
1180        ]
1181        fn(*args)
1182        torch.cuda.synchronize()  # shake out Triton Error [CUDA]: misaligned address
1183
1184    @skipIfRocm
1185    def test_non_commutative_scan_op(self):
1186        from torch._higher_order_ops.associative_scan import associative_scan
1187
1188        a = torch.randn(1024, 8192, dtype=torch.float64, device="cuda")
1189        b = torch.randn(1024, 8192, dtype=torch.float64, device="cuda")
1190
1191        def baseline(v, u):
1192            A = []
1193            A.append(b[:, 0])
1194            for i in range(1, v.shape[1]):
1195                A.append(a[:, i] * A[i - 1] + b[:, i])
1196            return torch.stack(A, dim=1)
1197
1198        def combine_fn(i, j):
1199            ia, ib = i
1200            ja, jb = j
1201            return ia * ja, ib * ja + jb
1202
1203        @torch.compile
1204        def compiled_scan(a, b):
1205            return associative_scan(combine_fn, (a, b), dim=-1)[1]
1206
1207        out1 = baseline(a, b)
1208        out2 = compiled_scan(a, b)
1209        self.assertEqual(out1, out2)
1210
1211    def test_dynamic_persistent_reductions(self):
1212        @torch.compile(dynamic=True)
1213        def inner_reduce(x):
1214            assert x.shape[1] <= 1024
1215            return x.sum(1)
1216
1217        a = torch.randn(50, 600, device="cuda")
1218        out, code = run_and_get_code(inner_reduce, a)
1219        self.assertEqual(inner_reduce(a), out)
1220        self.assertTrue("for roffset" not in code)
1221
1222        @torch.compile(dynamic=True)
1223        def outer_reduce(x):
1224            assert x.shape[0] <= 64
1225            return x.sum(0)
1226
1227        out, code = run_and_get_code(outer_reduce, a)
1228        self.assertEqual(outer_reduce(a), out)
1229        self.assertTrue("for roffset" not in code)
1230
1231    def test_non_contiguous_unaligned_input_indices(self):
1232        from torch._inductor.compile_fx import remove_unaligned_input_idxs
1233
1234        inputs = [torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")[1:]]
1235        idxs = remove_unaligned_input_idxs(inputs, [1])
1236        self.assertEqual(idxs, [])
1237
1238        inputs = [
1239            torch.ones(2, 2, device="cuda"),
1240            torch.ones(2, 2, device="cuda"),
1241            torch.ones(2, 2, device="cuda")[1:],
1242        ]
1243        idxs = remove_unaligned_input_idxs(inputs, [0, 2])
1244        self.assertEqual(idxs, [0])
1245
1246    @config.patch("triton.cudagraphs", True)
1247    def test_unused_cpu_input_cudagraphs(self):
1248        def fn(x, y):
1249            return x.sin().sin().sin().sin().cos() + 1
1250
1251        fx_graph = torch.fx.symbolic_trace(fn)
1252        inp = [torch.randn(64, device="cuda"), torch.randn(64, device="cpu")]
1253        compiled_fn, (graph,) = run_and_get_graph_lowering(
1254            torch._inductor.compile, fx_graph, inp
1255        )
1256        self.assertEqual(graph.disable_cudagraphs_reason, None)
1257        self.assertEqual(graph.device_types, {"cuda"})
1258        self.assertEqual(compiled_fn(*inp), fn(*inp))
1259
1260    def test_epilogue_fusion_with_view(self):
1261        class ToyModel(torch.nn.Module):
1262            def __init__(self) -> None:
1263                super().__init__()
1264                self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
1265                self.linear = torch.nn.Linear(262144, 100)
1266                self.relu = torch.nn.ReLU()
1267
1268            def forward(self, x):
1269                x = self.conv(x)
1270                x = x.view(x.size(0), -1)
1271                return self.relu(self.linear(x))
1272
1273        m = ToyModel().to(device="cuda:0")
1274        input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0")
1275        from torch._inductor.utils import fresh_inductor_cache
1276
1277        with fresh_inductor_cache():
1278            cm = torch.compile(m, mode="max-autotune")
1279            out = cm(input_tensor)
1280            out2 = m(input_tensor)
1281            self.assertEqual(out, out2, atol=1e-3, rtol=1e-3)
1282
1283    @config.patch("triton.cudagraphs", True)
1284    def test_cpu_index(self):
1285        @torch.compile(fullgraph=True)
1286        def fn(x):
1287            return x[torch.arange(32)]
1288
1289        result, (graph,) = run_and_get_graph_lowering(
1290            fn, torch.randn(64, device="cuda")
1291        )
1292        self.assertEqual(graph.disable_cudagraphs_reason, None)
1293        self.assertEqual(graph.device_types, {"cuda"})
1294
1295        inp = torch.randn(64, device="cuda", requires_grad=True)
1296        result, (graph,) = run_and_get_graph_lowering(fn, inp)
1297        self.assertEqual(graph.disable_cudagraphs_reason, None)
1298        self.assertEqual(graph.device_types, {"cuda"})
1299
1300        result, (graph,) = run_and_get_graph_lowering(lambda: result.sum().backward())
1301        self.assertEqual(graph.disable_cudagraphs_reason, None)
1302        self.assertEqual(graph.device_types, {"cuda"})
1303
1304    def test_reflection_pad_loop_order(self):
1305        def fn(x, y):
1306            a = torch.nn.functional.pad(x, (5, 5, 5, 5), mode="reflect")
1307            b = torch.nn.functional.pad(y, (5, 5, 5, 5), mode="reflect")
1308            return a + b
1309
1310        cfn = torch.compile(fn)
1311        a = torch.rand((10, 10, 10), device="cuda")
1312        b = torch.rand((10, 10, 10), device="cuda")
1313        expect = fn(a, b)
1314        actual, code = run_and_get_code(cfn, a, b)
1315        self.assertEqual(expect, actual)
1316
1317        # Expect the code iterates in contiguous order, and is not tiled
1318        kernel_code = "\n".join(code[0].split("\n")[60:74])
1319        self.assertExpectedInline(
1320            kernel_code,
1321            """\
1322@triton.jit
1323def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
1324    xnumel = 4000
1325    xoffset = tl.program_id(0) * XBLOCK
1326    xindex = xoffset + tl.arange(0, XBLOCK)[:]
1327    xmask = xindex < xnumel
1328    x0 = xindex % 20
1329    x1 = (xindex // 20) % 20
1330    x2 = (xindex // 400)
1331    x3 = xindex
1332    tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
1333    tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
1334    tmp2 = tmp0 + tmp1
1335    tl.store(out_ptr0 + (x3), tmp2, xmask)""",  # noqa: B950
1336        )
1337
1338    @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
1339    def test_int64_index_intermediate(self):
1340        def foo(inp):
1341            view_23 = torch.ops.aten.view.default(inp, [-1, 8192, 8192])
1342            split_1 = torch.ops.aten.split.Tensor(view_23, 1024, 1)
1343            view_23 = None
1344            getitem_17 = split_1[0]
1345            getitem_18 = split_1[1]
1346            getitem_19 = split_1[2]
1347            getitem_20 = split_1[3]
1348            getitem_21 = split_1[4]
1349            getitem_22 = split_1[5]
1350            getitem_23 = split_1[6]
1351            getitem_24 = split_1[7]
1352            split_1 = None
1353            cat_1 = torch.ops.aten.cat.default(
1354                [
1355                    getitem_17,
1356                    getitem_18,
1357                    getitem_19,
1358                    getitem_20,
1359                    getitem_21,
1360                    getitem_22,
1361                    getitem_23,
1362                    getitem_24,
1363                ]
1364            )
1365            getitem_17 = (
1366                getitem_18
1367            ) = (
1368                getitem_19
1369            ) = getitem_20 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None
1370            return cat_1
1371
1372        for mark_dynamic in [False, True]:
1373            inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device="cuda")
1374            if mark_dynamic:
1375                torch._dynamo.mark_dynamic(inp, 0)
1376            foo_c = torch.compile(foo)
1377            torch.testing.assert_allclose(foo(inp), foo_c(inp))
1378
1379
1380if __name__ == "__main__":
1381    from torch._inductor.test_case import run_tests
1382    from torch.testing._internal.inductor_utils import HAS_CUDA
1383
1384    if HAS_CUDA and not TEST_WITH_ASAN:
1385        run_tests(needs="filelock")
1386