xref: /aosp_15_r20/external/pytorch/test/functorch/test_aotdispatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: pt2"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import copy
10import itertools
11import unittest
12import warnings
13from contextlib import nullcontext
14from functools import partial, wraps
15from typing import Any, Callable, Dict, List, Optional, Union
16from unittest.mock import patch
17
18from common_utils import decorate, decorateForModules, skip, skipOps, xfail
19
20import torch
21import torch._dynamo as torchdynamo
22import torch.nn as nn
23import torch.utils._pytree as pytree
24from functorch import grad, jacrev, make_fx, vjp, vmap
25from functorch.compile import (
26    aot_function,
27    aot_module,
28    aot_module_simplified,
29    compiled_function,
30    compiled_module,
31    default_decompositions,
32    default_partition,
33    get_aot_compilation_context,
34    make_boxed_compiler,
35    make_boxed_func,
36    memory_efficient_fusion,
37    min_cut_rematerialization_partition,
38    nnc_jit,
39    nop,
40)
41from functorch.experimental import control_flow
42from torch._decomp import decomposition_table
43from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
44from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
45from torch._higher_order_ops.out_dtype import out_dtype
46from torch._inductor.codecache import compiled_fx_graph_hash
47from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
48from torch.fx.experimental.proxy_tensor import is_sym_node
49from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
50from torch.nn.utils.rnn import PackedSequence
51from torch.testing._internal.common_device_type import (
52    instantiate_device_type_tests,
53    ops,
54    tol,
55    toleranceOverride,
56)
57from torch.testing._internal.common_methods_invocations import op_db
58from torch.testing._internal.common_modules import module_db, modules
59from torch.testing._internal.common_utils import (
60    compare_equal_outs_and_grads,
61    instantiate_parametrized_tests,
62    IS_ARM64,
63    IS_MACOS,
64    IS_WINDOWS,
65    IS_X86,
66    outs_and_grads,
67    parametrize,
68    run_tests,
69    skipIfRocm,
70    skipIfTorchDynamo,
71    TestCase,
72    xfail_inherited_tests,
73    xfailIfTorchDynamo,
74)
75from torch.testing._internal.custom_tensor import ConstantExtraMetadataTensor
76from torch.testing._internal.hop_db import hop_db
77from torch.testing._internal.optests import (
78    _test_aot_autograd_forwards_backwards_helper,
79    aot_autograd_check,
80)
81from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
82
83
84USE_TORCHVISION = False
85try:
86    import torchvision
87
88    USE_TORCHVISION = True
89except ImportError:
90    warnings.warn(
91        "Couldn't import torchvision. Some of our tests use it, try "
92        "to install it with commands from pytorch.org, post-fixed with "
93        "`--no-deps` to avoid overwriting the pytorch installation",
94        UserWarning,
95    )
96
97USE_NETWORKX = False
98try:
99    import networkx  # noqa: F401
100
101    USE_NETWORKX = True
102except ImportError:
103    warnings.warn("Some tests use networkx but it was not installed", UserWarning)
104
105# NB: numpy is a testing dependency!
106
107
108class AOTTestCase(TestCase):
109    pass
110
111
112class TestPythonKey(AOTTestCase):
113    def test_make_fx(self, device):
114        def f(x):
115            return torch.sin(x)
116
117        inp = torch.randn(3)
118        fx_f = make_fx(f)(inp)
119
120        new_inp = torch.randn(3)
121        self.assertEqual(fx_f(new_inp), f(new_inp))
122
123    def test_make_fx_grad(self, device):
124        def f(x):
125            return torch.sin(x).sum()
126
127        inp = torch.randn(3)
128        f = grad(f)
129        fx_f = make_fx(f)(inp)
130
131        new_inp = torch.randn(3)
132        self.assertEqual(fx_f(new_inp), f(new_inp))
133
134    def test_scalar_device(self, device):
135        def f(a, b):
136            return a + b
137
138        inps = [torch.randn(3, device=device), torch.tensor(5)]
139        fx_f = make_fx(f)(*inps)
140        self.assertEqual(fx_f(*inps), f(*inps))
141
142    def test_make_fx_vmap(self, device):
143        def f(x):
144            return torch.sin(x)
145
146        inp = torch.randn(5, 3)
147        f = vmap(f)
148        fx_f = make_fx(f)(inp)
149        new_inp = torch.randn(5, 3)
150        self.assertEqual(fx_f(new_inp), f(new_inp))
151
152    def test_make_fx_jacrev(self, device):
153        def f(x):
154            return x.sin().sum()
155
156        inp = torch.randn(3)
157        f = jacrev(jacrev(f))
158        fx_f = make_fx(f)(inp)
159        new_inp = torch.randn(3)
160        self.assertEqual(fx_f(new_inp), f(new_inp))
161
162    def test_make_fx_vjp(self, device):
163        def f(x):
164            return torch.sin(x).sum()
165
166        primals = torch.randn(3)
167        _, vjp_fn = vjp(f, primals)
168        cotangent = torch.randn(())
169        fx_f = make_fx(vjp_fn)(cotangent, True, True)
170        new_cotangent = torch.randn(())
171        self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
172
173    def test_make_fx_functionalize(self, device):
174        from functorch.experimental import functionalize
175
176        def fn(a):
177            a = a * 2
178            a.relu_()
179            return a
180
181        a = torch.randn(3, device=device)
182        symbolic_gm = torch.fx.symbolic_trace(fn)
183        includes_method_relu_ = any(
184            str(n.target) == "relu_" for n in symbolic_gm.graph.nodes
185        )
186        self.assertTrue(includes_method_relu_)
187        # Also verifies fix for https://github.com/pytorch/pytorch/issues/84570
188        gm = make_fx(functionalize(symbolic_gm))(a)
189        includes_aten_relu = any(
190            n.target == torch.ops.aten.relu.default for n in gm.graph.nodes
191        )
192        self.assertTrue(includes_aten_relu)
193
194    def test_make_fx_no_decompose(self, device):
195        # FIXME
196        return self.skipTest("error: maximum recursion reached")
197
198        def f(x):
199            return torch.tanh(x).sum()
200
201        fx_f = make_fx(grad(f))(torch.randn(5))
202        ops = {i.target for i in fx_f.graph.nodes}
203
204        self.assertEqual(torch.ops.aten.tanh_backward in ops, True)
205
206        fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
207        ops = {i.target for i in fx_f.graph.nodes}
208        self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
209
210    def test_nnc_jit(self, device):
211        def f(x):
212            return torch.sin(x)
213
214        jit_f = nnc_jit(f)
215
216        inp = torch.randn(3)
217        self.assertEqual(jit_f(inp), f(inp))
218
219    def test_nnc_scalar(self, device):
220        def f(x):
221            return torch.sin(x)
222
223        jit_f = nnc_jit(f)
224
225        inp = torch.randn(())
226        self.assertEqual(jit_f(inp), f(inp))
227
228    def test_nnc_pytrees(self, device):
229        def f(x):
230            return [torch.sin(x[0])]
231
232        jit_f = nnc_jit(f)
233
234        inp = [torch.randn(3)]
235        self.assertEqual(jit_f(inp), f(inp))
236
237    def test_external_calls(self, device):
238        def f(a, b):
239            return torch.mv(a, b)
240
241        jit_f = nnc_jit(f)
242        inp = [torch.randn(3, 3), torch.randn(3)]
243        self.assertEqual(jit_f(*inp), f(*inp))
244
245    def test_nnc_passthrough(self, device):
246        def f(x, y):
247            return x + y, y
248
249        inp = (torch.randn(3), torch.randn(3))
250        jit_f = nnc_jit(f)
251        self.assertEqual(jit_f(*inp), f(*inp))
252
253        def f(x):
254            x["a"] = x["a"] * 2
255            return x
256
257        inp = ({"a": torch.randn(3), "b": torch.randn(3)},)
258        jit_f = nnc_jit(f)
259        self.assertEqual(jit_f(*inp), f(*inp))
260
261    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
262    def test_resnet18_backward_trace(self, device):
263        mod = torchvision.models.resnet18()
264
265        def f(x):
266            out = mod(x)
267            out.sum().backward()
268            return [a.grad for a in mod.parameters()]
269
270        inp = torch.randn(3, 3, 250, 250, requires_grad=True)
271        grads = f(inp)
272
273        mod.zero_grad()
274        mod(inp).sum().backward()
275        grads2 = [a.grad for a in mod.parameters()]
276        self.assertEqual(grads, grads2)
277
278
279def get_base(t):
280    return t._base if t._is_view() else t
281
282
283def is_in_base(t, maybe_tensors):
284    t_base = get_base(t)
285    for maybe_tensor in maybe_tensors:
286        if isinstance(maybe_tensor, torch.Tensor):
287            if t_base is get_base(maybe_tensor):
288                return True
289    return False
290
291
292def skipIfDynamoInput(reason):
293    """
294    Skip TestAOTAutograd if running with dynamo input
295    """
296
297    def decorator(func):
298        @wraps(func)
299        def wrapper(self, *args, **kwargs):
300            if isinstance(self, TestAOTAutogradWithDynamo):
301                self.skipTest(
302                    f"Skipping {self._testMethodName} in TestAOTAutogradWithDynamo because {reason}"
303                )
304            else:
305                func(self, *args, **kwargs)
306
307        return wrapper
308
309    return decorator
310
311
312class TestAOTAutograd(AOTTestCase):
313    def run_autograd(
314        self,
315        f: Callable,
316        fw_graph_cell: List[Optional[Callable]],
317        decompositions: Optional[Dict],
318        keep_input_mutations: bool,
319        dynamic: bool,
320    ):
321        """
322        Runs aot_autograd with the specified settings on f.
323        """
324        if isinstance(f, nn.Module):
325            compiled_f = aot_module(
326                f,
327                fw_compiler=make_boxed_compiler(
328                    partial(extract_graph, graph_cell=fw_graph_cell)
329                ),
330                bw_compiler=nop,
331                decompositions=decompositions,
332                keep_inference_input_mutations=keep_input_mutations,
333                dynamic=dynamic,
334            )
335        else:
336            compiled_f = aot_function(
337                f,
338                fw_compiler=make_boxed_compiler(
339                    partial(extract_graph, graph_cell=fw_graph_cell)
340                ),
341                bw_compiler=nop,
342                decompositions=decompositions,
343                keep_inference_input_mutations=keep_input_mutations,
344                dynamic=dynamic,
345            )
346        return compiled_f
347
348    # test_mutation will:
349    # - Ensure that inputs are non-leaves, so our graphs can mutate them
350    # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs)
351    @patch("functorch.compile.config.debug_assert", True)
352    def verify_aot_autograd(
353        self,
354        f,
355        inp_: Union[Callable, List[Any]],
356        *,
357        test_mutation: bool = False,
358        keep_inp_mutations: bool = False,
359        decompositions: Optional[Dict] = None,
360        dynamic: bool = False,
361        # Only active when inp_ is Callable.
362        # TODO: probably consolidate all tests to make inp a Callable.
363        make_inputs_subclasses: bool = False,
364    ):
365        def make_inputs(inp_):
366            # Some tests pass in a callable for inp, to generate the inputs
367            # (useful if we want to generate complicated aliasing inputs)
368            if isinstance(inp_, Callable):
369                inp_callable = inp_
370                # The callable should return a tuple of f_inputs, f_graph_inputs
371                # (The idea is that we might want to compile a function with the graph inputs,
372                # but test autograd backprop all the way through the actual inputs)
373                with TwoTensorMode() if make_inputs_subclasses else nullcontext():
374                    inp, graph_inps = inp_callable()
375            else:
376                inp = []
377                # Our input clones need to mimic when inputs are duplicates of one another
378                dupes_map = {}
379                for i, x in enumerate(inp_):
380                    if x in dupes_map:
381                        x_dupe_idx = dupes_map[x]
382                        inp.append(inp[x_dupe_idx])
383                    else:
384                        dupes_map[x] = i
385                        if not isinstance(x, torch.Tensor):
386                            x_copy = x
387                        else:
388                            x_copy = x.clone().detach().requires_grad_(x.requires_grad)
389                            if x.requires_grad and not x.is_leaf:
390                                x_copy = x_copy.clone()
391
392                        inp.append(x_copy)
393
394                if test_mutation:
395                    # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves
396                    graph_inps = [x.add(1) for x in inp]
397                else:
398                    graph_inps = inp
399
400            return inp, graph_inps
401
402        def check_results(
403            ref_results,
404            test_results,
405            ref_graph_inps,
406            test_graph_inps,
407            ref_inp,
408            test_inp,
409        ):
410            ref_out, ref_grad = ref_results
411            test_out, test_grad = test_results
412            self.assertEqual(ref_grad, test_grad)
413            if isinstance(ref_out, torch.Tensor):
414                self.assertTrue(isinstance(test_out, torch.Tensor))
415                ref_out, test_out = [ref_out], [test_out]
416            for ref_o, test_o in zip(ref_out, test_out):
417                if isinstance(ref_o, torch.Tensor):
418                    self.assertEqual(ref_o.requires_grad, test_o.requires_grad)
419                    self.assertEqual(ref_o.is_leaf, test_o.is_leaf)
420                    ref_is_view_of_non_interm = is_in_base(
421                        ref_o, ref_graph_inps
422                    ) or is_in_base(ref_o, ref_out)
423                    test_is_view_of_non_interm = is_in_base(
424                        test_o, test_graph_inps
425                    ) or is_in_base(test_o, test_out)
426                    self.assertEqual(
427                        ref_is_view_of_non_interm, test_is_view_of_non_interm
428                    )
429                    self.assertEqual(ref_o, test_o)
430                    if test_mutation:
431                        # This tests that autograd meta is set properly on the output we can
432                        # mutate it.
433                        ref_o.add_(2)
434                        test_o.add_(2)
435                        self.assertEqual(ref_o, test_o)
436                        # Reverse the modification
437                        ref_o.sub_(2)
438                        test_o.sub_(2)
439                        self.assertEqual(ref_o, test_o)
440            for ref_i, test_i in zip(ref_inp, test_inp):
441                if isinstance(ref_i, torch.Tensor):
442                    self.assertEqual(ref_i.requires_grad, test_i.requires_grad)
443                self.assertEqual(ref_i, test_i)
444
445        for keep_input_mutations in [True] if keep_inp_mutations else [True, False]:
446            inp, graph_inps = make_inputs(inp_)
447            test_inp, test_graph_inps = make_inputs(inp_)
448            fw_graph_cell = [None]
449            compiled_f = self.run_autograd(
450                f, fw_graph_cell, decompositions, keep_input_mutations, dynamic
451            )
452            ref_results = outs_and_grads(f, graph_inps, inp)
453            test_results = outs_and_grads(compiled_f, test_graph_inps, test_inp)
454
455            check_results(
456                ref_results, test_results, graph_inps, test_graph_inps, inp, test_inp
457            )
458            if isinstance(self, TestAOTAutogradWithCache):
459                # When testing with cache, run compiled_f a second time
460                cached_inp, cached_graph_inps = make_inputs(inp_)
461                cached_results = outs_and_grads(
462                    compiled_f, cached_graph_inps, cached_inp
463                )
464                check_results(
465                    ref_results,
466                    cached_results,
467                    graph_inps,
468                    cached_graph_inps,
469                    inp,
470                    cached_inp,
471                )
472
473        return fw_graph_cell[0]
474
475    def test_non_tensor_and_none_inputs(self):
476        # int, None, Tensor
477        def f(a, b, c):
478            return a * c
479
480        inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=True)]
481        self.verify_aot_autograd(f, inp)
482        inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=False)]
483        self.verify_aot_autograd(f, inp)
484
485    def test_single_output(self):
486        def f(a, b):
487            return a + b
488
489        inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
490        self.verify_aot_autograd(f, inp)
491        inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
492        self.verify_aot_autograd(f, inp)
493
494    def test_multi_output(self):
495        def f(a, b):
496            return a + b, a - b
497
498        inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
499        self.verify_aot_autograd(f, inp)
500        inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
501        self.verify_aot_autograd(f, inp)
502
503    def test_multi_output_list(self):
504        def f(a, b):
505            return [a + b, a - b]
506
507        inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
508        self.verify_aot_autograd(f, inp)
509        inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
510        self.verify_aot_autograd(f, inp)
511
512    # Test for bug occurring at the intersection of fake tensors & functionalization.
513    def test_squeeze_mutation(self):
514        def f(a):
515            b = a.clone().squeeze(-1)
516            b.add_(1.0)
517            return a + b
518
519        inp = [torch.randn(3, 1, requires_grad=True)]
520        self.verify_aot_autograd(f, inp, dynamic=True)
521        inp = [torch.randn(3, 1, requires_grad=False)]
522        self.verify_aot_autograd(f, inp, dynamic=True)
523
524    def test_complex_linear(self):
525        # https://github.com/pytorch/pytorch/issues/93424
526        inp = [torch.randn(1, 10, 10, dtype=torch.complex64)]
527
528        class F(torch.nn.Module):
529            def __init__(self) -> None:
530                super().__init__()
531                self.linear = nn.Linear(10, 10, dtype=torch.complex64)
532
533            def forward(self, x):
534                return self.linear(x).sum().abs()
535
536        self.verify_aot_autograd(F(), inp)
537
538    def test_embedding_bag_view_dynamic(self):
539        # Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper;
540        # test that this works even though the sparse tensor has no storage.
541
542        class F(torch.nn.Module):
543            def __init__(self) -> None:
544                super().__init__()
545                self.emb = torch.nn.EmbeddingBag(100, 8, sparse=True)
546
547            def forward(self, x, y):
548                return self.emb(x, y).view(-1)
549
550        x = torch.arange(3)
551        y = torch.arange(3)
552        self.verify_aot_autograd(F(), [x, y], dynamic=False)
553        self.verify_aot_autograd(F(), [x, y], dynamic=True)
554
555    def test_input_mutation_simple(self):
556        def f(a):
557            a.mul_(2)
558            return a * 3
559
560        inp = [torch.ones(3, 3, requires_grad=True)]
561        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
562        inp = [torch.ones(3, 3, requires_grad=False)]
563        self.verify_aot_autograd(f, inp, test_mutation=True)
564        # Things to note:
565        # - the extra clone is because we need to pass the pre-mutated input to grad(),
566        #   but autograd operates above functionalization so we need to manually clone.
567        #   Hopefully backends can optimize this easily.
568        # - The extra return arg is because the compiled forward returns (mutated inputs + outputs)
569        self.assertExpectedInline(
570            fw_graph.code.strip(),
571            """\
572def forward(self, primals_1):
573    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
574    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
575    mul_1 = torch.ops.aten.mul.Tensor(mul, 3)
576    return (mul, mul_1)""",
577        )
578
579    def test_input_mutation_set__input_mutation(self):
580        def f(a):
581            b = torch.arange(9, dtype=a.dtype).reshape(3, 3)
582            with torch.no_grad():
583                a.set_(b)
584            return a * b
585
586        inp = [torch.ones(3, 3, requires_grad=True)]
587        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
588        inp = [torch.ones(3, 3, requires_grad=False)]
589        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
590
591    def test_set__steals_view_chain(self):
592        def f(a, b):
593            a_ = a.mul(2)
594            b_ = b.mul(2)
595            b_slice = b_[1].view(3, 3)
596            # a_clone should inherit the view chain from b_slice
597            a_.set_(b_slice)
598            # Also mutates b_,
599            a_.view(-1).mul_(2)
600            return a_ * b_slice
601
602        inp = [
603            torch.ones(3, 3, requires_grad=False),
604            torch.zeros(3, 9, requires_grad=False),
605        ]
606        self.verify_aot_autograd(f, inp, keep_inp_mutations=True)
607
608    @skipIfDynamoInput(
609        "Test doesn't make sense with dynamo, which changes order of mutations"
610    )
611    def test_set__and_data_mutation_good(self):
612        def f(a, b):
613            # The data mutation happens *after* the set_(). This is ok (see the graph below)
614            with torch.no_grad():
615                a.set_(b)
616                b.mul_(2)
617            return a + b
618
619        inp = [
620            torch.ones(3, 3, requires_grad=True),
621            torch.ones(3, 3, requires_grad=True),
622        ]
623        fw_graph = self.verify_aot_autograd(
624            f, inp, test_mutation=True, keep_inp_mutations=True
625        )
626        inp = [
627            torch.ones(3, 3, requires_grad=False),
628            torch.zeros(3, 3, requires_grad=False),
629        ]
630        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
631        # Important things to note:
632        # - "return a.set_(b)" desugars into "return b"
633        # - Both a and b are recorded as experiencing mutations,
634        #   which is why we see "b_updated" (output of the mul) twice in the graph outputs.
635        #   a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage).
636        # - the runtime epilogue for a is "a.set_(mul)"
637        # - the runtime epilogue for b is "b.copy_(mul)"
638        self.assertExpectedInline(
639            fw_graph.code.strip(),
640            """\
641def forward(self, primals_1, primals_2):
642    mul = torch.ops.aten.mul.Tensor(primals_2, 2)
643    add = torch.ops.aten.add.Tensor(mul, mul)
644    set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul);  primals_1 = set_ = None
645    copy_ = torch.ops.aten.copy_.default(primals_2, mul);  primals_2 = mul = copy_ = None
646    return (add,)""",
647        )
648
649    # This is a (hopefully) extremely rare case that is difficult to handle,
650    # so we ban it.
651    # https://github.com/pytorch/pytorch/issues/126236
652    # https://github.com/pytorch/pytorch/pull/126113
653    @xfailIfTorchDynamo
654    def test_set__and_data_mutation_bad(self):
655        def f(a):
656            a_view = a.view(-1)
657            tmp = torch.ones(3, 3, requires_grad=True)
658            # Now, any mutations on either tmp
659            # will be tracked as graph input mutations.
660            with torch.no_grad():
661                a.set_(tmp)
662                # BAD: a_view is now detached from every graph input,
663                # so we won't recognize that this caused an input mutation!
664                a_view.mul_(2)
665            return a + tmp
666
667        inp = [torch.ones(3, 3, requires_grad=True)]
668        with self.assertRaisesRegex(
669            RuntimeError, "cannot mutate tensors with frozen storage"
670        ):
671            self.verify_aot_autograd(
672                f, inp, test_mutation=True, keep_inp_mutations=True
673            )
674
675    @skipIfDynamoInput(
676        "Test doesn't make sense with dynamo, which changes order of mutations"
677    )
678    def test_set__not_allowed(self):
679        def f(a, b):
680            with torch.no_grad():
681                a.set_(b)
682            # Mutating a will change a's grad_fn, which requires us to replay the mutation outside of the graph.
683            # We currently ban this today, when the input also received a set_() input mutation.
684            a.mul_(2)
685            return a + b
686
687        inp = [
688            torch.ones(3, 3, requires_grad=True),
689            torch.ones(3, 3, requires_grad=True),
690        ]
691        with self.assertRaisesRegex(
692            AssertionError, "but the input has other mutations that we cannot"
693        ):
694            fw_graph = self.verify_aot_autograd(
695                f, inp, test_mutation=True, keep_inp_mutations=True
696            )
697
698    def test_input_mutation_set__nop(self):
699        def f(a):
700            b = torch.arange(9, dtype=a.dtype)
701            a_old = torch.ops.aten.alias.default(a)
702            with torch.no_grad():
703                a.set_(b)
704                a.set_(a_old)
705            return a + b.reshape(3, 3)
706
707        inp = [torch.ones(3, 3, requires_grad=True)]
708        fw_graph = self.verify_aot_autograd(
709            f, inp, test_mutation=True, keep_inp_mutations=True
710        )
711        inp = [torch.ones(3, 3, requires_grad=False)]
712        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
713        # Things to note:
714        # - There are no set_() calls in the graph (we functionalize a.set_(b) into "b")
715        # - There is only **1** graph output. We properly realized that the two set_() calls
716        #   undo each other, and so effectively no inputs are mutated.
717        self.assertExpectedInline(
718            fw_graph.code.strip(),
719            """\
720def forward(self, primals_1):
721    arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
722    alias = torch.ops.aten.alias.default(primals_1);  primals_1 = None
723    view = torch.ops.aten.view.default(arange, [3, 3]);  arange = None
724    add = torch.ops.aten.add.Tensor(alias, view);  alias = view = None
725    return (add,)""",
726        )
727
728    @unittest.skipIf(IS_WINDOWS, "TODO: need to fix the test case")
729    @unittest.skipIf(IS_MACOS, "TODO: need to fix the test case")
730    def test_input_mutation_fsdp_set__into_same_input(self):
731        import torch.distributed._composable.fsdp._fsdp_param
732
733        def f(a):
734            b = torch.arange(9, dtype=a.dtype).view(3, 3)
735            c = torch.arange(9, dtype=a.dtype).view(3, 3)
736            d = torch.arange(9, dtype=a.dtype).view(3, 3)
737            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
738                torch.ops.fsdp.set_.default(a, b)
739            x = a * a
740            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
741                torch.ops.fsdp.set_.default(a, c)
742            y = a * a
743            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a):
744                torch.ops.fsdp.set_.default(a, c)
745            z = a * a
746            return x + y + z
747
748        inp = [torch.ones(3, 3, requires_grad=True)]
749        fw_graph = self.verify_aot_autograd(
750            f, inp, test_mutation=True, keep_inp_mutations=True
751        )
752        inp = [torch.ones(3, 3, requires_grad=False)]
753        self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
754        """
755        Expected behavior:
756        (1) When there are multiple set_() calls on the same graph input primal_X,
757        we want those set_() calls to all show up with primal_X as the first arg in the graph.
758        (2) Behavior (1) is not the case today with normal aten.set_ (blocked on #129892),
759        but using a custom fsdp.set_ op with no returns is a simple workaround to achieve that behavior.
760        """
761        self.assertExpectedInline(
762            fw_graph.code.strip(),
763            """\
764def forward(self, primals_1):
765    arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
766    view = torch.ops.aten.view.default(arange, [3, 3]);  arange = None
767    arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
768    view_1 = torch.ops.aten.view.default(arange_1, [3, 3]);  arange_1 = None
769    set_ = torch.ops.fsdp.set_.default(primals_1, view);  view = set_ = None
770    mul = torch.ops.aten.mul.Tensor(primals_1, primals_1)
771    set__1 = torch.ops.fsdp.set_.default(primals_1, view_1);  set__1 = None
772    mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1)
773    set__2 = torch.ops.fsdp.set_.default(primals_1, view_1);  view_1 = set__2 = None
774    mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1)
775    add = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
776    add_1 = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None
777    return (add_1, primals_1)""",
778        )
779        self.assertEqual(torch.compile(f, backend="inductor")(*inp), f(*inp))
780
781    def test_input_mutation_simple_with_none_and_nontensor(self):
782        # Tensor, None, int
783        def f(a, b, c):
784            return a * c
785
786        f_compiled = aot_function(f, nop)
787        for req_grad in [True, False]:
788            inp = [torch.ones(3, 3, requires_grad=req_grad), None, 3]
789            out_ref = f(*inp)
790            out_test = f_compiled(*inp)
791            self.assertEqual(out_ref, out_test)
792
793    # https://github.com/pytorch/pytorch/issues/93363
794    def test_mutates_input_noncontiguous(self):
795        def f(a):
796            a.add_(1)
797            return ()
798
799        f_compiled = aot_function(f, nop)
800        ref = torch.ones(4, requires_grad=True) + 0
801        ref_view = ref[0::2]
802
803        test = torch.ones(4, requires_grad=True) + 0
804        test_view = test[0::2]
805
806        out_ref = f(ref_view)
807        out_test = f_compiled(test_view)
808        self.assertEqual(ref, test)
809
810    def test_input_mutation_modifies_autograd_meta_of_aliases(self):
811        def f(a):
812            a.mul_(2)
813            out = a + 1
814            return out.detach()
815
816        x_ref = torch.ones(3, 3, requires_grad=True).clone()
817        x_ref_view = x_ref.view(3, 3)
818
819        x_test = torch.ones(3, 3, requires_grad=True).clone()
820        x_test_view = x_test.view(3, 3)
821
822        f_compiled = aot_function(f, nop, keep_inference_input_mutations=True)
823        f(x_ref)
824        f_compiled(x_test)
825        # f will mutate aliases of the input, including its autograd metadata!
826        # y.grad_fn is AsStridedBackward
827        self.assertEqual(x_ref_view, x_test_view)
828        self.assertEqual(x_ref_view._version, x_test_view._version)
829        self.assertEqual(x_ref_view.grad_fn.__class__, x_test_view.grad_fn.__class__)
830        # Test the actual gradients are correct
831        (x_ref * x_ref_view).sum().backward()
832        (x_test * x_test_view).sum().backward()
833        self.assertEqual(x_ref.grad, x_test.grad)
834        self.assertEqual(x_ref_view.grad, x_test_view.grad)
835
836    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
837    def test_nested_subclasses(self):
838        @torch.compile(backend="aot_eager")
839        def f(x):
840            return x.sin().cos()
841
842        a = torch.ones(4, requires_grad=True)
843        a2 = a.clone().detach().requires_grad_()
844        a3 = a.clone().detach().requires_grad_()
845        a4 = a.clone().detach().requires_grad_()
846        aa = TwoTensor(a, a2)
847        aa2 = TwoTensor(a3, a4)
848        aaaa = TwoTensor(aa, aa2)
849        out = f(aaaa)
850        self.assertTrue(isinstance(out, TwoTensor))
851        self.assertTrue(isinstance(out.a, TwoTensor))
852        self.assertTrue(isinstance(out.b, TwoTensor))
853        self.assertTrue(isinstance(out.a.a, torch.Tensor))
854        self.assertTrue(isinstance(out.a.b, torch.Tensor))
855        self.assertTrue(isinstance(out.b.a, torch.Tensor))
856        self.assertTrue(isinstance(out.b.b, torch.Tensor))
857
858        out.sum().backward()
859        self.assertTrue(isinstance(aaaa.grad, TwoTensor))
860        self.assertTrue(isinstance(aaaa.grad.a, TwoTensor))
861        self.assertTrue(isinstance(aaaa.grad.b, TwoTensor))
862
863    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
864    def test_nested_subclasses_non_nested_grad(self):
865        @torch.compile(backend="aot_eager")
866        def f(x):
867            return x.sin().cos()
868
869        a = torch.ones(4, requires_grad=True)
870        a2 = a.clone().detach().requires_grad_()
871        a3 = a.clone().detach().requires_grad_()
872        a4 = a.clone().detach().requires_grad_()
873        new_aa = TwoTensor(a3, a4)
874        aa = TwoTensor(a, a2)
875
876        aa2 = aa.clone().detach().requires_grad_()
877        aaaa = TwoTensor(aa, aa2)
878        out = f(new_aa)
879        new_out = out + aaaa
880        with self.assertRaisesRegex(
881            RuntimeError,
882            "The grad inputs should be same tensor subclass type as forward output",
883        ):
884            new_out.sum().backward()
885
886    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
887    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
888    def test_custom_tensor_metadata(self):
889        def f(x):
890            x_elem = x.elem
891            x_elem_elem = x_elem.elem
892            x_elem_metadata = x_elem.constant_attribute
893            return x * x_elem * x_elem_elem * x_elem_metadata
894
895        a = torch.ones(4, requires_grad=True)
896        custom_a = ConstantExtraMetadataTensor(a)
897        custom_a.constant_attribute = 6
898        custom_aa = ConstantExtraMetadataTensor(custom_a)
899        custom_aa.constant_attribute = 4
900
901        custom_aa_compile = custom_aa.clone().detach().requires_grad_()
902        custom_aa_compile.elem.constant_attribute = 6
903        out_eager = f(custom_aa)
904
905        compiled_f = torch.compile(f, backend="aot_eager")
906        out = compiled_f(custom_aa_compile)
907
908        self.assertTrue(torch.allclose(out_eager, out))
909
910        out.sum().backward()
911
912        self.assertTrue(isinstance(custom_aa_compile.grad, ConstantExtraMetadataTensor))
913        self.assertTrue(
914            isinstance(custom_aa_compile.grad.elem, ConstantExtraMetadataTensor)
915        )
916
917    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
918    def test_nested_subclasses_complicated_inps(self):
919        def f(x, y, z):
920            temp = x + y
921            temp_plain = x.a + y.b
922            res = temp.sum() + temp_plain.sum()
923            return x.sin().cos() + res
924
925        x = torch.ones(4, requires_grad=True)
926        x2 = x.clone().detach().requires_grad_()
927        xx = TwoTensor(x, x2)
928        xx2 = xx.clone().detach().requires_grad_()
929
930        x_nested = TwoTensor(xx, xx2)
931        x_nested_compile = x_nested.clone().detach().requires_grad_()
932
933        y_nested = x_nested.clone().detach().requires_grad_()
934        y_nested_compile = y_nested.clone().detach().requires_grad_()
935
936        z = x.clone().detach().requires_grad_()
937        z_compile = z.clone().detach().requires_grad_()
938
939        out_eager = f(x_nested, y_nested, z)
940        compiled_f = torch.compile(f, backend="aot_eager")
941        out = compiled_f(x_nested_compile, y_nested_compile, z_compile)
942        self.assertTrue(torch.allclose(out_eager, out))
943
944        self.assertTrue(isinstance(out, TwoTensor))
945        self.assertTrue(isinstance(out.a, TwoTensor))
946        self.assertTrue(isinstance(out.b, TwoTensor))
947        self.assertTrue(isinstance(out.a.a, torch.Tensor))
948        self.assertTrue(isinstance(out.a.b, torch.Tensor))
949        self.assertTrue(isinstance(out.b.a, torch.Tensor))
950        self.assertTrue(isinstance(out.b.b, torch.Tensor))
951
952        out.sum().backward()
953        out_eager.sum().backward()
954
955        self.assertTrue(isinstance(x_nested_compile.grad, TwoTensor))
956        self.assertTrue(isinstance(x_nested_compile.grad.a, TwoTensor))
957        self.assertTrue(isinstance(x_nested_compile.grad.b, TwoTensor))
958
959        self.assertTrue(isinstance(y_nested_compile.grad, TwoTensor))
960        self.assertTrue(isinstance(y_nested_compile.grad.a, TwoTensor))
961        self.assertTrue(isinstance(y_nested_compile.grad.b, TwoTensor))
962
963        self.assertTrue(torch.allclose(x_nested_compile.grad.a.a, x_nested.grad.a.a))
964        self.assertTrue(torch.allclose(x_nested_compile.grad.a.b, x_nested.grad.a.b))
965        self.assertTrue(torch.allclose(y_nested_compile.grad.a.a, y_nested.grad.a.a))
966        self.assertTrue(torch.allclose(y_nested_compile.grad.a.b, y_nested.grad.a.b))
967
968    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
969    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470")
970    def test_nested_subclasses_complicated_inps_mixed(self):
971        def f(x, y):
972            y_elem = y.elem
973            y_elem_elem = y_elem.elem
974            y_elem_metadata = y_elem.constant_attribute
975            return y * y_elem * y_elem_elem * y_elem_metadata + x
976
977        x = torch.ones(4, requires_grad=True)
978        x2 = x.clone().detach().requires_grad_()
979        xx = TwoTensor(x, x2)
980        xx2 = xx.clone().detach().requires_grad_()
981
982        x_nested = TwoTensor(xx, xx2)
983        x_nested_compile = x_nested.clone().detach().requires_grad_()
984
985        a = torch.ones(4, requires_grad=True)
986        custom_a = ConstantExtraMetadataTensor(a)
987        custom_a.constant_attribute = 6
988        custom_aa = ConstantExtraMetadataTensor(custom_a)
989        custom_aa.constant_attribute = 4
990
991        custom_aa_compile = custom_aa.clone().detach().requires_grad_()
992        custom_aa_compile.constant_attribute = 4
993        custom_aa_compile.elem.constant_attribute = 6
994
995        compiled_f = torch.compile(f, backend="aot_eager")
996        out_eager = f(x_nested, custom_aa)
997        out = compiled_f(x_nested_compile, custom_aa_compile)
998        self.assertTrue(torch.allclose(out_eager, out))
999
1000        out.sum().backward()
1001        out_eager.sum().backward()
1002
1003        self.assertTrue(torch.allclose(x_nested_compile.grad, x_nested.grad))
1004        self.assertTrue(torch.allclose(custom_aa_compile.grad, custom_aa.grad))
1005
1006    @skipIfTorchDynamo("This test suite already uses dynamo")
1007    def test_composite_impl_compile(self):
1008        class Foo(torch.nn.Module):
1009            def __init__(self) -> None:
1010                super().__init__()
1011                self.linear = torch.nn.Linear(3, 3)
1012
1013            def forward(self, a):
1014                return self.linear(a)
1015
1016        inp = [torch.ones(3, 3, requires_grad=True)]
1017        fw_graph = self.verify_aot_autograd(Foo(), inp, test_mutation=True)
1018        inp = [torch.ones(3, 3, requires_grad=False)]
1019        self.assertExpectedInline(
1020            fw_graph.code.strip(),
1021            """\
1022def forward(self, primals_1, primals_2, primals_3):
1023    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
1024    addmm = torch.ops.aten.addmm.default(primals_2, primals_3, t);  primals_2 = None
1025    return (addmm, primals_3, t)""",
1026        )
1027
1028        with torch.inference_mode():
1029            fw_graph = self.verify_aot_autograd(Foo(), inp, test_mutation=True)
1030            inp = [torch.ones(3, 3, requires_grad=False)]
1031            self.assertExpectedInline(
1032                fw_graph.code.strip(),
1033                """\
1034def forward(self, arg0_1, arg1_1, arg2_1):
1035    t = torch.ops.aten.t.default(arg0_1);  arg0_1 = None
1036    addmm = torch.ops.aten.addmm.default(arg1_1, arg2_1, t);  arg1_1 = arg2_1 = t = None
1037    return (addmm,)""",
1038            )
1039
1040    def test_outputs_are_aliased(self):
1041        # Tensor, None, int
1042        def f(a):
1043            b = a.mul(2)
1044            c = b.view(-1)
1045            return b, c
1046
1047        f_compiled = aot_function(f, nop)
1048        for req_grad in [True, False]:
1049            inp = torch.ones(3, requires_grad=req_grad)
1050            out_ref = f(inp)
1051            out_test = f_compiled(inp)
1052            self.assertEqual(out_ref[0], out_test[0])
1053            self.assertEqual(out_ref[1], out_test[1])
1054            # Try mutating one of the outputs, which is aliased.
1055            out_ref[0].mul_(3)
1056            out_test[0].mul_(3)
1057            # Assert that the aliasing relationship was preserved
1058            self.assertEqual(out_ref[0], out_test[0])
1059            self.assertEqual(out_ref[1], out_test[1])
1060
1061    def test_input_mutation_is_output(self):
1062        def f(a):
1063            a.mul_(2)
1064            return a
1065
1066        inp = [torch.ones(3, 3, requires_grad=True)]
1067        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1068        inp = [torch.ones(3, 3, requires_grad=False)]
1069        self.verify_aot_autograd(f, inp, test_mutation=True)
1070        self.assertExpectedInline(
1071            fw_graph.code.strip(),
1072            """\
1073def forward(self, primals_1):
1074    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
1075    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
1076    return (mul, mul)""",
1077        )
1078
1079    def test_input_mutation_multiple(self):
1080        def f(a, b, c):
1081            a.mul_(2)
1082            c.mul_(2)
1083            return a + b + c
1084
1085        def create_inp(req_grad):
1086            return [
1087                torch.ones(3, 3, requires_grad=req_grad),
1088                torch.ones(3, 3, requires_grad=req_grad),
1089                torch.ones(3, 3, requires_grad=req_grad),
1090            ]
1091
1092        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
1093
1094        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
1095        self.assertExpectedInline(
1096            fw_graph.code.strip(),
1097            """\
1098def forward(self, primals_1, primals_2, primals_3):
1099    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
1100    clone_1 = torch.ops.aten.clone.default(primals_3);  primals_3 = None
1101    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
1102    mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2);  clone_1 = None
1103    add = torch.ops.aten.add.Tensor(mul, primals_2);  primals_2 = None
1104    add_1 = torch.ops.aten.add.Tensor(add, mul_1);  add = None
1105    return (mul, mul_1, add_1)""",
1106        )
1107
1108    def test_input_mutation_return(self):
1109        def f(a, b):
1110            return torch.sin(a, out=b)
1111
1112        inp = [torch.randn(3, 3), torch.ones(3, 3)]
1113
1114        fw_graph = self.verify_aot_autograd(
1115            f, inp, test_mutation=True, keep_inp_mutations=True
1116        )
1117        self.assertExpectedInline(
1118            fw_graph.code.strip(),
1119            """\
1120def forward(self, arg0_1, arg1_1):
1121    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
1122    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
1123    return (copy_,)""",
1124        )
1125
1126    def test_input_mutation_metadata(self):
1127        def f(a, b):
1128            a.transpose_(1, 0)
1129            return a + b
1130
1131        def create_inp(req_grad):
1132            return [
1133                torch.ones(3, 3, requires_grad=req_grad),
1134                torch.ones(3, 3, requires_grad=req_grad),
1135            ]
1136
1137        self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
1138        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
1139
1140    def test_input_mutation_storage_resize_up(self):
1141        def f(a):
1142            torch.ops.inductor.resize_storage_bytes_(a, 32)
1143            # float32, 4 bytes per element, 32 bytes == 8 elements
1144            with torch.no_grad():
1145                a.copy_(torch.ones(8))
1146            return a + 1
1147
1148        inp = torch.zeros(8, requires_grad=True)
1149        # Input starts with zero-size-storage
1150        inp.untyped_storage().resize_(0)
1151
1152        fw_graph_cell = [None]
1153        compiled_f = aot_function(
1154            f,
1155            fw_compiler=make_boxed_compiler(
1156                partial(extract_graph, graph_cell=fw_graph_cell)
1157            ),
1158            bw_compiler=nop,
1159            decompositions={},
1160            keep_inference_input_mutations=True,
1161            dynamic=False,
1162        )
1163        out = compiled_f(inp)
1164        # Final functionalized graph has two mutation ops:
1165        # (1) a resize_() to resize input tensor up
1166        # (2) a copy_() to fill in the resized input with valid data
1167        self.assertExpectedInline(
1168            fw_graph_cell[0].code.strip(),
1169            """\
1170def forward(self, primals_1):
1171    resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32);  resize_storage_bytes_ = None
1172    ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
1173    copy = torch.ops.aten.copy.default(primals_1, ones);  ones = None
1174    add = torch.ops.aten.add.Tensor(copy, 1)
1175    copy_ = torch.ops.aten.copy_.default(primals_1, copy);  primals_1 = copy = copy_ = None
1176    return (add,)""",
1177        )
1178
1179    def test_input_mutation_storage_resize_down(self):
1180        def f(a):
1181            out = a.sin()
1182            torch.ops.inductor.resize_storage_bytes_(a, 0)
1183            return out
1184
1185        inp = torch.zeros(8, requires_grad=True)
1186
1187        fw_graph_cell = [None]
1188        compiled_f = aot_function(
1189            f,
1190            fw_compiler=make_boxed_compiler(
1191                partial(extract_graph, graph_cell=fw_graph_cell)
1192            ),
1193            bw_compiler=nop,
1194            decompositions={},
1195            keep_inference_input_mutations=True,
1196            dynamic=False,
1197        )
1198        out = compiled_f(inp)
1199        # Final functionalized graph has one mutation ops:
1200        # (1) a resize_() to resize input tensor down
1201        # Even though there was technically a "data mutation" on the input (from a.copy_()),
1202        # We don't include it in the graph since the final input size has zero storage
1203        self.assertExpectedInline(
1204            fw_graph_cell[0].code.strip(),
1205            """\
1206def forward(self, primals_1):
1207    sin = torch.ops.aten.sin.default(primals_1)
1208    resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0);  resize_storage_bytes_ = None
1209    return (sin, primals_1)""",
1210        )
1211
1212    #     def test_input_mutation_storage_resize_up_down(self):
1213    #         def f(a):
1214    #             torch.ops.inductor.resize_storage_bytes_(a, 32)
1215    #             # float32, 4 bytes per element, 32 bytes == 8 elements
1216    #             with torch.no_grad():
1217    #                 a.copy_(torch.ones(8))
1218    #             out = a.sin()
1219    #             torch.ops.inductor.resize_storage_bytes_(a, 0)
1220    #             return out
1221
1222    #         inp = torch.zeros(8, requires_grad=True)
1223    #         # Input starts with zero-size-storage
1224    #         inp.untyped_storage().resize_(0)
1225
1226    #         fw_graph_cell = [None]
1227    #         compiled_f = aot_function(
1228    #             f,
1229    #             fw_compiler=make_boxed_compiler(
1230    #                 partial(extract_graph, graph_cell=fw_graph_cell)
1231    #             ),
1232    #             bw_compiler=nop,
1233    #             decompositions={},
1234    #             keep_inference_input_mutations=True,
1235    #             dynamic=False,
1236    #         )
1237    #         out = compiled_f(inp)
1238    #         # Final graph has two interesting properties:
1239    #         # (1) no resizes in the functional graph, since the two resizes cancel out
1240    #         #     and the final size is zero
1241    #         # (2) no copy_ in the functional graph, even though we copied data into the input,
1242    #         #     because the input has no storage at the end of graph execution (so no data to copy)
1243    #         self.assertExpectedInline(
1244    #             fw_graph_cell[0].code.strip(),
1245    #             """\
1246    # def forward(self, primals_1):
1247    #     ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
1248    #     copy = torch.ops.aten.copy.default(primals_1, ones);  primals_1 = ones = None
1249    #     sin = torch.ops.aten.sin.default(copy)
1250    #     return [sin, copy]""",
1251    #         )
1252
1253    def test_input_mutation_storage_resize_down_and_set_(self):
1254        # Meant to mimic ppFSDP
1255        class TracableCreateParameter(torch.autograd.Function):
1256            @staticmethod
1257            def forward(ctx, tensor, placeholder):
1258                assert not tensor.requires_grad
1259                return placeholder.set_(tensor)
1260
1261            @staticmethod
1262            def backward(ctx, grad):
1263                return None, grad  # grad flows to placeholder
1264
1265        def f(dummy_param, param_shard):
1266            # simulate allgather
1267            with torch.no_grad():
1268                allgather_param = torch.cat([param_shard, param_shard])
1269            # simulate propagating grad state through dummy param, using data of allgather param
1270            dummy_param_with_grad_state = TracableCreateParameter.apply(
1271                allgather_param, dummy_param
1272            )
1273            out = dummy_param.sin()
1274            # Resize out dummy param, which now has the allgather data
1275            torch.ops.inductor.resize_storage_bytes_(dummy_param, 0)
1276            return out
1277
1278        # Simulates the local shard of our param
1279        param_shard = torch.zeros(8, requires_grad=True)
1280        # The dummy, zero-sized allgathered param that autograd will actually compute gradients on
1281        dummy_param = torch.zeros(16, requires_grad=True)
1282        dummy_param.untyped_storage().resize_(0)
1283
1284        fw_graph_cell = [None]
1285        compiled_f = aot_function(
1286            f,
1287            fw_compiler=make_boxed_compiler(
1288                partial(extract_graph, graph_cell=fw_graph_cell)
1289            ),
1290            bw_compiler=nop,
1291            decompositions={},
1292            keep_inference_input_mutations=True,
1293            dynamic=False,
1294        )
1295        out = compiled_f(dummy_param, param_shard)
1296        # Important stuff to point out:
1297        # (1) We save cat for backward (input to the sin()).
1298        #     While the original code was dummy_param.sin(),
1299        #     dummy_param actually contains the `cat` tensor due to the set_() call
1300        # (2) We emit a cat.resize_storage_(0) in the graph.
1301        #     After the set_(), cat is the actually data of dummy_param, which is what we call resize_() on
1302        self.assertExpectedInline(
1303            fw_graph_cell[0].code.strip(),
1304            """\
1305def forward(self, primals_1, primals_2):
1306    cat = torch.ops.aten.cat.default([primals_2, primals_2]);  primals_2 = None
1307    sin = torch.ops.aten.sin.default(cat)
1308    resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0);  resize_storage_bytes_ = None
1309    set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat);  primals_1 = set_ = None
1310    return (sin, cat)""",
1311        )
1312
1313    def test_input_mutation_storage_resize_before_set_(self):
1314        def f(a):
1315            with torch.no_grad():
1316                torch.ops.inductor.resize_storage_bytes_(a, 0)
1317                a.set_(torch.ones(2))
1318
1319        inp = torch.zeros(8, requires_grad=True)
1320
1321        compiled_f = aot_function(
1322            f,
1323            fw_compiler=nop,
1324            bw_compiler=nop,
1325            decompositions={},
1326            keep_inference_input_mutations=True,
1327            dynamic=False,
1328        )
1329        out = compiled_f(inp)
1330
1331    # def test_input_mutation_storage_resize_not_supported(self):
1332    #     def f(a):
1333    #         a.mul_(2)
1334    #         torch.ops.inductor.resize_storage_bytes_(a, 0)
1335    #         return a
1336
1337    #     inp = torch.zeros(8, requires_grad=True)
1338
1339    #     with self.assertRaisesRegex(
1340    #         AssertionError, "the input has other mutations that we cannot"
1341    #     ):
1342    #         compiled_f = aot_function(
1343    #             f,
1344    #             fw_compiler=nop,
1345    #             bw_compiler=nop,
1346    #             decompositions={},
1347    #             keep_inference_input_mutations=True,
1348    #             dynamic=False,
1349    #         )
1350    #         out = compiled_f(inp)
1351
1352    def test_input_output_aliase_custom_autograd_function(self):
1353        class Foo(torch.autograd.Function):
1354            @staticmethod
1355            def forward(ctx, x):
1356                return x
1357
1358            @staticmethod
1359            def backward(ctx, gx):
1360                return gx * 0.5
1361
1362        def f(x):
1363            return Foo.apply(x)
1364
1365        inp = [torch.ones(2, 2, requires_grad=True)]
1366        self.verify_aot_autograd(f, inp, test_mutation=False)
1367
1368    def test_input_mutation_requires_grad_detach(self):
1369        # Here, "a" requires grad, and gets mutated, so we append a copy_() to the end of the graph.
1370        # Its mutation doesn't take part in autograd though, because we mutated a detach'd view.
1371        # Need to make sure that this copy_() doesn't error, and doesn't participate in autograd either.
1372        def f(a):
1373            a.detach().mul_(2)
1374            return a + 3
1375
1376        inp = [torch.ones(4, requires_grad=True)]
1377        self.verify_aot_autograd(f, inp, test_mutation=False)
1378        inp = [torch.ones(4, requires_grad=True)]
1379        # test_mutation=True will first do some compute on inp, so it is no longer an autograd leaf
1380        # by the time it becomes a graph input. Good to test both cases.
1381        self.verify_aot_autograd(f, inp, test_mutation=True)
1382
1383    def test_input_mutation_hidden_from_autograd_aliasing(self):
1384        def f(a):
1385            a_alias = a.view(-1)
1386            with torch.no_grad():
1387                a_alias.mul_(2)
1388            return a + 1
1389
1390        inp = [torch.ones(4, requires_grad=True)]
1391        # The important bit: we detected that the input mutation is safe
1392        # to include **inside** the graph, since it was under no_grad
1393        # (so all we need to do is use mark_dirty() on the input to bump the VC)
1394        fw_graph = self.verify_aot_autograd(
1395            f, inp, test_mutation=True, keep_inp_mutations=True
1396        )
1397        self.assertExpectedInline(
1398            fw_graph.code.strip(),
1399            """\
1400def forward(self, primals_1):
1401    view = torch.ops.aten.view.default(primals_1, [-1])
1402    mul = torch.ops.aten.mul.Tensor(view, 2);  view = None
1403    view_1 = torch.ops.aten.view.default(mul, [4]);  mul = None
1404    add = torch.ops.aten.add.Tensor(view_1, 1)
1405    copy_ = torch.ops.aten.copy_.default(primals_1, view_1);  primals_1 = view_1 = copy_ = None
1406    return (add,)""",
1407        )
1408
1409    def test_input_mutation_requires_grad_no_grad(self):
1410        def f(a):
1411            with torch.no_grad():
1412                a.mul_(2)
1413            return a + 3
1414
1415        inp = [torch.ones(4, requires_grad=True)]
1416        fw_graph = self.verify_aot_autograd(
1417            f, inp, test_mutation=True, keep_inp_mutations=True
1418        )
1419        # Even though the input requires_grad, we expect the keep the input mutation in the graph
1420        # (Even though this is a training graph!)
1421        self.assertExpectedInline(
1422            fw_graph.code.strip(),
1423            """\
1424def forward(self, primals_1):
1425    mul = torch.ops.aten.mul.Tensor(primals_1, 2)
1426    add = torch.ops.aten.add.Tensor(mul, 3)
1427    copy_ = torch.ops.aten.copy_.default(primals_1, mul);  primals_1 = mul = copy_ = None
1428    return (add,)""",
1429        )
1430
1431    def test_input_mutation_requires_grad_no_grad_inference_graph(self):
1432        def f(a):
1433            with torch.no_grad():
1434                a.mul_(2)
1435                return a + 3
1436
1437        inp = [torch.ones(4, requires_grad=True)]
1438        # Even though the input requires_grad, we expect the keep the input mutation in the graph
1439        fw_graph = self.verify_aot_autograd(
1440            f, inp, test_mutation=True, keep_inp_mutations=True
1441        )
1442
1443        self.assertExpectedInline(
1444            fw_graph.code.strip(),
1445            """\
1446def forward(self, arg0_1):
1447    mul = torch.ops.aten.mul.Tensor(arg0_1, 2)
1448    add = torch.ops.aten.add.Tensor(mul, 3)
1449    copy_ = torch.ops.aten.copy_.default(arg0_1, mul);  arg0_1 = mul = copy_ = None
1450    return (add,)""",
1451        )
1452
1453    def test_input_mutation_requires_grad_no_grad_detach_mixed(self):
1454        # Perform a mix of mutations on a:
1455        # 1 normal, 1 in no_grad, 1 on a detach'd tensor.
1456        # Only the first should participate in gradient computation.
1457        def f(a):
1458            a.detach().mul_(2)
1459            a.mul_(3)
1460            with torch.no_grad():
1461                a.mul_(4)
1462            return a + 5
1463
1464        inp = [torch.ones(4, requires_grad=True)]
1465        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1466
1467    def test_input_mutation_metadata2(self):
1468        def f(a):
1469            a.transpose_(1, 0)
1470            a.mul_(2)
1471            return a + 1
1472
1473        inp = [torch.ones(3, 3, requires_grad=True)]
1474        self.verify_aot_autograd(f, inp, test_mutation=True)
1475        inp = [torch.ones(3, 3, requires_grad=False)]
1476        self.verify_aot_autograd(f, inp, test_mutation=True)
1477
1478    def test_input_mutation_batchnorm(self):
1479        def f(inpt, weight, bias, running_mean, running_var):
1480            # This is additionally a good test, because the input tensors that we mutate
1481            # are *also* saved for backwards.
1482            # This tests that what we save for the backward is actually cloned inputs,
1483            # and not the original inputs that got mutated.
1484            return torch._native_batch_norm_legit(
1485                inpt, weight, bias, running_mean, running_var, True, 0.5, 1e-5
1486            )
1487
1488        def create_inp(req_grad):
1489            return [
1490                torch.ones(2, 5, 5, 5, requires_grad=req_grad),
1491                torch.ones(5, requires_grad=req_grad),
1492                torch.ones(5, requires_grad=req_grad),
1493                torch.ones(5),
1494                torch.ones(5),
1495            ]
1496
1497        from torch._decomp import get_decompositions
1498
1499        # This simulates what inductor does (running the fw + bw decompositions)
1500        decompositions = get_decompositions(
1501            [
1502                torch.ops.aten._native_batch_norm_legit_functional,
1503                torch.ops.aten.native_batch_norm_backward,
1504            ]
1505        )
1506        self.verify_aot_autograd(
1507            f, create_inp(True), test_mutation=True, decompositions=decompositions
1508        )
1509        self.verify_aot_autograd(
1510            f, create_inp(False), test_mutation=True, decompositions=decompositions
1511        )
1512
1513    def test_batchnorm_inference(self):
1514        inp = [
1515            torch.ones(2, 5, 5, 5, requires_grad=True),
1516            torch.ones(5, requires_grad=True),
1517            torch.ones(5, requires_grad=True),
1518            torch.ones(5),
1519            torch.ones(5),
1520        ]
1521
1522        m = torch.nn.BatchNorm2d(4, 4)
1523        m.eval()
1524        fw_graph_cell = [None]
1525        inp = torch.ones(4, 4, 4, 4)
1526        fw_graph_cell = [None]
1527        compiled_m = aot_module(
1528            m,
1529            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
1530            bw_compiler=nop,
1531            keep_inference_input_mutations=True,
1532        )
1533        inp = torch.ones(4, 4, 4, 4)
1534        with torch.no_grad():
1535            out = compiled_m(inp)
1536        # expectation: there are no copy_() calls in the decomposed batch norm when running under training=False (eval mode)
1537        code = fw_graph_cell[0].code.strip()
1538        self.assertTrue("copy_" not in str(code))
1539
1540    def test_input_output_view_simple(self):
1541        def f(a):
1542            return a.view(-1)
1543
1544        inp = [torch.ones(2, 2, requires_grad=False).add(1)]
1545        self.verify_aot_autograd(f, inp, test_mutation=True)
1546        inp = [torch.ones(2, 2, requires_grad=True).add(1)]
1547        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1548        # Outputs that alias inputs are pulled out of the graph entirely, so we don't compile anything here
1549        self.assertExpectedInline(
1550            fw_graph.code.strip(),
1551            """\
1552def forward(self, arg0_1):
1553    view = torch.ops.aten.view.default(arg0_1, [-1]);  arg0_1 = None
1554    return (view,)""",
1555        )
1556
1557    def test_input_output_view_mutate_multiple(self):
1558        def f(a, b, c):
1559            a.mul_(2)
1560            c.mul_(3)
1561            return b.view(2, 2), c.view(2, 2)
1562
1563        def create_inp(req_grad):
1564            return [
1565                torch.ones(2, 2, requires_grad=req_grad).add(1),
1566                torch.ones(2, 2, requires_grad=req_grad).add(1),
1567                torch.ones(2, 2, requires_grad=req_grad).add(1),
1568            ]
1569
1570        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
1571        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
1572        # The original function returned two outputs, both of which aliased inputs.
1573        # We expect two outputs in the functional graph, a_updated and c_updated.
1574        # The actual aliased outputs themselves aren't in the compiled forward graph;
1575        # Instead, they're generated outside of  the graph.
1576        self.assertExpectedInline(
1577            fw_graph.code.strip(),
1578            """\
1579def forward(self, primals_1, primals_2, primals_3):
1580    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
1581    clone_1 = torch.ops.aten.clone.default(primals_3);  primals_3 = None
1582    mul = torch.ops.aten.mul.Tensor(clone, 2);  clone = None
1583    mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3);  clone_1 = None
1584    view = torch.ops.aten.view.default(primals_2, [2, 2]);  primals_2 = None
1585    view_2 = torch.ops.aten.view.default(mul_1, [2, 2])
1586    return (mul, mul_1, view, view_2)""",
1587        )
1588
1589    def test_input_output_view_metadata_mutate_multiple(self):
1590        def f(a, b, c):
1591            b.mul_(3)
1592            c.t_()
1593            return a.view(2, 2), b.view(2, 2), c.view(2, 2)
1594
1595        def create_inp(req_grad):
1596            return [
1597                torch.ones(2, 2, requires_grad=req_grad).add(1),
1598                torch.ones(2, 2, requires_grad=req_grad).add(1),
1599                torch.ones(2, 2, requires_grad=req_grad).add(1),
1600            ]
1601
1602        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
1603        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
1604        # Important thing to check here: of the three inputs:
1605        # Only the b.mul_(3) should show up in the graph (we functionalize it and return it).
1606        # Everything else that does not show up in the graph includes:
1607        # - The metadata mutation on c (we do it outside the graph)
1608        # - All 3 original fw outputs, which are aliases of inputs (we regenerate them outside of the graph)
1609        self.assertExpectedInline(
1610            fw_graph.code.strip(),
1611            """\
1612def forward(self, primals_1, primals_2, primals_3):
1613    clone = torch.ops.aten.clone.default(primals_2);  primals_2 = None
1614    view = torch.ops.aten.view.default(primals_3, [2, 2]);  primals_3 = None
1615    mul = torch.ops.aten.mul.Tensor(clone, 3);  clone = None
1616    t = torch.ops.aten.t.default(view);  view = None
1617    view_1 = torch.ops.aten.view.default(primals_1, [2, 2]);  primals_1 = None
1618    view_3 = torch.ops.aten.view.default(t, [2, 2])
1619    view_4 = torch.ops.aten.view.default(mul, [2, 2])
1620    return (mul, t, view_1, view_4, view_3)""",
1621        )
1622
1623    def test_input_mutation_and_output_view(self):
1624        def f(a):
1625            a.add_(1)
1626            return a.view(-1)
1627
1628        inp = [torch.ones(2, 2, requires_grad=False).add(1)]
1629        self.verify_aot_autograd(f, inp, test_mutation=True)
1630        inp = [torch.ones(2, 2, requires_grad=True).add(1)]
1631        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1632        # Here, total # of outputs is 1 because:
1633        # - num_mutated_inps = 1 (a_updated)
1634        # - num_fw_outputs = 0 (the output is an alias of the input, so we move it outside the compiled fw)
1635        self.assertExpectedInline(
1636            fw_graph.code.strip(),
1637            """\
1638def forward(self, primals_1):
1639    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
1640    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
1641    view_1 = torch.ops.aten.view.default(add, [-1])
1642    return (add, view_1)""",
1643        )
1644
1645    def test_input_mutation_output_view_multiple(self):
1646        def f(a, b, c, d):
1647            b.transpose_(1, 0)
1648            c.add_(1)
1649            return d + 1, b.diagonal(), a + c
1650
1651        def create_inp(req_grad):
1652            return [
1653                torch.arange(4, requires_grad=req_grad, dtype=torch.float32)
1654                .view(2, 2)
1655                .add(1),
1656                torch.arange(4, requires_grad=req_grad, dtype=torch.float32)
1657                .view(2, 2)
1658                .add(1),
1659                torch.ones(2, 2, requires_grad=req_grad).add(1),
1660                torch.ones(2, 2, requires_grad=req_grad).add(1),
1661            ]
1662
1663        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
1664        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
1665        self.assertExpectedInline(
1666            fw_graph.code.strip(),
1667            """\
1668def forward(self, primals_1, primals_2, primals_3, primals_4):
1669    view = torch.ops.aten.view.default(primals_2, [2, 2]);  primals_2 = None
1670    clone = torch.ops.aten.clone.default(primals_3);  primals_3 = None
1671    transpose = torch.ops.aten.transpose.int(view, 1, 0);  view = None
1672    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
1673    add_1 = torch.ops.aten.add.Tensor(primals_4, 1);  primals_4 = None
1674    diagonal = torch.ops.aten.diagonal.default(transpose)
1675    add_2 = torch.ops.aten.add.Tensor(primals_1, add);  primals_1 = None
1676    return (transpose, add, add_1, diagonal, add_2)""",
1677        )
1678
1679    def test_output_aliases_intermediate_single(self):
1680        def f(a):
1681            out = torch.mul(a, 3)
1682            return out.view(-1)
1683
1684        inp = [torch.ones(3, 3, requires_grad=False)]
1685        self.verify_aot_autograd(f, inp, test_mutation=True)
1686        inp = [torch.ones(3, 3, requires_grad=True)]
1687        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1688        # In AOTAutograd, we are obligated to make the compiled forward directly return `out`,
1689        # and reconstruct `out.view(-1)` as a fresh output.
1690        self.assertExpectedInline(
1691            fw_graph.code.strip(),
1692            """\
1693def forward(self, primals_1):
1694    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
1695    view = torch.ops.aten.view.default(mul, [-1]);  mul = None
1696    return (view,)""",
1697        )
1698
1699    def test_output_aliases_input_multi_output_view_should_raise_autograd_error(self):
1700        def f1(a):
1701            return list(a.unbind(0))
1702
1703        f1_compiled = aot_function(f1, nop)
1704
1705        inp1 = torch.ones(3, 3, requires_grad=True).clone()
1706        inp2 = torch.ones(3, 3, requires_grad=True).clone()
1707        inp3 = torch.ones(3, 3, requires_grad=True).clone()
1708
1709        with self.assertRaisesRegex(
1710            RuntimeError, "Such functions do not allow the output views"
1711        ):
1712            out_test1 = f1_compiled(inp1)
1713            # This raises a runtime error from autograd in eager mode
1714            out_test1[0].mul_(2)
1715
1716        with self.assertRaisesRegex(
1717            RuntimeError, "Such functions do not allow the output views"
1718        ):
1719            out_test2 = f1_compiled(inp2)
1720            inp2.mul_(2)
1721            # In eager mode, if we mutate a tensor, any multi-output-view aliases
1722            # get their grad_fn replaced with error nodes, so accessing grad_fn should error
1723            grad_fn = out_test2[0].grad_fn
1724
1725        with self.assertRaisesRegex(
1726            RuntimeError, "Such functions do not allow the output views"
1727        ):
1728            out_test3 = f1_compiled(inp3)
1729            out_test1[0].detach().mul_(2)
1730            # The above case also applies to detached aliases (they turn the multi-output-view
1731            # alias's grad_fns into error nodes)
1732            grad_fn = out_test2[0].grad_fn
1733
1734    def test_output_aliases_input_multi_output_view(self):
1735        # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
1736        def f1(a):
1737            return list(a.unbind(0))
1738
1739        inp = torch.ones(3, 3, requires_grad=True)
1740        inp_ref = torch.ones(3, 3, requires_grad=True)
1741        f1_compiled = aot_function(f1, nop)
1742
1743        out_ref = f1(inp_ref)
1744        out_test = f1_compiled(inp)
1745        # Assert that we get CompiledFunctionBackward in the backward graph,
1746        # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case.
1747        # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
1748        self.assertTrue(
1749            all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test)
1750        )
1751
1752        sum(out_ref).sum().backward()
1753        sum(out_test).sum().backward()
1754        self.assertEqual(inp_ref.grad, inp.grad)
1755
1756        # Several of the outputs are from multi-output views.
1757        # However: they are part of the same alias set as "a", and "a.view(out.shape)",
1758        # which are both user-visible.
1759        # AOTAutograd will not try to be smart here and hide the aliasing relationships from autograd.
1760        # Instead, it will perform its "output aliases input" logic, and regenerate all aliases.
1761        def f3(a):
1762            return *list(a.unbind(0)), a.view(a.shape)
1763
1764        inp = torch.ones(3, 3, requires_grad=True)
1765        inp_ref = torch.ones(3, 3, requires_grad=True)
1766        f3_compiled = aot_function(f3, nop)
1767
1768        inp_ref_clone = inp_ref.clone()
1769        inp_clone = inp.clone()
1770        out_ref = f3(inp_ref_clone)
1771        out_test = f3_compiled(inp_clone)
1772        self.assertTrue(all("UnbindBackward" in str(o.grad_fn) for o in out_test[:3]))
1773
1774        # The last output is not from a multi-output view, so autograd will let us mutate it.
1775        out_ref[-1].mul_(2)
1776        out_test[-1].mul_(2)
1777        # Also mutate the input, which should affect the aliased output.
1778        inp_ref_clone.view(-1).mul_(3)
1779        inp_clone.view(-1).mul_(3)
1780        # Do backward
1781        (inp_ref + out_ref[-1]).sum().backward()
1782        (inp + out_test[-1]).sum().backward()
1783        self.assertEqual(inp_ref.grad, inp.grad)
1784
1785    def test_output_aliases_intermediate_multi_output_view(self):
1786        # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
1787        def f1(a):
1788            out = torch.mul(a, 3)
1789            return list(out.unbind(0))
1790
1791        inp = torch.ones(3, 3, requires_grad=True)
1792        inp_ref = torch.ones(3, 3, requires_grad=True)
1793        f1_compiled = aot_function(f1, nop)
1794
1795        out_ref = f1(inp_ref)
1796        out_test = f1_compiled(inp)
1797        # Assert that we get CompiledFunctionBackward in the backward graph,
1798        # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case.
1799        # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
1800        self.assertTrue(
1801            all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test)
1802        )
1803
1804        sum(out_ref).sum().backward()
1805        sum(out_test).sum().backward()
1806        self.assertEqual(inp_ref.grad, inp.grad)
1807
1808        # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
1809        def f2(a):
1810            out = torch.mul(a, 3)
1811            return *list(out.unbind(0)), out
1812
1813        inp = torch.ones(3, 3, requires_grad=True)
1814        inp_ref = torch.ones(3, 3, requires_grad=True)
1815        f2_compiled = aot_function(f2, nop)
1816
1817        out_ref = f2(inp_ref)
1818        out_test = f2_compiled(inp)
1819        # Assert that we get CompiledFunctionBackward in the backward graph,
1820        # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case.
1821        # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
1822        self.assertTrue(
1823            all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test)
1824        )
1825
1826        # The last output is not from a multi-output view, so autograd will let us mutate it.
1827        out_ref[-1].mul_(2)
1828        out_test[-1].mul_(2)
1829        out_ref[-1].sum().backward()
1830        out_test[-1].sum().backward()
1831        self.assertEqual(inp_ref.grad, inp.grad)
1832
1833        # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
1834        def f3(a):
1835            out = torch.mul(a, 3)
1836            return *list(out.unbind(0)), out.view(out.shape)
1837
1838        inp = torch.ones(3, 3, requires_grad=True)
1839        inp_ref = torch.ones(3, 3, requires_grad=True)
1840        f3_compiled = aot_function(f3, nop)
1841
1842        out_ref = f3(inp_ref)
1843        out_test = f3_compiled(inp)
1844        # Assert that we get CompiledFunctionBackward in the backward graph,
1845        # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case.
1846        # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
1847        self.assertTrue(
1848            all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test)
1849        )
1850
1851        # The last output is not from a multi-output view, so autograd will let us mutate it.
1852        out_ref[-1].mul_(2)
1853        out_test[-1].mul_(2)
1854        out_ref[-1].sum().backward()
1855        out_test[-1].sum().backward()
1856        self.assertEqual(inp_ref.grad, inp.grad)
1857
1858        # There are 5 outputs that all alias each other.
1859        # 3 of them come from multi-output views, but the other 3 are "ordinary" aliases.
1860        # Therefore, AOTAutograd will not attempt the multi-output-view optimization,
1861        # and apply the intermediate_base logic to all aliases.
1862        # (In theory we could probably get AOTAutograd to only apply the intermediate base
1863        # logic to the last 2 outputs and not the first 3. We should probably
1864        # just do the graph partitioning defined in this doc instead though).
1865        # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit
1866        def f4(a):
1867            out = torch.mul(a, 3)
1868            # also return the graph intermediate directly,
1869            # which will force AOTAutograd to do the "intermediate base" logic.
1870            # (Why? The user can mutate "out", which should change the autograd metadata
1871            #  of the other aliased outputs)
1872            return *list(out.unbind(0)), out, out.view(out.shape)
1873
1874        inp = torch.ones(3, 3, requires_grad=True)
1875        inp_ref = torch.ones(3, 3, requires_grad=True)
1876        f4_compiled = aot_function(f4, nop)
1877
1878        out_ref = f4(inp_ref)
1879        out_test = f4_compiled(inp)
1880        # Mutate the last output of f4 (autograd will allow this, since it is not a multi-output view,
1881        # as long as *only* the non-multi-output views participate in the backward)
1882        # Note: We could probably try to hide **only** the multi-output views from autograd here
1883        # and only do the intermediate base logic for the last two aliases.
1884        # Longer term solution of graph partitioning is probably cleaner though (see the note).
1885        out_ref[-1].mul_(2)
1886        out_test[-1].mul_(2)
1887
1888        out_ref_sum = out_ref[-1] + out_ref[-2]
1889        out_test_sum = out_test[-1] + out_test[-2]
1890        out_ref_sum.sum().backward()
1891        out_test_sum.sum().backward()
1892        self.assertEqual(inp_ref.grad, inp.grad)
1893
1894    def test_output_aliases_intermediate_mutation_linear(self):
1895        def f(x):
1896            return (x + 1).view(-1)
1897
1898        inp = [torch.ones(3, 3, requires_grad=True)]
1899        # use inductor's decomps (which will e.g. turn _unsafe_view() into view())
1900        from torch._inductor.decomposition import decompositions
1901
1902        f_compiled = aot_function(f, nop, decompositions=decompositions)
1903
1904        out_ref = f(*inp)
1905        out_test = f_compiled(*inp)
1906
1907        out_ref.mul_(2)
1908        out_test.mul_(2)
1909        self.assertEqual(out_ref, out_test)
1910
1911    def test_output_aliases_intermediate_no_grad(self):
1912        def f(a, b):
1913            out = torch.mul(a, 3)
1914            # First output is an alias of an intermediate that doesn't require grad
1915            return out.view(-1), b.add(1)
1916
1917        inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=False)]
1918        self.verify_aot_autograd(f, inp, test_mutation=True)
1919        inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=True)]
1920        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1921        # important bit: we don't bother generating an intermediate base as an output in the graph,
1922        # because the intermediate base itself didn't require gradients.
1923        # (the only problematic case is when both the base and the aliasesed output require gradients).
1924        self.assertExpectedInline(
1925            fw_graph.code.strip(),
1926            """\
1927def forward(self, primals_1, primals_2):
1928    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
1929    view = torch.ops.aten.view.default(mul, [-1]);  mul = None
1930    add = torch.ops.aten.add.Tensor(primals_2, 1);  primals_2 = None
1931    return (view, add)""",
1932        )
1933
1934    def test_output_aliases_intermediate_returned_multiple_times(self):
1935        def f(a):
1936            out = torch.mul(a, 3)
1937            out_view = out.view(-1)
1938            return out, out_view, out
1939
1940        inp = [torch.ones(3, 3, requires_grad=False)]
1941        self.verify_aot_autograd(f, inp, test_mutation=True)
1942        inp = [torch.ones(3, 3, requires_grad=True)]
1943        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1944
1945    def test_output_aliases_intermediate_multiple(self):
1946        def f(a):
1947            out = torch.mul(a, 3)
1948            # AOTAutograd should manually generate these two output views in the epilogue.
1949            return out.view(-1), out.view(-1)
1950
1951        inp = [torch.ones(3, 3, requires_grad=False)]
1952        self.verify_aot_autograd(f, inp, test_mutation=True)
1953        inp = [torch.ones(3, 3, requires_grad=True)]
1954        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1955        self.assertExpectedInline(
1956            fw_graph.code.strip(),
1957            """\
1958def forward(self, primals_1):
1959    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
1960    view = torch.ops.aten.view.default(mul, [-1])
1961    view_1 = torch.ops.aten.view.default(mul, [-1])
1962    return (view, view_1, mul)""",
1963        )
1964
1965    def test_output_aliases_intermediate_and_returned(self):
1966        def f(a):
1967            out = torch.mul(a, 3)
1968            # AOTAutograd should manually generate the first output (a view of an intermediate)
1969            # but not the second (which is itself the intermediate for the first)
1970            return out.view(-1), out
1971
1972        inp = [torch.ones(3, 3, requires_grad=False)]
1973        self.verify_aot_autograd(f, inp, test_mutation=True)
1974        inp = [torch.ones(3, 3, requires_grad=True)]
1975        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1976        self.assertExpectedInline(
1977            fw_graph.code.strip(),
1978            """\
1979def forward(self, primals_1):
1980    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
1981    view = torch.ops.aten.view.default(mul, [-1])
1982    return (view, mul)""",
1983        )
1984
1985    def test_output_aliases_intermediate_and_returned_flipped(self):
1986        def f(a):
1987            out = torch.mul(a, 3)
1988            # AOTAutograd should manually generate the first output (a view of an intermediate)
1989            # but not the second (which is itself the intermediate for the first)
1990            return out, out.view(-1)
1991
1992        inp = [torch.ones(3, 3, requires_grad=False)]
1993        self.verify_aot_autograd(f, inp, test_mutation=True)
1994        inp = [torch.ones(3, 3, requires_grad=True)]
1995        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
1996        self.assertExpectedInline(
1997            fw_graph.code.strip(),
1998            """\
1999def forward(self, primals_1):
2000    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
2001    view = torch.ops.aten.view.default(mul, [-1])
2002    return (mul, view)""",
2003        )
2004
2005    def test_output_aliases_intermediate_and_returned_different_grad(self):
2006        def f(a):
2007            out = torch.mul(a, 3)
2008            # AOTAutograd should manually generate the first output (a view of an intermediate)
2009            # but not the second (which is itself the intermediate for the first)
2010            return out.view(-1), out, out[0].detach()
2011
2012        inp = [torch.ones(3, 3, requires_grad=False)]
2013        self.verify_aot_autograd(f, inp, test_mutation=True)
2014        inp = [torch.ones(3, 3, requires_grad=True)]
2015        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
2016        self.assertExpectedInline(
2017            fw_graph.code.strip(),
2018            """\
2019def forward(self, primals_1):
2020    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
2021    view = torch.ops.aten.view.default(mul, [-1])
2022    select = torch.ops.aten.select.int(mul, 0, 0)
2023    detach = torch.ops.aten.detach.default(select);  select = None
2024    detach_1 = torch.ops.aten.detach.default(detach);  detach = None
2025    detach_2 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
2026    return (view, mul, detach_2)""",
2027        )
2028
2029    def test_output_aliases_intermediate_inplace_view(self):
2030        def f(a):
2031            out = torch.mul(a, 3)
2032            out.t_()
2033            return out
2034
2035        inp = [torch.ones(2, 4, requires_grad=True)]
2036
2037        # TODO: fix this test.
2038        # See https://github.com/pytorch/pytorch/issues/90507
2039        # self.verify_aot_autograd(f, inp, test_mutation=True)
2040
2041    def test_output_aliases_intermediate_inplace_view_with_detach(self):
2042        def f(a):
2043            out = torch.mul(a, 3)
2044            out.t_()
2045            out.detach_()
2046            # Thanks to the detach_() AOT Autograd doesn't need to do anything.
2047            # `out` will show up as having OutputType.non_alias,
2048            # and ._is_view() == False
2049            return out, a + 1
2050
2051        inp = [torch.ones(2, 4, requires_grad=False)]
2052        self.verify_aot_autograd(f, inp, test_mutation=True)
2053        inp = [torch.ones(2, 4, requires_grad=True)]
2054        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
2055        self.assertExpectedInline(
2056            fw_graph.code.strip(),
2057            """\
2058def forward(self, primals_1):
2059    mul = torch.ops.aten.mul.Tensor(primals_1, 3)
2060    t = torch.ops.aten.t.default(mul);  mul = None
2061    add = torch.ops.aten.add.Tensor(primals_1, 1);  primals_1 = None
2062    return (t, add)""",
2063        )
2064
2065    def test_output_aliases_intermediate_inplace_view_and_view(self):
2066        def f(a):
2067            out = torch.mul(a, 3)
2068            out_view = out.unsqueeze(0)
2069            out.t_()
2070            out_view2 = out.unsqueeze(0)
2071            return out_view, out, out_view2
2072
2073        inp = [torch.ones(2, 4, requires_grad=True)]
2074
2075        # TODO: fix this test.
2076        # See <github issue link>
2077        # self.verify_aot_autograd(f, inp, test_mutation=True)
2078
2079    def test_output_aliases_intermediate_multiple_mixed(self):
2080        def f(a):
2081            out1 = torch.mul(a, 3)
2082            out2 = torch.mul(a, 4)
2083            # AOTAutograd should manually generate these two output views in the epilogue.
2084            return out1.view(-1), out2.transpose(1, 0), out1.transpose(1, 0)
2085
2086        inp = [torch.ones(3, 3, requires_grad=False)]
2087        self.verify_aot_autograd(f, inp, test_mutation=True)
2088        inp = [torch.ones(3, 3, requires_grad=True)]
2089        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
2090        self.assertExpectedInline(
2091            fw_graph.code.strip(),
2092            """\
2093def forward(self, primals_1):
2094    mul = torch.ops.aten.mul.Tensor(primals_1, 3)
2095    mul_1 = torch.ops.aten.mul.Tensor(primals_1, 4);  primals_1 = None
2096    view = torch.ops.aten.view.default(mul, [-1])
2097    transpose = torch.ops.aten.transpose.int(mul_1, 1, 0);  mul_1 = None
2098    transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0)
2099    return (view, transpose, transpose_1, mul)""",
2100        )
2101
2102    def test_output_all_alias_types(self):
2103        # There are 3 types of aliasing that require us to return metadata in the compiled fw:
2104        # (1) outputs that are views of inputs
2105        # (2) outputs that are views of intermediates
2106        # (3) inputs that get metadata mutations
2107        # test all 3 of them here
2108        def f(a):
2109            a.transpose_(1, 0)
2110            tmp = a.mul(2)
2111            return tmp.squeeze(), tmp.transpose(1, 0), a.unsqueeze(0)
2112
2113        def inp_callable(req_grad):
2114            x = torch.ones(1, 2, 4, requires_grad=req_grad).clone()
2115            return [(x,), (x,)]
2116
2117        self.verify_aot_autograd(
2118            f, partial(inp_callable, req_grad=False), test_mutation=True
2119        )
2120        fw_graph = self.verify_aot_autograd(
2121            f, partial(inp_callable, req_grad=True), test_mutation=True
2122        )
2123        # TODO: make this test run with dynamic shapes so it is more meaningful
2124        # metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta)
2125        self.assertExpectedInline(
2126            fw_graph.code.strip(),
2127            """\
2128def forward(self, primals_1):
2129    view = torch.ops.aten.view.default(primals_1, [1, 2, 4]);  primals_1 = None
2130    transpose = torch.ops.aten.transpose.int(view, 1, 0);  view = None
2131    mul = torch.ops.aten.mul.Tensor(transpose, 2)
2132    squeeze = torch.ops.aten.squeeze.default(mul)
2133    transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0)
2134    unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0)
2135    return (transpose, squeeze, transpose_1, unsqueeze, mul)""",
2136        )
2137
2138    @parametrize("req_grad", [False, True])
2139    def test_subclass_metadata_mutation(self, req_grad):
2140        def f(a):
2141            a.transpose_(1, 0)
2142            tmp = a.mul(2)
2143            return tmp.transpose(1, 0)
2144
2145        def inp_callable(req_grad):
2146            x = torch.ones(1, 2, 4, requires_grad=req_grad).clone()
2147            return [(x,), (x,)]
2148
2149        # See https://github.com/pytorch/pytorch/issues/114975
2150        with self.assertRaisesRegex(
2151            RuntimeError,
2152            "Metadata mutations are currently not allowed on tensor subclasses",
2153        ):
2154            self.verify_aot_autograd(
2155                f,
2156                partial(inp_callable, req_grad=req_grad),
2157                test_mutation=True,
2158                make_inputs_subclasses=True,
2159            )
2160
2161    def test_input_data_and_metadata_mutation(self):
2162        def f(a):
2163            a.t_()
2164            a[0].mul_(2)
2165            return a.view(a.shape)
2166
2167        inp = [torch.ones(3, 3, requires_grad=False)]
2168        self.verify_aot_autograd(f, inp, test_mutation=True)
2169        inp = [torch.ones(3, 3, requires_grad=True)]
2170        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
2171        self.assertExpectedInline(
2172            fw_graph.code.strip(),
2173            """\
2174def forward(self, primals_1):
2175    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
2176    t = torch.ops.aten.t.default(clone)
2177    select = torch.ops.aten.select.int(t, 0, 0);  t = None
2178    mul = torch.ops.aten.mul.Tensor(select, 2);  select = None
2179    t_1 = torch.ops.aten.t.default(clone);  clone = None
2180    select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0);  t_1 = mul = None
2181    t_2 = torch.ops.aten.t.default(select_scatter);  select_scatter = None
2182    t_4 = torch.ops.aten.t.default(t_2)
2183    t_6 = torch.ops.aten.t.default(t_2);  t_2 = None
2184    view_1 = torch.ops.aten.view.default(t_6, [3, 3]);  t_6 = None
2185    return (t_4, view_1)""",
2186        )
2187
2188    def test_view_and_inplace_view(self):
2189        def f(a, b):
2190            a.t_()
2191            return b.view(b.shape), a.view(a.shape)
2192
2193        def create_inp(req_grad):
2194            return [
2195                torch.ones(3, 3, requires_grad=req_grad),
2196                torch.ones(3, 3, requires_grad=req_grad),
2197            ]
2198
2199        self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
2200        fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
2201        self.assertExpectedInline(
2202            fw_graph.code.strip(),
2203            """\
2204def forward(self, arg0_1, arg1_1):
2205    t = torch.ops.aten.t.default(arg0_1);  arg0_1 = None
2206    view = torch.ops.aten.view.default(arg1_1, [3, 3]);  arg1_1 = None
2207    view_1 = torch.ops.aten.view.default(t, [3, 3])
2208    return (t, view, view_1)""",
2209        )
2210
2211    def test_view_detach(self):
2212        def f(a):
2213            tmp = a.detach()
2214            a.mul_(2)
2215            return a, tmp
2216
2217        inp = [torch.ones(3, 3, requires_grad=True)]
2218        self.verify_aot_autograd(f, inp, test_mutation=True)
2219        inp = [torch.ones(3, 3, requires_grad=False)]
2220        self.verify_aot_autograd(f, inp, test_mutation=True)
2221
2222    def test_input_inplace_requires_grad_true(self):
2223        def f(a, b):
2224            a.requires_grad_(True)
2225            return a.mul(3), b.mul(4)
2226
2227        inp = [
2228            # First inp doesnt require grad, but we switch it on
2229            torch.ones(3, 3, requires_grad=False),
2230            torch.ones(3, 3, requires_grad=True),
2231        ]
2232
2233        fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
2234        self.assertExpectedInline(
2235            fw_graph.code.strip(),
2236            """\
2237def forward(self, primals_1, primals_2):
2238    mul = torch.ops.aten.mul.Tensor(primals_1, 3);  primals_1 = None
2239    mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4);  primals_2 = None
2240    return (mul, mul_1)""",
2241        )
2242
2243    # This is a torture test:
2244    # a and b get turned into a synthetic base in the compiled graph
2245    # One gets a data mutation, the other gets a metadata mutation.
2246    # We need to make sure that the metadata mutation gets propagated
2247    # back to the original input.
2248    @skipIfDynamoInput("Dynamo removes runtime error")
2249    def test_input_data_and_metadata_mutation_aliases_other_input(self):
2250        # a and b are aliased
2251        def f(a, b):
2252            a.mul_(2)
2253            b.t_()
2254            return a.mul(b)
2255
2256        def inp_callable(req_grad):
2257            base = torch.ones(2, 2, requires_grad=req_grad)
2258            # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them.
2259            x = base.add(1)
2260            inp1 = x[0]
2261            inp2 = x[0]
2262            return [base], [inp1, inp2]
2263
2264        self.verify_aot_autograd(
2265            f, partial(inp_callable, req_grad=False), test_mutation=True
2266        )
2267        self.verify_aot_autograd(
2268            f, partial(inp_callable, req_grad=True), test_mutation=True
2269        )
2270        with self.assertRaisesRegex(
2271            RuntimeError,
2272            "Encountered aliased inputs that are mutated in the graph, but",
2273        ):
2274            self.verify_aot_autograd(
2275                f,
2276                partial(inp_callable, req_grad=False),
2277                test_mutation=True,
2278                make_inputs_subclasses=True,
2279            )
2280        with self.assertRaisesRegex(
2281            RuntimeError,
2282            "Encountered aliased inputs that are mutated in the graph, but",
2283        ):
2284            self.verify_aot_autograd(
2285                f,
2286                partial(inp_callable, req_grad=True),
2287                test_mutation=True,
2288                make_inputs_subclasses=True,
2289            )
2290
2291    # https://github.com/pytorch/pytorch/issues/106456
2292    def test_input_mutation_noncontiguous(self):
2293        def f(a):
2294            a.mul_(2)
2295            return a + 1
2296
2297        def inp_callable(req_grad):
2298            base = torch.ones(2, 2, requires_grad=req_grad)
2299            x = base.add(1)
2300            # create a non-contiguous view to pass as an input to the compiler
2301            inp = x[:, 0]
2302            return [base], [inp]
2303
2304        self.verify_aot_autograd(
2305            f, partial(inp_callable, req_grad=False), test_mutation=True
2306        )
2307        self.verify_aot_autograd(
2308            f, partial(inp_callable, req_grad=True), test_mutation=True
2309        )
2310        with self.assertRaisesRegex(
2311            RuntimeError,
2312            "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses",
2313        ):
2314            self.verify_aot_autograd(
2315                f,
2316                partial(inp_callable, req_grad=False),
2317                test_mutation=True,
2318                make_inputs_subclasses=True,
2319            )
2320        with self.assertRaisesRegex(
2321            RuntimeError,
2322            "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses",
2323        ):
2324            self.verify_aot_autograd(
2325                f,
2326                partial(inp_callable, req_grad=True),
2327                test_mutation=True,
2328                make_inputs_subclasses=True,
2329            )
2330
2331    def test_backward_mutation_data(self):
2332        class BwMutation(torch.autograd.Function):
2333            @staticmethod
2334            def forward(ctx, x):
2335                ctx.save_for_backward(x)
2336                return x.clone()
2337
2338            @staticmethod
2339            def backward(ctx, grad_output):
2340                (x,) = ctx.saved_tensors
2341                # bw mutation
2342                x.mul_(2)
2343                return grad_output.clone()
2344
2345        def f(a, b):
2346            out = BwMutation.apply(b)
2347            return a * out
2348
2349        inp_no_grad = [
2350            torch.ones(3, 3, requires_grad=True),
2351            torch.ones(3, 3, requires_grad=False),
2352        ]
2353
2354        # Mutation on buffer that does not require grad during the backward is allowed
2355        self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
2356
2357        inp_grad = [
2358            torch.ones(3, 3, requires_grad=True),
2359            torch.ones(3, 3, requires_grad=True),
2360        ]
2361        self.verify_aot_autograd(f, inp_grad, test_mutation=True)
2362
2363    def test_backward_mutation_metadata(self):
2364        class BwMutation(torch.autograd.Function):
2365            @staticmethod
2366            def forward(ctx, a, b):
2367                ctx.save_for_backward(b)
2368                return a.clone(), b.clone()
2369
2370            @staticmethod
2371            def backward(ctx, grad_a, grad_b):
2372                (b,) = ctx.saved_tensors
2373                # bw metadata mutation
2374                b.transpose_(1, 0)
2375                return grad_a.clone(), grad_b.clone()
2376
2377        def f(a, b):
2378            a_, b_ = BwMutation.apply(a, b)
2379            out = a_ * b_
2380            return out
2381
2382        inp_no_grad = [
2383            torch.ones(3, 3, requires_grad=True),
2384            torch.ones(3, 3, requires_grad=False),
2385        ]
2386
2387        with self.assertRaisesRegex(
2388            AssertionError, "input that had its metadata mutated in the backward"
2389        ):
2390            self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
2391
2392    def test_backward_mutation_on_grad_out(self):
2393        class BwMutation(torch.autograd.Function):
2394            @staticmethod
2395            def forward(ctx, x):
2396                return x.clone()
2397
2398            @staticmethod
2399            def backward(ctx, grad_output):
2400                grad_output.mul_(2)
2401                return grad_output.clone()
2402
2403        def f(a, b):
2404            tmp = a * b
2405            out = BwMutation.apply(tmp)
2406            return out
2407
2408        inp_grad = [
2409            torch.ones(3, 3, requires_grad=True),
2410            torch.ones(3, 3, requires_grad=True),
2411        ]
2412        f_compiled = aot_function(f, nop)
2413        with self.assertRaisesRegex(
2414            AssertionError, "input to the backward that was mutated during the backward"
2415        ):
2416            out = f_compiled(*inp_grad)
2417
2418    def test_backward_mutation_forward_inputs(self):
2419        @torch.library.custom_op("_test::_clone", mutates_args={})
2420        def f(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
2421            return x.clone()
2422
2423        def f_fake(x, x1):
2424            return torch.empty_like(x)
2425
2426        def backward(ctx, grad):
2427            with torch.no_grad():
2428                ctx.x1.zero_()
2429            return grad * 2, None
2430
2431        def setup_context(ctx, inputs, output):
2432            (x, x1) = inputs
2433            ctx.x = x
2434            ctx.x1 = x1
2435
2436        f.register_fake(f_fake)
2437        f.register_autograd(backward, setup_context=setup_context)
2438
2439        def fn(x: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
2440            x2.mul_(5)
2441            return torch.ops._test._clone(x, x1) + x2
2442
2443        inp_x, inp_x1, inp_x2 = (
2444            torch.randn(3, requires_grad=True),
2445            torch.randn(3, requires_grad=False),
2446            torch.randn(3, requires_grad=False),
2447        )
2448
2449        ref_x, ref_x1, ref_x2 = inp_x.clone(), inp_x1.clone(), inp_x2.clone()
2450        ref_y = fn(ref_x, ref_x1, ref_x2)
2451
2452        compiled_f = aot_function(fn, nop, keep_inference_input_mutations=True)
2453
2454        x, x1, x2 = inp_x.clone(), inp_x1.clone(), inp_x2.clone()
2455        y = compiled_f(x, x1, x2)
2456
2457        # Verify mutation in forward applied and mutation in backward is not in forward
2458        self.assertEqual(ref_x, x)
2459        self.assertEqual(ref_x1, x1)
2460        self.assertEqual(ref_x2, x2)
2461        self.assertEqual(ref_y, y)
2462
2463        ref_y.sum().backward()
2464        y.sum().backward()
2465
2466        # Verify mutations in backward applied
2467        self.assertEqual(ref_x, x)
2468        self.assertEqual(ref_x1, x1)
2469        self.assertEqual(ref_x2, x2)
2470        self.assertEqual(ref_y, y)
2471
2472        self.assertEqual(ref_x.grad, x.grad)
2473        self.assertEqual(ref_x1.grad, x1.grad)
2474        self.assertEqual(ref_x2.grad, x2.grad)
2475
2476    def test_backward_mutation_forward_inputs_create_graph(self):
2477        @torch.library.custom_op("_test::_clone_create_graph", mutates_args={})
2478        def f(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
2479            return x.clone()
2480
2481        def f_fake(x, x1):
2482            return torch.empty_like(x)
2483
2484        def backward(ctx, grad):
2485            with torch.no_grad():
2486                ctx.x1.zero_()
2487            return grad * 2, None
2488
2489        def setup_context(ctx, inputs, output):
2490            (x, x1) = inputs
2491            ctx.x = x
2492            ctx.x1 = x1
2493
2494        f.register_fake(f_fake)
2495        f.register_autograd(backward, setup_context=setup_context)
2496
2497        def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
2498            return torch.ops._test._clone_create_graph(x, x1)
2499
2500        inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn(
2501            3, requires_grad=True
2502        )
2503
2504        ref_x, ref_x1 = inp_x.clone(), inp_x1.clone()
2505        ref_y = f(ref_x, ref_x1)
2506        ref_y.sum().backward()
2507        x, x1 = inp_x.clone(), inp_x1.clone()
2508        compiled_f = aot_function(fn, nop)
2509        y = compiled_f(x, x1)
2510        loss = y.sum()
2511        with self.assertRaisesRegex(
2512            RuntimeError,
2513            "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True",
2514        ):
2515            torch.autograd.grad(loss, inp_x, create_graph=True)
2516        # Not checking equality of ref and x as Exception is expected
2517
2518    # Partially addresses https://github.com/pytorch/pytorch/issues/106457
2519    def test_input_mutation_false_aliasing(self):
2520        def f(a, b):
2521            a.mul_(3)
2522            b.mul_(2)
2523            return a.clone().view(-1) + b.clone().view(-1)
2524
2525        # No overlap, contiguous
2526        def inp_callable1(req_grad):
2527            base = torch.ones(4, 4, requires_grad=req_grad)
2528            x = base.add(1)
2529            # create two views that share storage, but are actually non-overlapping
2530            a = x[0:2]
2531            b = x[2:4]
2532            return [base], [a, b]
2533
2534        fw_graph = self.verify_aot_autograd(
2535            f, partial(inp_callable1, req_grad=False), test_mutation=True
2536        )
2537        self.verify_aot_autograd(
2538            f, partial(inp_callable1, req_grad=True), test_mutation=True
2539        )
2540        self.verify_aot_autograd(
2541            f,
2542            partial(inp_callable1, req_grad=False),
2543            test_mutation=True,
2544            make_inputs_subclasses=True,
2545        )
2546        # Input mutations on subclasses with training graphs fail backward guards today.
2547        with self.assertRaisesRegex(
2548            AssertionError,
2549            "attempted to compile the backward with incorrect subclass metadata",
2550        ):
2551            self.verify_aot_autograd(
2552                f,
2553                partial(inp_callable1, req_grad=True),
2554                test_mutation=True,
2555                make_inputs_subclasses=True,
2556            )
2557
2558        # Important characteristic: the graph takes in 2 inputs!
2559        # That shows that we didn't try to run our complicated synthetic base logic,
2560        # because we successfully detected false aliasing across the two inputs.
2561        self.assertExpectedInline(
2562            fw_graph.code.strip(),
2563            """\
2564def forward(self, arg0_1, arg1_1):
2565    mul = torch.ops.aten.mul.Tensor(arg0_1, 3);  arg0_1 = None
2566    mul_1 = torch.ops.aten.mul.Tensor(arg1_1, 2);  arg1_1 = None
2567    clone = torch.ops.aten.clone.default(mul)
2568    view = torch.ops.aten.view.default(clone, [-1]);  clone = None
2569    clone_1 = torch.ops.aten.clone.default(mul_1)
2570    view_1 = torch.ops.aten.view.default(clone_1, [-1]);  clone_1 = None
2571    add = torch.ops.aten.add.Tensor(view, view_1);  view = view_1 = None
2572    return (mul, mul_1, add)""",
2573        )
2574
2575        # No overlap, non-contiguous: first tensor ends before second tensor start
2576        def inp_callable2(req_grad):
2577            base = torch.ones(256, requires_grad=req_grad)
2578            x = base.add(1)
2579            a = x.as_strided((4, 4), (8, 1), storage_offset=0)
2580            b = x.as_strided((4, 4), (8, 1), storage_offset=28)
2581            return [base], [a, b]
2582
2583        # No overlap, non-contiguous: tensors are perfectly interleaved
2584        def inp_callable3(req_grad):
2585            base = torch.ones(4, 4, requires_grad=req_grad)
2586            x = base.add(1)
2587            a = x[:, 0:2]
2588            b = x[:, 2:4]
2589            return [base], [a, b]
2590
2591        # No overlap, non-contiguous
2592        def inp_callable4(req_grad):
2593            base = torch.ones(256, requires_grad=req_grad)
2594            x = base.add(1)
2595            a = x.as_strided((4, 4), (9, 1), storage_offset=0)
2596            b = x.as_strided((4, 4), (9, 1), storage_offset=22)
2597            return [base], [a, b]
2598
2599        # No overlap, non-contiguous
2600        def inp_callable5(req_grad):
2601            base = torch.ones(256, requires_grad=req_grad)
2602            x = base.add(1)
2603            a = x.as_strided((4, 4), (9, 1), storage_offset=0)
2604            b = x.as_strided((4, 4), (9, 1), storage_offset=23)
2605            return [base], [a, b]
2606
2607        # No overlap, non-contiguous
2608        def inp_callable6(req_grad):
2609            base = torch.ones(256, requires_grad=req_grad)
2610            x = base.add(1)
2611            # a's last element is at offset 195 (24 total elements)
2612            a = x.as_strided((2, 4, 3), (110, 24, 4), storage_offset=5)
2613            # b's first element is at offset 196: no overlap
2614            b = x[196 : 196 + a.numel()]
2615            return [base], [a, b]
2616
2617        # overlap! non-contiguous
2618        def inp_callable_overlap1(req_grad):
2619            base = torch.ones(256, requires_grad=req_grad)
2620            x = base.add(1)
2621            a = x.as_strided((4, 4), (9, 1), storage_offset=0)
2622            b = x.as_strided((4, 4), (9, 1), storage_offset=24)
2623            return [base], [a, b]
2624
2625        # overlap! non-contiguous
2626        def inp_callable_overlap2(req_grad):
2627            base = torch.ones(256, requires_grad=req_grad)
2628            x = base.add(1)
2629            a = x.as_strided((4, 4), (9, 1), storage_offset=0)
2630            b = x.as_strided((4, 4), (9, 1), storage_offset=25)
2631            return [base], [a, b]
2632
2633        # overlap! non-contiguous
2634        def inp_callable_overlap3(req_grad):
2635            base = torch.ones(256, requires_grad=req_grad)
2636            x = base.add(1)
2637            # a's last element is at offset 195 (24 total elements)
2638            a = x.as_strided((2, 4, 3), (110, 24, 4), storage_offset=5)
2639            # b's first element is at offset 195: overlap!
2640            b = x[195 : 195 + a.numel()]
2641            return [base], [a, b]
2642
2643        fw_graph2 = self.verify_aot_autograd(
2644            f, partial(inp_callable2, req_grad=False), test_mutation=True
2645        )
2646        fw_graph3 = self.verify_aot_autograd(
2647            f, partial(inp_callable3, req_grad=False), test_mutation=True
2648        )
2649        fw_graph4 = self.verify_aot_autograd(
2650            f, partial(inp_callable4, req_grad=False), test_mutation=True
2651        )
2652        fw_graph5 = self.verify_aot_autograd(
2653            f, partial(inp_callable5, req_grad=False), test_mutation=True
2654        )
2655        fw_graph6 = self.verify_aot_autograd(
2656            f, partial(inp_callable6, req_grad=False), test_mutation=True
2657        )
2658
2659        fw_graph_overlap1 = self.verify_aot_autograd(
2660            f, partial(inp_callable_overlap2, req_grad=False), test_mutation=True
2661        )
2662        fw_graph_overlap2 = self.verify_aot_autograd(
2663            f, partial(inp_callable_overlap1, req_grad=False), test_mutation=True
2664        )
2665
2666        # All non-overlap graphs should be the same since we detected false aliasing
2667        self.assertEqual(str(fw_graph.code), str(fw_graph2.code))
2668        self.assertEqual(str(fw_graph.code), str(fw_graph3.code))
2669        self.assertEqual(str(fw_graph.code), str(fw_graph4.code))
2670        self.assertEqual(str(fw_graph.code), str(fw_graph5.code))
2671        self.assertEqual(str(fw_graph.code), str(fw_graph6.code))
2672
2673        # All overlap graphs should be the same since we detected real aliasing
2674        self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap1.code))
2675        self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap2.code))
2676        self.assertTrue("as_strided_scatter" in str(fw_graph_overlap1.code))
2677        self.assertTrue("as_strided_scatter" in str(fw_graph_overlap2.code))
2678
2679    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
2680    def test_mem_leak_from_save_for_bw(self):
2681        # See a full diagnosis at this issue: https://github.com/pytorch/pytorch/issues/94990
2682        # Note [Detaching saved tensors in AOTAutograd]
2683        # This program creates a ref-cycle. Long term, we should fix this ref cycle
2684        # (since it can arise, naturally albeit rarely, from uses of autograd.Function).
2685        # But AOTAutograd makes it more likely to show up from tracing user programs,
2686        # so we deal with it by manually detaching the tensors that we save for backward.
2687        # This is completely wrong and would give wrong results if we were to do double backward.
2688        # Fortunately today, double backward is explicitly banned in AOTAutograd.
2689        def f(a, b):
2690            add = a + a
2691            split = torch.functional.split(add, [4, 4], dim=1)
2692            getitem_2 = split[1]
2693            unsqueeze = getitem_2.unsqueeze(-1)
2694            mul = unsqueeze * b
2695            return (getitem_2, mul)
2696
2697        f_compiled = aot_function(f, nop)
2698        inps = [
2699            torch.ones(8, 8, device="cuda", requires_grad=True),
2700            torch.ones(1, 4, 1, device="cuda", requires_grad=True),
2701        ]
2702        mem_before = torch.cuda.memory_allocated()
2703        f_compiled(*inps)
2704        mem_after = torch.cuda.memory_allocated()
2705        self.assertTrue(mem_after == mem_before)
2706
2707    def test_output_aliases_multiple_inputs_get_correct_one(self):
2708        # a and b are aliased, but have different shapes
2709        # The first output should view off the first input, the 2nd output should view off the 2nd input
2710        def f(a, b):
2711            return a.view(a.shape), b.view(b.shape)
2712
2713        def inp_callable(req_grad):
2714            base = torch.ones(2, 2, requires_grad=req_grad)
2715            # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them.
2716            x = base.mul(2)
2717            inp1 = x.view(-1)
2718            inp2 = x[0]
2719            return [base], [inp1, inp2]
2720
2721        self.verify_aot_autograd(
2722            f, partial(inp_callable, req_grad=False), test_mutation=True
2723        )
2724        self.verify_aot_autograd(
2725            f, partial(inp_callable, req_grad=True), test_mutation=True
2726        )
2727        self.verify_aot_autograd(
2728            f,
2729            partial(inp_callable, req_grad=False),
2730            test_mutation=True,
2731            make_inputs_subclasses=True,
2732        )
2733        self.verify_aot_autograd(
2734            f,
2735            partial(inp_callable, req_grad=True),
2736            test_mutation=True,
2737            make_inputs_subclasses=True,
2738        )
2739
2740    def test_input_mutation_aliases_other_input(self):
2741        def f(a, b):
2742            a.add_(1)
2743            return a + b
2744
2745        def inp_callable(req_grad):
2746            base = torch.ones(4, 2, requires_grad=req_grad)
2747            # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them.
2748            x = base.add(1)
2749            inp1 = x[0]
2750            inp2 = x[0]
2751            return [base], [inp1, inp2]
2752
2753        self.verify_aot_autograd(
2754            f, partial(inp_callable, req_grad=False), test_mutation=True
2755        )
2756        fw_graph = self.verify_aot_autograd(
2757            f, partial(inp_callable, req_grad=True), test_mutation=True
2758        )
2759        # Important parts of the graph:
2760        # - the compiled graph takes in a base, and we generate a and b (the views) off of the base
2761        # - clone() is still in the graph, because we need to call grad() on the original (non-mutated) inputs
2762        # - We re-generate the views *after* the clone, to preserve view relationships.
2763        self.assertExpectedInline(
2764            fw_graph.code.strip(),
2765            """\
2766def forward(self, primals_1):
2767    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
2768    as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
2769    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
2770    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0);  clone = add = None
2771    as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
2772    as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
2773    add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5);  as_strided_2 = as_strided_5 = None
2774    return (as_strided_scatter, add_1)""",
2775        )  # noqa: B950
2776
2777    def test_input_mutation_aliases_other_input2(self):
2778        def f(a, b):
2779            a.add_(1)
2780            return a + b
2781
2782        def inp_callable(req_grad):
2783            base = torch.ones(2, 2, requires_grad=req_grad)
2784            x = base.add(1)
2785            inp1 = x[0]
2786            # Here, one of the aliased inputs is the base itself
2787            inp2 = x
2788            return [base], [inp1, inp2]
2789
2790        self.verify_aot_autograd(
2791            f, partial(inp_callable, req_grad=False), test_mutation=True
2792        )
2793        fw_graph = self.verify_aot_autograd(
2794            f, partial(inp_callable, req_grad=True), test_mutation=True
2795        )
2796        self.assertExpectedInline(
2797            fw_graph.code.strip(),
2798            """\
2799def forward(self, primals_1):
2800    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
2801    as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0)
2802    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
2803    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0);  clone = add = None
2804    as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
2805    as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
2806    add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5);  as_strided_2 = as_strided_5 = None
2807    return (as_strided_scatter, add_1)""",
2808        )  # noqa: B950
2809
2810    def test_input_mutation_aliases_and_output_alias(self):
2811        def f(a, b):
2812            # Here, we need to take care:that because and b are aliased
2813            # since a and b are aliased, we generate a view off of "updated b"
2814            a.add_(1)
2815            return b.view(b.shape)
2816
2817        def inp_callable(req_grad):
2818            base = torch.ones(2, 2, requires_grad=req_grad)
2819            x = base.add(1)
2820            return [base], [x.view(-1), x.view(-1)]
2821
2822        self.verify_aot_autograd(
2823            f, partial(inp_callable, req_grad=False), test_mutation=True
2824        )
2825        fw_graph = self.verify_aot_autograd(
2826            f, partial(inp_callable, req_grad=True), test_mutation=True
2827        )
2828        self.assertExpectedInline(
2829            fw_graph.code.strip(),
2830            """\
2831def forward(self, primals_1):
2832    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
2833    as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
2834    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
2835    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0);  clone = add = None
2836    as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
2837    view_1 = torch.ops.aten.view.default(as_strided_8, [4]);  as_strided_8 = None
2838    return (as_strided_scatter, view_1)""",
2839        )  # noqa: B950
2840
2841    def test_input_aliased_with_mutation_output_alias(self):
2842        def f(a, b, c):
2843            # a and c alias
2844            c.mul_(2)
2845            # The main thing we're testing here is that
2846            # (1) We need to reconstruct c.view(-1) from the 3rd input to the forward
2847            # (2) But we need to be careful to do this *before* converting aliased inputs into synthetic bases.
2848            #     The original fw takes in 3 args, but the compiled fw takes in only 2 args.
2849            return b.add(1), c.view(-1)
2850
2851        def inp_callable(req_grad):
2852            base1 = torch.ones(2, 2, requires_grad=req_grad)
2853            base2 = torch.ones(2, 2, requires_grad=req_grad)
2854            x = base1.add(1)
2855            y = base2.add(1)
2856            return [base1, base2], [x.view(-1), y, x.view(-1)]
2857
2858        self.verify_aot_autograd(
2859            f, partial(inp_callable, req_grad=False), test_mutation=True
2860        )
2861        fw_graph = self.verify_aot_autograd(
2862            f, partial(inp_callable, req_grad=True), test_mutation=True
2863        )
2864        self.assertExpectedInline(
2865            fw_graph.code.strip(),
2866            """\
2867def forward(self, primals_1, primals_2):
2868    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
2869    as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
2870    mul = torch.ops.aten.mul.Tensor(as_strided_1, 2);  as_strided_1 = None
2871    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0);  clone = mul = None
2872    add = torch.ops.aten.add.Tensor(primals_2, 1);  primals_2 = None
2873    as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
2874    view_1 = torch.ops.aten.view.default(as_strided_7, [-1]);  as_strided_7 = None
2875    return (as_strided_scatter, add, view_1)""",
2876        )  # noqa: B950
2877
2878    def test_input_metadata_mutation_aliases(self):
2879        def f(a, b):
2880            # a and b alias, and we do a metadata mutation on a
2881            # Since we're not mutating data, then b isn't affected at all.
2882            # We expect aot autograd to not bother with constructing a synthetic base.
2883            a.t_()
2884            return a + b
2885
2886        def inp_callable(req_grad):
2887            base = torch.ones(2, 2, requires_grad=req_grad)
2888            x = base.add(1)
2889            return [base], [x.view(-1), x.view(-1)]
2890
2891        self.verify_aot_autograd(
2892            f, partial(inp_callable, req_grad=False), test_mutation=True
2893        )
2894        fw_graph = self.verify_aot_autograd(
2895            f, partial(inp_callable, req_grad=True), test_mutation=True
2896        )
2897        # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base.
2898        self.assertExpectedInline(
2899            fw_graph.code.strip(),
2900            """\
2901def forward(self, primals_1, primals_2):
2902    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
2903    add = torch.ops.aten.add.Tensor(t, primals_2);  t = primals_2 = None
2904    return (add,)""",
2905        )
2906
2907    def test_input_mutation_aliases_and_none_require_gradients(self):
2908        def f(a, b, c):
2909            # a and b alias, but neither require gradients (so they don't have a _base)
2910            # aot autograd should construct the synthetic base from `torch.Tensor(a.storage())`
2911            a.mul_(2)
2912            return b + 1, c + 1
2913
2914        def inp_callable(req_grad):
2915            base = torch.ones(2, 2)
2916            c_arg = torch.ones(2, 2, requires_grad=req_grad)
2917            x = base.add(1)
2918            return [base, c_arg], [x.view(-1), x.view(-1), c_arg]
2919
2920        self.verify_aot_autograd(
2921            f, partial(inp_callable, req_grad=False), test_mutation=True
2922        )
2923
2924        with self.assertRaisesRegex(
2925            RuntimeError, "is a tensor subclass. This is not supported today"
2926        ):
2927            self.verify_aot_autograd(
2928                f,
2929                partial(inp_callable, req_grad=False),
2930                test_mutation=True,
2931                make_inputs_subclasses=True,
2932            )
2933
2934        fw_graph = self.verify_aot_autograd(
2935            f, partial(inp_callable, req_grad=True), test_mutation=True
2936        )
2937        self.assertExpectedInline(
2938            fw_graph.code.strip(),
2939            """\
2940def forward(self, primals_1, primals_2):
2941    as_strided = torch.ops.aten.as_strided.default(primals_1, [4], [1], 0)
2942    mul = torch.ops.aten.mul.Tensor(as_strided, 2);  as_strided = None
2943    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(primals_1, mul, [4], [1], 0);  primals_1 = mul = None
2944    as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
2945    add = torch.ops.aten.add.Tensor(as_strided_3, 1);  as_strided_3 = None
2946    add_1 = torch.ops.aten.add.Tensor(primals_2, 1);  primals_2 = None
2947    return (as_strided_scatter, add, add_1)""",
2948        )  # noqa: B950
2949
2950    @skipIfDynamoInput("Fails with dynamo")
2951    def test_input_mutation_aliases_bases_out_of_order(self):
2952        # This tests our calling convention: if b and d are aliased, then the outer calling convention
2953        # that we send to the compiled forward becomes:
2954        # (b_d_base, a, c)
2955        # Importantly, even though a and c alias in our test, neither inputs are mutated,
2956        # So we don't need to do the base construction / deconstruction
2957        def f(a, b, c, d):
2958            b.add_(1)
2959            d.unsqueeze_(0)
2960            return a + c + d, b.view(-1)
2961
2962        def inp_callable(req_grad):
2963            base1 = torch.ones(2, 2, requires_grad=req_grad)
2964            base2 = torch.ones(2, 2, requires_grad=req_grad)
2965            x1 = base1.add(1)
2966            x2 = base2.add(1)
2967            # a and c alias, b and d alias
2968            return [base1, base2], [x1.view(-1), x2.view(-1), x1.view(-1), x2.view(-1)]
2969
2970        self.verify_aot_autograd(
2971            f, partial(inp_callable, req_grad=False), test_mutation=True
2972        )
2973
2974        with self.assertRaisesRegex(
2975            RuntimeError,
2976            "Metadata mutations are currently not allowed on tensor subclasses",
2977        ):
2978            self.verify_aot_autograd(
2979                f,
2980                partial(inp_callable, req_grad=False),
2981                test_mutation=True,
2982                make_inputs_subclasses=True,
2983            )
2984
2985        fw_graph = self.verify_aot_autograd(
2986            f, partial(inp_callable, req_grad=True), test_mutation=True
2987        )
2988        # 3 graph inputs: (b_d_base, a, c)
2989        # 2 returns: (b_updated, a+c+d)
2990        # (there are 2 original fw outs, but one is a view of b so it's not part of the graph)
2991        # (there are also 2 input mutations, but one is a metadata-only mutation so the compiled forward doesn't return it)
2992        self.assertExpectedInline(
2993            fw_graph.code.strip(),
2994            """\
2995def forward(self, primals_1, primals_2, primals_3):
2996    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
2997    as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
2998    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
2999    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0);  clone = add = None
3000    add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3);  primals_2 = primals_3 = None
3001    as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
3002    unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0);  as_strided_5 = None
3003    add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1);  add_1 = None
3004    as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
3005    view_2 = torch.ops.aten.view.default(as_strided_14, [-1]);  as_strided_14 = None
3006    return (as_strided_scatter, add_2, view_2, unsqueeze_1)""",
3007        )  # noqa: B950
3008
3009    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
3010    def test_synthetic_base_base_attribute_is_none(self):
3011        def f(a, b):
3012            a.add_(1)
3013            return a + b
3014
3015        def inp_callable():
3016            base = torch.ones(4, 4, device="cuda")
3017            # detach() so that none of the inputs have a ._base attribute.
3018            a = base[0].detach()
3019            b = base[1].detach()
3020            base2 = torch.ones(2, 2, requires_grad=True)
3021            return [base], [a, b]
3022
3023        self.verify_aot_autograd(f, inp_callable, test_mutation=True)
3024
3025    def test_input_mutation_alias_everything(self):
3026        # Mondo test that tests a combination of:
3027        # input is mutated, that aliases another input (so we make a synthetic base)
3028        # an output is an alias of another output
3029        # an output is an alias of an intermediate
3030        # a and c are aliased
3031        def f(a, b, c):
3032            c.mul_(2)  # mutates c
3033            b.t_()  # metadata mutate b
3034            tmp = a + c
3035            out1 = tmp.view(-1)
3036            out2 = b.t()
3037            out3 = out1.unsqueeze(0)
3038            # out1 and out3 are aliases of an intermediate, and alias each other!
3039            # out2 aliases an input, so we don't return it
3040            return out1, out2, out3
3041
3042        def inp_callable(req_grad):
3043            base1 = torch.ones(2, 2, requires_grad=req_grad)
3044            base2 = torch.ones(2, 2, requires_grad=req_grad)
3045            # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them.
3046            base1_ = base1.add(1)
3047            base2_ = base2.add(1)
3048            a = base1_.view(-1)
3049            b = base2_
3050            c = base1_.view(-1)
3051            return [base1, base2], [a, b, c]
3052
3053        self.verify_aot_autograd(
3054            f, partial(inp_callable, req_grad=False), test_mutation=True
3055        )
3056        fw_graph = self.verify_aot_autograd(
3057            f, partial(inp_callable, req_grad=True), test_mutation=True
3058        )
3059        # Expected:
3060        # - 2 inputs in the forward: synthetic_base_a_c, b
3061        # - 1 output in the forward: "tmp"
3062        #   out2 is an alias of an input, and will be generated off of b outside of the compiled fn
3063        #   out1 and out3 are aliases of tmp, that we generate outside of the compiled function
3064        self.assertExpectedInline(
3065            fw_graph.code.strip(),
3066            """\
3067def forward(self, primals_1, primals_2):
3068    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
3069    view = torch.ops.aten.view.default(primals_2, [2, 2]);  primals_2 = None
3070    as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
3071    mul = torch.ops.aten.mul.Tensor(as_strided_1, 2);  as_strided_1 = None
3072    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0);  clone = mul = None
3073    as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
3074    t = torch.ops.aten.t.default(view);  view = None
3075    as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
3076    add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2);  as_strided_5 = as_strided_2 = None
3077    view_1 = torch.ops.aten.view.default(add, [-1])
3078    t_1 = torch.ops.aten.t.default(t)
3079    unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)
3080    return (as_strided_scatter, t, view_1, t_1, unsqueeze, add)""",
3081        )  # noqa: B950
3082
3083    def test_dynamic_shape_output_not_in_bw_graph(self):
3084        def f(x):
3085            return [x + 1, x.shape[0]]
3086
3087        inp = torch.ones(5, requires_grad=True)
3088        bw_graph_cell = [None]
3089        compiled_f = aot_function(
3090            f,
3091            fw_compiler=nop,
3092            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
3093            decompositions={},
3094            keep_inference_input_mutations=False,
3095            dynamic=True,
3096        )
3097        out = compiled_f(inp)
3098        out[0].sum().backward()
3099        # The important bit: the forward fn returns 2 outputs,
3100        # but one of them is a symint so we should only see
3101        # 1 grad_output as an input to the backward graph.
3102        # (Otherwise, autograd will plumb a None as the value of the grad_output,
3103        # which causes inductor to complain).
3104        self.assertExpectedInline(
3105            bw_graph_cell[0].code.strip(),
3106            """\
3107def forward(self, tangents_1):
3108    return (tangents_1,)""",
3109        )
3110
3111    def test_no_grad_input_output(self):
3112        def f(a, b):
3113            return a.cos(), b.cos(), a * b
3114
3115        inp_thunks = [
3116            lambda: torch.randn(5, requires_grad=True),
3117            lambda: torch.randn(5, requires_grad=False),
3118        ]
3119        for inps in itertools.product(inp_thunks, repeat=2):
3120            inps = [i() for i in inps]
3121            self.verify_aot_autograd(f, inps)
3122
3123    def test_some_output_requires_grad_input_doesnt(self):
3124        def f(a, b):
3125            a_view = a.view(-1)
3126            a_view.requires_grad_(True)
3127            return a_view
3128
3129        inp = [torch.randn(3, 3), torch.randn(3, 3, requires_grad=True)]
3130        self.verify_aot_autograd(f, inp)
3131
3132    def test_some_outputs_dont_require_grad_view(self):
3133        def f(a, b):
3134            return a.detach(), b
3135
3136        inp = [
3137            torch.randn(3, 3, requires_grad=True),
3138            torch.randn(3, 3, requires_grad=True),
3139        ]
3140        self.verify_aot_autograd(f, inp)
3141
3142    def test_some_outputs_dont_require_grad_non_view(self):
3143        def f(a, b):
3144            return a.add(1).detach(), b
3145
3146        inp = [
3147            torch.randn(3, 3, requires_grad=True),
3148            torch.randn(3, 3, requires_grad=True),
3149        ]
3150        self.verify_aot_autograd(f, inp)
3151
3152    def test_inner_grad(self):
3153        def foo(x):
3154            y = torch.exp(x)
3155            z = torch.autograd.grad(y, x)
3156            return z
3157
3158        inps = [torch.randn((), requires_grad=True)]
3159        self.verify_aot_autograd(foo, inps)
3160
3161    def test_grad_context(self):
3162        def foo(x):
3163            return x * 2
3164
3165        inps = [torch.randn((), requires_grad=True)]
3166        graph_size = None
3167
3168        def get_graph_size(fx_g, _):
3169            nonlocal graph_size
3170            graph_size = len(fx_g.graph.nodes)
3171            return fx_g
3172
3173        f = aot_function(foo, nop, get_graph_size)
3174        with torch.set_grad_enabled(False):
3175            f(*inps)
3176        self.assertIsNone(graph_size)
3177
3178        f = aot_function(foo, nop, get_graph_size)
3179        with torch.set_grad_enabled(True):
3180            out = f(*inps)
3181            self.assertIsNone(graph_size)
3182            out.sum().backward()
3183            self.assertTrue(graph_size > 2)
3184
3185    def test_output_dict(self):
3186        def f(x):
3187            return {"a": x, "b": x}
3188
3189        inp = [torch.randn(3, 3, requires_grad=True)]
3190        self.verify_aot_autograd(f, inp)
3191
3192        def f(x, y):
3193            return {"a": x, "b": y + x}
3194
3195        inp = [torch.randn(3, requires_grad=True), torch.randn(3)]
3196        self.verify_aot_autograd(f, inp)
3197
3198        def f(x):
3199            new_d = {}
3200            for k in x:
3201                new_d[k] = x[k] * 2
3202            return new_d
3203
3204        a = torch.randn(3, requires_grad=True)
3205        b = torch.randn(3, requires_grad=True)
3206
3207        def inp_callable():
3208            inps = [{"a": a, "b": b}]
3209            return inps, inps
3210
3211        self.verify_aot_autograd(f, inp_callable)
3212
3213    def test_module(self):
3214        mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
3215        compiled_mod = compiled_module(mod, nop, nop)
3216        inp = torch.randn(32, 32)
3217        ref_out = mod(inp)
3218        ref_out.sum().backward()
3219        ref_grads = sorted([(name, p.grad) for name, p in mod.named_parameters()])
3220        out = compiled_mod(inp)
3221        out.sum().backward()
3222        grads = sorted([(name, p.grad) for name, p in mod.named_parameters()])
3223        self.assertEqual((out, grads), (ref_out, ref_grads))
3224
3225    def test_batchnorm(self):
3226        mod = compiled_module(nn.BatchNorm2d(4), nop, nop)
3227        x = torch.ones(1, 4, 2, 2)
3228        mod(x).sum().backward()
3229
3230    def test_list_codegen(self):
3231        def list_nop(f, _):
3232            def g(inps):
3233                return f(*inps)
3234
3235            g._boxed_call = True
3236            return g
3237
3238        def f(a, b, c):
3239            return a.sin() * b.cos() * c.sin()
3240
3241        f = aot_function(f, list_nop)
3242        inp = [torch.randn(5, requires_grad=True) for _ in range(3)]
3243        f(*inp).sum().backward()
3244
3245    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
3246    def test_compilation_context(self, counter):
3247        def f(x):
3248            return x.sin().sin()
3249
3250        count = []
3251
3252        def compiler(fx_g, _):
3253            context = get_aot_compilation_context()
3254            count.append((context[0], len(fx_g.graph.nodes)))
3255            return fx_g
3256
3257        f = aot_function(f, compiler)
3258        out = f(torch.randn(5, requires_grad=True))
3259        f = aot_function(f, compiler)
3260        f(torch.randn(5))
3261        out.sum().backward()
3262        self.assertExpectedInline(
3263            str(count),
3264            """[(['0_forward'], 4), (['1_inference'], 4), (['0_backward'], 8)]""",
3265        )
3266
3267    def test_dupe_arg(self):
3268        def f(x, y):
3269            return x + y
3270
3271        x = torch.randn(3, 3, requires_grad=True)
3272        self.verify_aot_autograd(f, [x, x])
3273
3274    def test_dupe_arg_torture(self):
3275        def f(x, y):
3276            x.t_()
3277            y.unsqueeze_(0)
3278            return x + y
3279
3280        x = torch.randn(3, 3, requires_grad=True).clone()
3281        self.verify_aot_autograd(f, [x, x])
3282
3283    # See https://github.com/pytorch/pytorch/issues/100224
3284    def test_dupe_arg_returned_as_output(self):
3285        def f(a, b, a_):
3286            a[0].add_(1)
3287            return a_
3288
3289        f_compiled = aot_function(f, nop)
3290        a = torch.ones(2)
3291        b = torch.ones(2)
3292        out_ref = f(a, b, a)
3293
3294        a2 = torch.ones(2)
3295        b2 = torch.ones(2)
3296        out_test = f_compiled(a2, b2, a2)
3297
3298        self.assertEqual(out_ref, out_test)
3299        self.assertEqual(a, a2)
3300
3301    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
3302    @patch("torch._functorch.config.debug_assert", True)
3303    def test_invalid_dupe_left_bias(self, counter):
3304        # This test checks that, just because only the first
3305        # argument did a metadata mutation, we still correctly
3306        # switch to strategy 2 (deduplicate)
3307        # See: https://github.com/pytorch/pytorch/pull/89896#discussion_r1036224447
3308        class F(torch.nn.Module):
3309            def forward(self, x, y):
3310                x.t_()
3311                return (x + y,)
3312
3313        x = torch.randn(3, 3, requires_grad=True).clone()
3314        y = torch.randn(3, 3, requires_grad=True)
3315        self.verify_aot_autograd(F(), [x, x])
3316
3317        fxx = aot_module_simplified(F(), (x, x), nop)
3318        self.assertExpectedRaisesInline(
3319            AssertionError,
3320            lambda: fxx(x, y),
3321            """At compilation time, graph 2 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case.  This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""",  # noqa: B950
3322        )
3323
3324    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
3325    @patch("torch._functorch.config.debug_assert", True)
3326    def test_invalid_dupe(self, counter):
3327        self._test_invalid_dupe(counter, fake=False)
3328
3329    # See Note: Dynamo recompilation guarding invalid grad for why this test exists
3330    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
3331    @patch("torch._functorch.config.debug_assert", True)
3332    def test_invalid_dupe_fake(self, counter):
3333        self._test_invalid_dupe(counter, fake=True)
3334
3335    def _test_invalid_dupe(self, counter, fake):
3336        class F(torch.nn.Module):
3337            def forward(self, x, y):
3338                x.unsqueeze_(0)
3339                y.unsqueeze_(0)
3340                return (x + y,)
3341
3342        x = torch.randn(3, 3, requires_grad=True).clone()
3343        y = torch.randn(3, 3, requires_grad=True).clone()
3344
3345        if fake:
3346            shape_env = ShapeEnv()
3347            fake_mode = FakeTensorMode(shape_env=shape_env)
3348
3349            fake_x = fake_mode.from_tensor(x)
3350            fake_y = fake_mode.from_tensor(y)
3351
3352        if fake:
3353            fxy = aot_module_simplified(F(), (fake_x, fake_y), nop)
3354        else:
3355            fxy = aot_module_simplified(F(), (x, y), nop)
3356
3357        fxy(x, y)
3358        x = torch.randn(3, 3, requires_grad=True).clone()
3359        y = torch.randn(3, 3, requires_grad=True).clone()
3360        fxy(x, x)  # is ok!
3361
3362        if fake:
3363            fxx = aot_module_simplified(F(), (fake_x, fake_x), nop)
3364        else:
3365            fxx = aot_module_simplified(F(), (x, x), nop)
3366
3367        x = torch.randn(3, 3, requires_grad=True).clone()
3368        y = torch.randn(3, 3, requires_grad=True).clone()
3369        fxx(x, x)
3370        # Note This should not raise! Once we have guards in place here,
3371        # we will have this working correctly, as it should recompile.
3372        x = torch.randn(3, 3, requires_grad=True).clone()
3373        y = torch.randn(3, 3, requires_grad=True).clone()
3374        self.assertExpectedRaisesInline(
3375            AssertionError,
3376            lambda: fxx(x, y),
3377            """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case.  This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""",  # noqa: B950
3378        )
3379
3380    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
3381    @patch("torch._functorch.config.debug_assert", True)
3382    def test_invalid_requires_grad(self, counter):
3383        self._test_invalid_requires_grad(counter, fake=False)
3384
3385    # See Note: Dynamo recompilation guarding invalid grad for why this test exists
3386    @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count)
3387    @patch("torch._functorch.config.debug_assert", True)
3388    def test_invalid_requires_grad_fake(self, counter):
3389        self._test_invalid_requires_grad(counter, fake=True)
3390
3391    def _test_invalid_requires_grad(self, counter, fake):
3392        class F(torch.nn.Module):
3393            def forward(self, x, y):
3394                return (x + y,)
3395
3396        x = torch.randn(3, 3, requires_grad=True)
3397        y = torch.randn(3, 3, requires_grad=True)
3398        z = torch.randn(3, 3, requires_grad=False)
3399
3400        if fake:
3401            shape_env = ShapeEnv()
3402            fake_mode = FakeTensorMode(shape_env=shape_env)
3403
3404            fake_x = fake_mode.from_tensor(x)
3405            fake_y = fake_mode.from_tensor(y)
3406            fake_z = fake_mode.from_tensor(z)
3407
3408        if fake:
3409            fxy = aot_module_simplified(F(), (fake_x, fake_y), nop)
3410        else:
3411            fxy = aot_module_simplified(F(), (x, y), nop)
3412
3413        compare_equal_outs_and_grads(self, F(), fxy, (x, y))
3414        compare_equal_outs_and_grads(self, F(), fxy, (x, z))
3415
3416        if fake:
3417            fxz = aot_module_simplified(F(), (fake_x, fake_z), nop)
3418        else:
3419            fxz = aot_module_simplified(F(), (x, z), nop)
3420
3421        compare_equal_outs_and_grads(self, F(), fxz, (x, z))
3422
3423        self.assertExpectedRaisesInline(
3424            AssertionError,
3425            lambda: fxz(x, y),
3426            """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case.  This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""",  # noqa: B950
3427        )
3428
3429    def test_custom_autograd(self):
3430        class CustomFn(torch.autograd.Function):
3431            @staticmethod
3432            def forward(ctx, x):
3433                return x.clone()
3434
3435            @staticmethod
3436            def backward(ctx, grad_output):
3437                return grad_output + 1
3438
3439        def f(x):
3440            return CustomFn.apply(x)
3441
3442        self.verify_aot_autograd(f, [torch.randn(3)])
3443
3444    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
3445    def test_autocast_disable_guard(self):
3446        with torch._C._DisableAutocast():
3447            x = torch.rand([4, 4]).cuda()
3448            y = x @ x
3449            self.assertEqual(y.dtype, torch.float32)
3450
3451    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
3452    def test_nonidempotent_amp(self):
3453        def f(self_s_emb, add_3):
3454            einsum_2 = torch.functional.einsum("ah,th->t", self_s_emb, add_3)
3455            log_softmax_2 = einsum_2.log_softmax(-1)
3456            return (log_softmax_2,)
3457
3458        args = [
3459            torch.rand((1, 256), dtype=torch.float32, device="cuda"),
3460            torch.rand((30, 256), dtype=torch.float16, device="cuda"),
3461        ]
3462        with torch.cuda.amp.autocast(enabled=True):
3463            self.verify_aot_autograd(f, args)
3464
3465        args = [e.requires_grad_(True) for e in args]
3466        with torch.cuda.amp.autocast(enabled=True):
3467            self.verify_aot_autograd(f, args)
3468
3469    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
3470    @unittest.skipIf(not torch.backends.cudnn.is_available(), "CUDNN is unavailable")
3471    @skipIfRocm  # https://github.com/pytorch/pytorch/issues/96560
3472    def test_batch_norm_amp(self):
3473        device = "cuda"
3474        input_dtype = torch.float16
3475        param_dtype = torch.float32
3476        weight, bias = (
3477            torch.ones(64, device=device, dtype=param_dtype, requires_grad=True)
3478            for _ in range(2)
3479        )
3480        running_mean, running_var = (
3481            torch.ones(64, device=device, dtype=param_dtype) for _ in range(2)
3482        )
3483
3484        def bn(x):
3485            return torch.ops.aten.cudnn_batch_norm(
3486                x,
3487                weight,
3488                bias,
3489                running_mean,
3490                running_var,
3491                False,
3492                0.1,
3493                1e-05,
3494            )
3495
3496        inp = torch.ones(
3497            torch.Size([16, 64, 112, 112]), dtype=input_dtype, device=device
3498        )
3499
3500        ref = bn(inp)
3501        cudnn_batch_norm_decomp = torch._decomp.get_decompositions(
3502            {torch.ops.aten.cudnn_batch_norm}
3503        )
3504        aot_fn = make_fx(bn, decomposition_table=cudnn_batch_norm_decomp)(inp)
3505        res = aot_fn(inp)
3506        for a, b in zip(ref, res):
3507            assert torch.allclose(a, b)
3508
3509    def test_output_op_depending_on_symint(self):
3510        """
3511        It won't be obvious from reading this test what it's testing for.  We should probably make it into a more
3512        focused unit test.
3513
3514        An issue with the following program was the expand op would end up depending on a symint whose proxy was
3515        incorrectly associated with one of the grad tensors rather than input tensors.  It broke partitioner logic
3516        and the net result was aot_function failed to produce a function and threw an exception instead.
3517        """
3518        inp = torch.randn(5, requires_grad=True)
3519
3520        def f(x):
3521            return x.expand(x.shape)
3522
3523        # TODO(whc) make this work (test setup is wrong somehow)
3524        # joint_forward_backward = create_joint_forward_backward(f)
3525        # out = f(inp)
3526        # joint_inputs =  ([inp], [out.detach().contiguous()])
3527        # fx_g = make_fx(joint_forward_backward)(*joint_inputs)
3528        # TODO: assert outputs of fwd graph trace to correct symint
3529
3530        # e2e test that fails without symint clone fix
3531        af = aot_function(
3532            f,
3533            nop,
3534            partition_fn=partial(
3535                min_cut_rematerialization_partition, compiler="inductor"
3536            ),
3537            dynamic=True,
3538        )
3539        out = af(inp)
3540        self.assertEqual(out, f(inp))
3541
3542    def test_inference_mode(self):
3543        m = torch.nn.Linear(4, 4)
3544        inp = torch.randn(4, 4)
3545
3546        aot_mod = aot_module(m, fw_compiler=nop)
3547
3548        with torch.inference_mode():
3549            out_ref = m(inp)
3550            out_test = aot_mod(inp)
3551        self.assertEqual(out_ref, out_test)
3552
3553    def test_default_partitioner_saves_symints_not_tensors_for_bw(self):
3554        """
3555        In this test, the important thing is that primals_1 is **only** needed in the backward
3556        in order to grab its sizes.
3557        We need to assert that what we save for the backward are the tensor's sizes, and not the tensor itself.
3558
3559        The way this test is set up, it will actually fail if we try to save the input tensor for backward.
3560        Why?
3561        b.masked_fill_(c, 0) has a backward that requires knowing a's sizes
3562        b.masked_fill_(c, 0) **also** mutates a (because b and a are aliased)
3563        The autograd engine yells at us if we save "a" for backward, and then try to mutate it.
3564        """
3565        inp = torch.randn(2, 2, requires_grad=True)
3566
3567        def f(a):
3568            b = a[0]
3569            c = torch.ones_like(b, dtype=torch.bool)
3570            d = b.masked_fill_(c, 0)
3571            return d
3572
3573        compiled_f = aot_function(f, nop, dynamic=True)
3574        inp_ref = torch.ones(2, 2, requires_grad=True)
3575        inp_test = torch.ones(2, 2, requires_grad=True)
3576
3577        out_ref = f(inp_ref.clone())
3578        out_test = compiled_f(inp_test.clone())
3579
3580        self.assertEqual(out_ref, out_test)
3581
3582        out_ref.sum().backward()
3583        out_test.sum().backward()
3584
3585        self.assertEqual(inp_ref.grad, inp_test.grad)
3586
3587    def test_buffer_copied_in_graph(self):
3588        class MyModel(torch.nn.Module):
3589            def __init__(self) -> None:
3590                super().__init__()
3591                self.buf = torch.nn.Buffer(torch.zeros(1))
3592                self.w1 = torch.nn.Parameter(torch.zeros(1))
3593                self.w2 = torch.nn.Parameter(torch.zeros(1))
3594
3595            def forward(self, x):
3596                self.buf.add_(1)
3597                return (self.w1 * x * self.w2).sum() + self.buf.sum()
3598
3599        model_for_eager = MyModel()
3600        model_for_compile = copy.deepcopy(model_for_eager)
3601
3602        fw_graph_cell = [None]
3603        compiled_f = aot_module(
3604            model_for_compile,
3605            fw_compiler=make_boxed_compiler(
3606                partial(extract_graph, graph_cell=fw_graph_cell)
3607            ),
3608            bw_compiler=nop,
3609            keep_inference_input_mutations=True,
3610        )
3611        inp_ref = torch.ones(1, requires_grad=True)
3612        inp_test = torch.ones(1, requires_grad=True)
3613
3614        out_ref = model_for_eager(inp_ref.clone())
3615        out_test = compiled_f(inp_test.clone())
3616
3617        self.assertExpectedInline(
3618            fw_graph_cell[0].code.strip(),
3619            """\
3620def forward(self, primals_1, primals_2, primals_3, primals_4):
3621    add = torch.ops.aten.add.Tensor(primals_3, 1)
3622    mul = torch.ops.aten.mul.Tensor(primals_1, primals_4)
3623    mul_1 = torch.ops.aten.mul.Tensor(mul, primals_2)
3624    sum_1 = torch.ops.aten.sum.default(mul_1);  mul_1 = None
3625    sum_2 = torch.ops.aten.sum.default(add)
3626    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
3627    copy_ = torch.ops.aten.copy_.default(primals_3, add);  primals_3 = add = copy_ = None
3628    return (add_1, primals_1, primals_2, primals_4, mul)""",
3629        )
3630
3631        self.assertEqual(out_ref, out_test)
3632
3633        out_ref.sum().backward()
3634        out_test.sum().backward()
3635
3636        eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
3637        compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]
3638
3639        self.assertEqual(eager_grads, compile_grads)
3640        self.assertEqual(inp_ref.grad, inp_test.grad)
3641
3642    def test_buffer_copied_in_graph_with_different_shapes(self):
3643        class MyModel(torch.nn.Module):
3644            def __init__(self) -> None:
3645                super().__init__()
3646                self.buf = torch.nn.Buffer(torch.ones(4, 4))
3647                self.w = torch.nn.Parameter(
3648                    torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]])
3649                )
3650
3651            def forward(self, x):
3652                self.buf.add_(1)
3653                return (self.w @ x).sum() + self.buf.sum()
3654
3655        model_for_eager = MyModel()
3656        model_for_compile = copy.deepcopy(model_for_eager)
3657
3658        fw_graph_cell = [None]
3659        compiled_f = aot_module(
3660            model_for_compile,
3661            fw_compiler=make_boxed_compiler(
3662                partial(extract_graph, graph_cell=fw_graph_cell)
3663            ),
3664            bw_compiler=nop,
3665            keep_inference_input_mutations=True,
3666        )
3667        inp_ref = torch.ones(2, 4, requires_grad=True)
3668        inp_test = torch.ones(2, 4, requires_grad=True)
3669
3670        out_ref = model_for_eager(inp_ref.clone())
3671        out_test = compiled_f(inp_test.clone())
3672
3673        self.assertExpectedInline(
3674            fw_graph_cell[0].code.strip(),
3675            """\
3676def forward(self, primals_1, primals_2, primals_3):
3677    add = torch.ops.aten.add.Tensor(primals_2, 1)
3678    mm = torch.ops.aten.mm.default(primals_1, primals_3)
3679    sum_1 = torch.ops.aten.sum.default(mm);  mm = None
3680    sum_2 = torch.ops.aten.sum.default(add)
3681    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
3682    copy_ = torch.ops.aten.copy_.default(primals_2, add);  primals_2 = add = copy_ = None
3683    return (add_1, primals_1, primals_3)""",
3684        )
3685        self.assertEqual(out_ref, out_test)
3686
3687        out_ref.sum().backward()
3688        out_test.sum().backward()
3689
3690        eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
3691        compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]
3692
3693        self.assertEqual(eager_grads, compile_grads)
3694
3695        self.assertEqual(inp_ref.grad, inp_test.grad)
3696
3697    def test_buffer_batch_norm(self):
3698        class MyModel(torch.nn.Module):
3699            def __init__(self) -> None:
3700                super().__init__()
3701                self.m = torch.nn.BatchNorm1d(100)
3702
3703            def forward(self, x):
3704                return self.m(x)
3705
3706        model_for_eager = MyModel()
3707        model_for_compile = copy.deepcopy(model_for_eager)
3708
3709        fw_graph_cell = [None]
3710        bw_graph_cell = [None]
3711        compiled_f = aot_module(
3712            model_for_compile,
3713            fw_compiler=make_boxed_compiler(
3714                partial(extract_graph, graph_cell=fw_graph_cell)
3715            ),
3716            bw_compiler=make_boxed_compiler(
3717                partial(extract_graph, graph_cell=bw_graph_cell)
3718            ),
3719            keep_inference_input_mutations=True,
3720        )
3721        inp_ref = torch.ones(20, 100, requires_grad=True)
3722        inp_test = torch.ones(20, 100, requires_grad=True)
3723
3724        out_ref = model_for_eager(inp_ref.clone())
3725        out_test = compiled_f(inp_test.clone())
3726
3727        self.assertExpectedInline(
3728            fw_graph_cell[0].code.strip(),
3729            """\
3730def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6):
3731    add = torch.ops.aten.add.Tensor(primals_5, 1)
3732    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(primals_6, primals_1, primals_2, primals_3, primals_4, True, 0.1, 1e-05);  primals_2 = None
3733    getitem = _native_batch_norm_legit_functional[0]
3734    getitem_1 = _native_batch_norm_legit_functional[1]
3735    getitem_2 = _native_batch_norm_legit_functional[2]
3736    getitem_3 = _native_batch_norm_legit_functional[3]
3737    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
3738    copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3);  primals_3 = copy_ = None
3739    copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4);  primals_4 = copy__1 = None
3740    copy__2 = torch.ops.aten.copy_.default(primals_5, add);  primals_5 = add = copy__2 = None
3741    return (getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4)""",  # noqa: B950
3742        )
3743
3744        self.assertEqual(out_ref, out_test)
3745
3746        out_ref.sum().backward()
3747        out_test.sum().backward()
3748
3749        eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
3750        compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]
3751        self.assertEqual(eager_grads, compile_grads)
3752
3753        self.assertExpectedInline(
3754            bw_graph_cell[0].code.strip(),
3755            """\
3756def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4, tangents_1):
3757    native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(tangents_1, primals_6, primals_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]);  tangents_1 = primals_6 = primals_1 = getitem_3 = getitem_4 = getitem_1 = getitem_2 = None
3758    getitem_5 = native_batch_norm_backward[0]
3759    getitem_6 = native_batch_norm_backward[1]
3760    getitem_7 = native_batch_norm_backward[2];  native_batch_norm_backward = None
3761    return (getitem_6, getitem_7, None, None, None, getitem_5)""",  # noqa: B950
3762        )
3763
3764        self.assertEqual(inp_ref.grad, inp_test.grad)
3765
3766    def test_new_inp_requires_grad_now(self):
3767        def f(x, y):
3768            return x.add_(y)
3769
3770        fw_graph_cell = [None]
3771        bw_graph_cell = [None]
3772        compiled_f = aot_function(
3773            f,
3774            fw_compiler=make_boxed_compiler(
3775                partial(extract_graph, graph_cell=fw_graph_cell)
3776            ),
3777            bw_compiler=make_boxed_compiler(
3778                partial(extract_graph, graph_cell=bw_graph_cell)
3779            ),
3780            keep_inference_input_mutations=True,
3781        )
3782
3783        inp_ref = (
3784            torch.ones(20, 100, requires_grad=False),
3785            torch.ones(20, 100, requires_grad=True),
3786        )
3787        inp_test = (
3788            torch.ones(20, 100, requires_grad=False),
3789            torch.ones(20, 100, requires_grad=True),
3790        )
3791
3792        out_ref = f(*inp_ref)
3793        out_test = compiled_f(*inp_test)
3794
3795        # There is no copy_ method
3796        self.assertExpectedInline(
3797            fw_graph_cell[0].code.strip(),
3798            """\
3799def forward(self, primals_1, primals_2):
3800    clone = torch.ops.aten.clone.default(primals_1);  primals_1 = None
3801    add = torch.ops.aten.add.Tensor(clone, primals_2);  clone = primals_2 = None
3802    return (add, add)""",
3803        )  # noqa: B950
3804
3805        self.assertEqual(out_ref, out_test)
3806
3807        out_ref.sum().backward()
3808        out_test.sum().backward()
3809
3810        self.assertExpectedInline(
3811            bw_graph_cell[0].code.strip(),
3812            """\
3813def forward(self, tangents_1):
3814    return (None, tangents_1)""",
3815        )  # noqa: B950
3816
3817    def test_real_weights_in_symbolic_mode(self):
3818        from functorch.experimental import functionalize
3819
3820        class M(torch.nn.Module):
3821            def __init__(self) -> None:
3822                super().__init__()
3823                self.linear = torch.nn.Linear(5, 5)
3824
3825            def forward(self, x):
3826                x = self.linear(x)
3827                return x
3828
3829        m = M().eval()
3830
3831        inp = torch.randn(2, 5)
3832
3833        gm = make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
3834        self.assertEqual(gm(torch.ones(2, 5)), m(torch.ones(2, 5)))
3835
3836        gm_functionalized = make_fx(
3837            functionalize(
3838                gm,
3839            ),
3840            tracing_mode="symbolic",
3841            _allow_non_fake_inputs=True,
3842        )(inp)
3843        self.assertEqual(gm_functionalized(torch.ones(2, 5)), m(torch.ones(2, 5)))
3844
3845        inp_count = 0
3846        for node in gm.graph.nodes:
3847            if node.op == "placeholder":
3848                inp_count += 1
3849
3850        # No more param lifting
3851        self.assertEqual(inp_count, 1)
3852
3853        inp_count = 0
3854        for node in gm_functionalized.graph.nodes:
3855            if node.op == "placeholder":
3856                inp_count += 1
3857
3858        # No more param lifting
3859        self.assertEqual(inp_count, 1)
3860
3861        with self.assertRaisesRegex(
3862            Exception, "Please convert all Tensors to FakeTensors"
3863        ):
3864            make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=False)(
3865                torch.randn(2, 5)
3866            )
3867
3868    def test_real_weights_in_symbolic_mode_with_inplace_ops(self):
3869        class M(torch.nn.Module):
3870            def __init__(self) -> None:
3871                super().__init__()
3872                self.buffer = torch.nn.Buffer(torch.ones(4, 5))
3873
3874            def forward(self, x):
3875                y = self.buffer.add_(3)
3876                y.resize_([20])
3877                assert y.shape == self.buffer.shape
3878                return x.sum() + self.buffer.sum()
3879
3880        m = M().eval()
3881        inp = torch.randn(2, 5)
3882        # inplace mutation on attr is not allowed
3883        with self.assertRaisesRegex(Exception, "Can't call metadata"):
3884            make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
3885
3886    def _compile_and_erase_bases(self, *output_view_indices):
3887        # Overrides _base and _view_func tensor attributes, so as to avoid the view-replay
3888        # execution path when reconstructing views.
3889        class NoViewReplayTensor(torch.Tensor):
3890            @property
3891            def _base(self):
3892                return None
3893
3894            @property
3895            def _view_func(self):
3896                return None
3897
3898        # Wraps the outputs that are views of the FX graph 'g' with NoViewReplayTensor,
3899        # since they are the only ones that will get reconstructed.
3900        def wrapper(g, *args, **kwargs):
3901            outs = list(g(*args, **kwargs))
3902            for i in output_view_indices:
3903                outs[i] = NoViewReplayTensor(outs[i])
3904            return tuple(outs)
3905
3906        return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g))
3907
3908    def test_output_aliases_input_view_meta_replay(self):
3909        @self._compile_and_erase_bases(0)
3910        def f(a):
3911            return a.view(-1)
3912
3913        inp = torch.ones(2, 2, requires_grad=True)
3914        out = f(inp)
3915
3916        self.assertIsNotNone(out.grad_fn)
3917        self.assertExpectedInline(
3918            str(out.grad_fn.__class__), """<class 'ViewBackward0'>"""
3919        )
3920
3921    def test_output_aliases_intermediate_view_meta_replay(self):
3922        @self._compile_and_erase_bases(0, 1)
3923        def f(a):
3924            b = a.clone()
3925            return b.view(-1), b.view(-1)
3926
3927        inp = torch.ones(2, 2, requires_grad=True)
3928        out1, out2 = f(inp)
3929
3930        self.assertIsNotNone(out1.grad_fn)
3931        self.assertExpectedInline(
3932            str(out1.grad_fn.__class__), """<class 'ViewBackward0'>"""
3933        )
3934
3935        self.assertIsNotNone(out2.grad_fn)
3936        self.assertExpectedInline(
3937            str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
3938        )
3939
3940    def test_output_aliases_output_view_meta_replay(self):
3941        @self._compile_and_erase_bases(1)
3942        def f(a):
3943            b = a.add(10)
3944            return b, b.view(-1)
3945
3946        inp = torch.ones(2, 2, requires_grad=True)
3947        out1, out2 = f(inp)
3948
3949        self.assertEqual(out1.untyped_storage(), out2.untyped_storage())
3950        self.assertIsNotNone(out2.grad_fn)
3951        self.assertExpectedInline(
3952            str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
3953        )
3954
3955    @skipIfTorchDynamo()
3956    @patch("torch._dynamo.config.assume_static_by_default", False)
3957    def test_dynamic_output_aliases_input_view_meta_replay(self):
3958        # - torch.compile: using it so we can have a SymInt in the FX graph.
3959        # - Compiling with inductor, so that tensor._base isn't tracked.
3960        #
3961        # This should force the use of as_strided in the view reconstruction path.
3962        # The first 2 view-replay paths won't be taken because:
3963        #   - target_functional_tensor will be symbolic (_functionalize_is_symbolic call)
3964        #   - tensor._base will be None
3965        @torch.compile(backend="inductor")
3966        def f(a, sz):
3967            return a.view(sz), a.view(-1)
3968
3969        inp = torch.ones(2, 2, requires_grad=True)
3970        out1, out2 = f(inp, (4,))
3971
3972        self.assertIsNotNone(out1.grad_fn)
3973        self.assertExpectedInline(
3974            str(out1.grad_fn.__class__), """<class 'AsStridedBackward0'>"""
3975        )
3976
3977        self.assertIsNotNone(out2.grad_fn)
3978        self.assertExpectedInline(
3979            str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
3980        )
3981
3982
3983def extract_graph(fx_g, _, graph_cell):
3984    graph_cell[0] = fx_g
3985    return fx_g
3986
3987
3988def get_ins_outs(fx_g):
3989    ins = []
3990    outs = []
3991    for n in fx_g.graph.nodes:
3992        if n.op == "placeholder":
3993            ins.append(n)
3994        elif n.op == "output":
3995            outs = tuple(n.args[0])
3996    return ins, outs
3997
3998
3999def get_num_ins_outs(fx_g):
4000    return tuple(len(i) for i in get_ins_outs(fx_g))
4001
4002
4003def get_fw_bw_graph(
4004    f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False
4005):
4006    fw_graph_cell = [None]
4007    bw_graph_cell = [None]
4008    aot_function(
4009        f,
4010        fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
4011        bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
4012        partition_fn=partitioner,
4013        decompositions=default_decompositions,
4014        dynamic=dynamic,
4015    )(*inps).sum().backward()
4016    return (fw_graph_cell[0], bw_graph_cell[0])
4017
4018
4019class TestMod(torch.nn.Module):
4020    def __init__(self, fn):
4021        super().__init__()
4022        self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True))
4023        self.fn = fn
4024
4025    def forward(self, *args):
4026        return self.fn(self.p, *args)
4027
4028
4029class TestAOTExport(AOTTestCase):
4030    def test_aot_export_ban_dropout_mut_pre_dispatch(self):
4031        def fn(p, x):
4032            y = torch.ops.aten.dropout.default(x, 0.1, train=False)
4033            y.add_(1)
4034            return (y,)
4035
4036        mod = TestMod(fn)
4037        inp = torch.randn(2, 2)
4038
4039        with self.assertRaisesRegex(
4040            RuntimeError, "cannot mutate tensors with frozen storage"
4041        ):
4042            aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4043
4044        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=False)
4045        self.assertExpectedInline(
4046            str(gm.code).strip(),
4047            """\
4048def forward(self, arg0_1, arg1_1):
4049    clone = torch.ops.aten.clone.default(arg1_1);  arg1_1 = None
4050    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
4051    return (add,)""",
4052        )
4053
4054        fw_graph_cell = [None]
4055        bw_graph_cell = [None]
4056
4057        compiled_outs = aot_function(
4058            fn,
4059            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
4060            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
4061            partition_fn=default_partition,
4062            decompositions=default_decompositions,
4063            dynamic=True,
4064        )(*inp)
4065        fw_graph = fw_graph_cell[0]
4066        bw_graph = bw_graph_cell[0]
4067
4068        self.assertExpectedInline(
4069            str(fw_graph.code).strip(),
4070            """\
4071def forward(self, arg0_1, arg1_1):
4072    clone = torch.ops.aten.clone.default(arg1_1);  arg1_1 = None
4073    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
4074    return (add,)""",
4075        )
4076
4077    def test_aot_export_predispatch_func_simple(self):
4078        def fn(p, x):
4079            y = x + 2
4080            with torch.no_grad():
4081                y.add_(2)
4082            return (x * 2 + y,)
4083
4084        mod = TestMod(fn)
4085        inp = torch.randn(2, 2)
4086
4087        with torch.no_grad():
4088            gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4089        self.assertExpectedInline(
4090            str(gm.code).strip(),
4091            """\
4092def forward(self, arg0_1, arg1_1):
4093    add = torch.ops.aten.add.Tensor(arg1_1, 2)
4094    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
4095    add_1 = torch.ops.aten.add.Tensor(add, 2);  add = None
4096    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
4097    mul = torch.ops.aten.mul.Tensor(arg1_1, 2);  arg1_1 = None
4098    add_2 = torch.ops.aten.add.Tensor(mul, add_1);  mul = add_1 = None
4099    return (add_2,)""",
4100        )
4101
4102    def test_aot_export_predispatch_func_composite_implicit(self):
4103        def fn(p, x):
4104            with torch.enable_grad():
4105                y = x @ x
4106            y.add_(2)
4107            return (x.sum() + y.sum(),)
4108
4109        mod = TestMod(fn)
4110        inp = torch.randn(2, 2)
4111
4112        with torch.no_grad():
4113            gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4114        self.assertExpectedInline(
4115            str(gm.code).strip(),
4116            """\
4117def forward(self, arg0_1, arg1_1):
4118    _set_grad_enabled = torch._C._set_grad_enabled(True);  _set_grad_enabled = None
4119    matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1)
4120    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
4121    add = torch.ops.aten.add.Tensor(matmul, 2);  matmul = None
4122    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
4123    sum_2 = torch.ops.aten.sum.default(add);  add = None
4124    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
4125    return (add_1,)""",
4126        )
4127
4128    def test_aot_export_predispatch_composite_implicit_inplace(self):
4129        def fn(x, p):
4130            return (torch.ops.aten.absolute_.default(x.clone()),)
4131
4132        mod = TestMod(fn)
4133        inp = torch.randn(2, 2)
4134
4135        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4136        self.assertExpectedInline(
4137            str(gm.code).strip(),
4138            """\
4139def forward(self, arg0_1, arg1_1):
4140    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
4141    abs_1 = torch.ops.aten.abs.default(clone);  clone = None
4142    return (abs_1,)""",
4143        )
4144
4145    def test_aot_export_predispatch_composite_implicit_linear(self):
4146        class MM(torch.nn.Module):
4147            def __init__(self) -> None:
4148                super().__init__()
4149                self.linear = torch.nn.Linear(2, 2)
4150
4151            def forward(self, x):
4152                return (self.linear(x),)
4153
4154        mod = MM()
4155        inp = torch.randn(2, 2)
4156
4157        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4158        self.assertExpectedInline(
4159            str(gm.code).strip(),
4160            """\
4161def forward(self, arg0_1, arg1_1, arg2_1):
4162    linear = torch.ops.aten.linear.default(arg2_1, arg0_1, arg1_1);  arg2_1 = arg0_1 = arg1_1 = None
4163    return (linear,)""",
4164        )
4165
4166    @unittest.expectedFailure
4167    def test_aot_export_predispatch_outdtype(self):
4168        class M(torch.nn.Module):
4169            def __init__(self, weight):
4170                super().__init__()
4171                self.weight = weight
4172
4173            def forward(self, x):
4174                y = x + 2
4175                y.add_(5)
4176                return (
4177                    out_dtype(torch.ops.aten.mm.default, torch.int32, y, self.weight),
4178                )
4179
4180        weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
4181        mod = M(weight)
4182        inp = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
4183
4184        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4185        self.assertExpectedInline(
4186            str(gm.code).strip(),
4187            """\
4188def forward(self, arg0_1, arg1_1):
4189    _set_grad_enabled = torch._C._set_grad_enabled(True);  _set_grad_enabled = None
4190    mm = torch.ops.aten.mm.default(arg1_1, arg1_1)
4191    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
4192    add = torch.ops.aten.add.Tensor(mm, 2);  mm = None
4193    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
4194    sum_2 = torch.ops.aten.sum.default(add);  add = None
4195    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
4196    return (add_1,)""",
4197        )
4198
4199    def test_aot_export_predispatch_func_view(self):
4200        def fn(p, x):
4201            y = x @ x
4202            y.add_(2)
4203            return (x.sum() + y.view(1, 4).sum(),)
4204
4205        mod = TestMod(fn)
4206        inp = torch.randn(2, 2)
4207
4208        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4209        self.assertExpectedInline(
4210            str(gm.code).strip(),
4211            """\
4212def forward(self, arg0_1, arg1_1):
4213    matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1)
4214    add = torch.ops.aten.add.Tensor(matmul, 2);  matmul = None
4215    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
4216    view_1 = torch.ops.aten.view.default(add, [1, 4]);  add = None
4217    sum_2 = torch.ops.aten.sum.default(view_1);  view_1 = None
4218    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
4219    return (add_1,)""",
4220        )
4221
4222    def test_aot_export_predispatch_buffer_mutation_metadata(self):
4223        class Foo(torch.nn.Module):
4224            def __init__(self) -> None:
4225                super().__init__()
4226                self.foo = torch.nn.Buffer(torch.zeros(2, 2))
4227
4228            def forward(self, x):
4229                self.foo.add_(4)
4230                return (x.sum() + self.foo.sum(),)
4231
4232        inp = torch.randn(2, 2)
4233
4234        gm, graph_sig = aot_export_module(
4235            Foo(), [inp], trace_joint=False, pre_dispatch=True
4236        )
4237        self.assertExpectedInline(
4238            str(gm.code).strip(),
4239            """\
4240def forward(self, arg0_1, arg1_1):
4241    add = torch.ops.aten.add.Tensor(arg0_1, 4);  arg0_1 = None
4242    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
4243    sum_2 = torch.ops.aten.sum.default(add)
4244    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
4245    return (add, add_1)""",
4246        )
4247        eager_mod = Foo()
4248        output_1, output_2 = gm(torch.zeros(2, 2), inp)
4249        eager_output = eager_mod(inp)
4250        self.assertTrue(torch.allclose(output_2, eager_output[0]))
4251
4252        _, output_2 = gm(output_1, inp)
4253        eager_output = eager_mod(inp)
4254        self.assertTrue(torch.allclose(output_2, eager_output[0]))
4255        self.assertTrue("foo" in graph_sig.buffers)
4256        self.assertTrue(graph_sig.inputs_to_buffers["arg0_1"] == "foo")
4257
4258    def test_aot_export_predispatch_with_autograd_op(self):
4259        def foo(p, x):
4260            with torch.enable_grad():
4261                y = x + 5
4262                y.add_(5)
4263                y.add_(7)
4264                return (x.cos() + y.sin(),)
4265
4266        inp = torch.randn(2, 2)
4267        mod = TestMod(foo)
4268
4269        with torch.no_grad():
4270            gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4271        self.assertExpectedInline(
4272            str(gm.code).strip(),
4273            """\
4274def forward(self, arg0_1, arg1_1):
4275    _set_grad_enabled = torch._C._set_grad_enabled(True);  _set_grad_enabled = None
4276    add = torch.ops.aten.add.Tensor(arg1_1, 5)
4277    add_1 = torch.ops.aten.add.Tensor(add, 5);  add = None
4278    add_2 = torch.ops.aten.add.Tensor(add_1, 7);  add_1 = None
4279    cos = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
4280    sin = torch.ops.aten.sin.default(add_2);  add_2 = None
4281    add_3 = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
4282    _set_grad_enabled_1 = torch._C._set_grad_enabled(False);  _set_grad_enabled_1 = None
4283    return (add_3,)""",
4284        )
4285
4286    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
4287    @unittest.skipIf(
4288        not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported"
4289    )
4290    def test_aot_export_predispatch_with_cond_nested(self):
4291        class M(torch.nn.Module):
4292            def __init__(self) -> None:
4293                super().__init__()
4294
4295            def forward(self, x):
4296                def true_fn(x):
4297                    y = x.sin()
4298                    y.add_(5)
4299
4300                    def true_true_fn(x):
4301                        y = x.sin()
4302                        y.add_(7)
4303                        return y.sin()
4304
4305                    def true_false_fn(x):
4306                        return x.cos()
4307
4308                    return torch.cond(
4309                        y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()]
4310                    )
4311
4312                def false_fn(x):
4313                    z = x.cos()
4314                    z.add_(6)
4315                    return z.sin()
4316
4317                a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
4318                return (a + 3, a + 4)
4319
4320        inp = torch.randn(2, 2)
4321        gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True)
4322        self.assertExpectedInline(
4323            str(gm.code).strip(),
4324            """\
4325def forward(self, arg0_1):
4326    sum_1 = torch.ops.aten.sum.default(arg0_1)
4327    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
4328    true_graph_0 = self.true_graph_0
4329    false_graph_0 = self.false_graph_0
4330    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
4331    getitem = cond[0];  cond = None
4332    add = torch.ops.aten.add.Tensor(getitem, 3)
4333    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
4334    return (add, add_1)""",  # noqa: B950
4335        )
4336
4337        self.assertExpectedInline(
4338            str(gm.true_graph_0.code).strip(),
4339            """\
4340def forward(self, arg0_1):
4341    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
4342    add = torch.ops.aten.add.Tensor(sin, 5);  sin = None
4343    cos = torch.ops.aten.cos.default(add)
4344    sum_1 = torch.ops.aten.sum.default(cos);  cos = None
4345    gt = torch.ops.aten.gt.Scalar(sum_1, 5);  sum_1 = None
4346    cos_1 = torch.ops.aten.cos.default(add);  add = None
4347    true_graph_0 = self.true_graph_0
4348    false_graph_0 = self.false_graph_0
4349    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]);  gt = true_graph_0 = false_graph_0 = cos_1 = None
4350    getitem = cond[0];  cond = None
4351    return (getitem,)""",  # noqa: B950
4352        )
4353
4354        self.assertExpectedInline(
4355            str(gm.true_graph_0.true_graph_0.code).strip(),
4356            """\
4357def forward(self, arg0_1):
4358    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
4359    add = torch.ops.aten.add.Tensor(sin, 7);  sin = None
4360    sin_1 = torch.ops.aten.sin.default(add);  add = None
4361    return (sin_1,)""",
4362        )
4363
4364    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
4365    @unittest.skipIf(
4366        not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported"
4367    )
4368    def test_aot_export_predispatch_map_1(self):
4369        class M(torch.nn.Module):
4370            def __init__(self) -> None:
4371                super().__init__()
4372
4373            def forward(self, x, y):
4374                def true_fn(x, r):
4375                    y = x.sin()
4376                    y.add_(5)
4377                    return y.cos() + r.sum()
4378
4379                def false_fn(x, r):
4380                    z = x.cos()
4381
4382                    def f(x, y):
4383                        a = x.cos()
4384                        a.add_(5)
4385                        return a + y
4386
4387                    return (
4388                        z
4389                        + control_flow.map(f, z, r).sum()
4390                        + control_flow.map(f, z, r).sum()
4391                    )
4392
4393                a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y])
4394                return (a + 3, a + 4)
4395
4396        inps = [torch.randn(2, 2), torch.ones(2)]
4397        gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
4398        self.assertExpectedInline(
4399            str(gm.code).strip(),
4400            """\
4401def forward(self, arg0_1, arg1_1):
4402    sum_1 = torch.ops.aten.sum.default(arg0_1)
4403    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
4404    true_graph_0 = self.true_graph_0
4405    false_graph_0 = self.false_graph_0
4406    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
4407    getitem = cond[0];  cond = None
4408    add = torch.ops.aten.add.Tensor(getitem, 3)
4409    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
4410    return (add, add_1)""",  # noqa: B950
4411        )
4412        self.assertExpectedInline(
4413            str(gm.true_graph_0.code).strip(),
4414            """\
4415def forward(self, arg0_1, arg1_1):
4416    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
4417    add = torch.ops.aten.add.Tensor(sin, 5);  sin = None
4418    cos = torch.ops.aten.cos.default(add);  add = None
4419    sum_1 = torch.ops.aten.sum.default(arg1_1);  arg1_1 = None
4420    add_1 = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
4421    return (add_1,)""",
4422        )
4423        self.assertExpectedInline(
4424            str(gm.false_graph_0.code).strip(),
4425            """\
4426def forward(self, arg0_1, arg1_1):
4427    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
4428    select = torch.ops.aten.select.int(cos, 0, 0);  select = None
4429    body_graph_0 = self.body_graph_0
4430    map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]);  body_graph_0 = None
4431    getitem = map_impl[0];  map_impl = None
4432    sum_1 = torch.ops.aten.sum.default(getitem);  getitem = None
4433    add = torch.ops.aten.add.Tensor(cos, sum_1);  sum_1 = None
4434    select_1 = torch.ops.aten.select.int(cos, 0, 0);  select_1 = None
4435    body_graph_1 = self.body_graph_1
4436    map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]);  body_graph_1 = cos = arg1_1 = None
4437    getitem_1 = map_impl_1[0];  map_impl_1 = None
4438    sum_2 = torch.ops.aten.sum.default(getitem_1);  getitem_1 = None
4439    add_1 = torch.ops.aten.add.Tensor(add, sum_2);  add = sum_2 = None
4440    return (add_1,)""",
4441        )
4442        self.assertExpectedInline(
4443            str(gm.false_graph_0.body_graph_0.code).strip(),
4444            """\
4445def forward(self, arg0_1, arg1_1):
4446    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
4447    add = torch.ops.aten.add.Tensor(cos, 5);  cos = None
4448    add_1 = torch.ops.aten.add.Tensor(add, arg1_1);  add = arg1_1 = None
4449    return (add_1,)""",
4450        )
4451
4452    def test_aot_export_predispatch_map_2(self):
4453        class M(torch.nn.Module):
4454            def __init__(self) -> None:
4455                super().__init__()
4456
4457            def forward(self, x, y):
4458                z = x.cos()
4459
4460                def f(x, y):
4461                    a = x.cos()
4462                    a.add_(5)
4463                    return a + y
4464
4465                return (z + control_flow.map(f, z, y).sum(),)
4466
4467        inps = [torch.randn(2, 2), torch.ones(2)]
4468        gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
4469        self.assertExpectedInline(
4470            str(gm.code).strip(),
4471            """\
4472def forward(self, arg0_1, arg1_1):
4473    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
4474    body_graph_0 = self.body_graph_0
4475    map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]);  body_graph_0 = arg1_1 = None
4476    getitem = map_impl[0];  map_impl = None
4477    sum_1 = torch.ops.aten.sum.default(getitem);  getitem = None
4478    add = torch.ops.aten.add.Tensor(cos, sum_1);  cos = sum_1 = None
4479    return (add,)""",
4480        )  # noqa: B950
4481        self.assertExpectedInline(
4482            str(gm.body_graph_0.code).strip(),
4483            """\
4484def forward(self, arg0_1, arg1_1):
4485    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
4486    add = torch.ops.aten.add.Tensor(cos, 5);  cos = None
4487    add_1 = torch.ops.aten.add.Tensor(add, arg1_1);  add = arg1_1 = None
4488    return [add_1]""",
4489        )
4490
4491    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
4492    @unittest.skipIf(
4493        not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported"
4494    )
4495    def test_aot_export_predispatch_with_cond(self):
4496        class M(torch.nn.Module):
4497            def __init__(self) -> None:
4498                super().__init__()
4499
4500            def forward(self, x):
4501                def true_fn(x):
4502                    y = x.sin()
4503                    z = torch.ops.aten.linear.default(y, torch.randn(2, 2))
4504                    z.add_(5)
4505                    return z.cos()
4506
4507                def false_fn(x):
4508                    z = x.cos()
4509                    z.add_(6)
4510                    return z.sin()
4511
4512                a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
4513                return (a + 3, a + 4)
4514
4515        inp = torch.randn(2, 2)
4516        gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True)
4517        self.assertExpectedInline(
4518            str(gm.code).strip(),
4519            """\
4520def forward(self, arg0_1):
4521    sum_1 = torch.ops.aten.sum.default(arg0_1)
4522    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
4523    true_graph_0 = self.true_graph_0
4524    false_graph_0 = self.false_graph_0
4525    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
4526    getitem = cond[0];  cond = None
4527    add = torch.ops.aten.add.Tensor(getitem, 3)
4528    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
4529    return (add, add_1)""",  # noqa: B950
4530        )
4531        self.assertExpectedInline(
4532            str(gm.true_graph_0.code).strip(),
4533            """\
4534def forward(self, arg0_1):
4535    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
4536    randn = torch.ops.aten.randn.default([2, 2], device = device(type='cpu'), pin_memory = False)
4537    linear = torch.ops.aten.linear.default(sin, randn);  sin = randn = None
4538    add = torch.ops.aten.add.Tensor(linear, 5);  linear = None
4539    cos = torch.ops.aten.cos.default(add);  add = None
4540    return (cos,)""",
4541        )
4542
4543    def test_aot_export_predispatch_conv_and_bn(self):
4544        class ConvBatchnorm(torch.nn.Module):
4545            def __init__(self) -> None:
4546                super().__init__()
4547                self.conv = torch.nn.Conv2d(1, 3, 1, 1)
4548                self.bn = torch.nn.BatchNorm2d(3)
4549
4550            def forward(self, x):
4551                x = self.conv(x)
4552                x = self.bn(x)
4553                return (x,)
4554
4555        mod = ConvBatchnorm()
4556        mod.train()
4557        inp = torch.randn(1, 1, 3, 3)
4558
4559        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4560        self.assertExpectedInline(
4561            str(gm.code).strip(),
4562            """\
4563def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
4564    conv2d = torch.ops.aten.conv2d.default(arg7_1, arg0_1, arg1_1);  arg7_1 = arg0_1 = arg1_1 = None
4565    add = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
4566    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05);  conv2d = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None
4567    getitem = _native_batch_norm_legit_functional[0]
4568    getitem_3 = _native_batch_norm_legit_functional[3]
4569    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
4570    return (getitem_3, getitem_4, add, getitem)""",  # noqa: B950
4571        )
4572
4573    def test_aot_export_predispatch_reshape(self):
4574        class Reshape(torch.nn.Module):
4575            def forward(self, x):
4576                y = x.reshape(4, 4)
4577                return (y.sum(),)
4578
4579        mod = Reshape()
4580        inp = torch.randn(2, 8)
4581
4582        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4583        self.assertExpectedInline(
4584            str(gm.code).strip(),
4585            """\
4586def forward(self, arg0_1):
4587    view = torch.ops.aten.view.default(arg0_1, [4, 4]);  arg0_1 = None
4588    sum_1 = torch.ops.aten.sum.default(view);  view = None
4589    return (sum_1,)""",
4590        )  # noqa: B950
4591
4592    def test_aot_export_predispatch_contiguous(self):
4593        class Cont(torch.nn.Module):
4594            def forward(self, x):
4595                y = torch.ops.aten.contiguous.default(x)
4596                return (y.sum(),)
4597
4598        mod = Cont()
4599        inp = torch.randn(2, 8)
4600
4601        gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True)
4602        self.assertExpectedInline(
4603            str(gm.code).strip(),
4604            """\
4605def forward(self, arg0_1):
4606    sum_1 = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
4607    return (sum_1,)""",
4608        )  # noqa: B950
4609
4610    def test_aot_export_module_joint(self):
4611        class ConvBatchnormRelu(torch.nn.Module):
4612            def __init__(self) -> None:
4613                super().__init__()
4614                self.conv = torch.nn.Conv2d(1, 3, 1, 1)
4615                self.bn = torch.nn.BatchNorm2d(3)
4616
4617            def forward(self, x):
4618                x = self.conv(x)
4619                x = self.bn(x)
4620                user_out = torch.nn.functional.relu(x)
4621                loss = user_out.sum()
4622                return loss, user_out.detach()
4623
4624        mod = ConvBatchnormRelu()
4625        mod.train()
4626        inp = torch.randn(1, 1, 3, 3)
4627        o_ref = mod(inp)
4628        fx_g, signature = aot_export_module(
4629            mod, [inp], trace_joint=True, output_loss_index=0
4630        )
4631        # Some important characteristics of the exported graph below:
4632        # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
4633        # 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)
4634        for node in fx_g.graph.nodes:
4635            node.meta.pop("stack_trace", None)
4636        self.assertExpectedInline(
4637            fx_g.print_readable(print_output=False),
4638            """\
4639class <lambda>(torch.nn.Module):
4640    def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"):
4641        # No stacktrace found for following nodes
4642        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  arg1_1 = None
4643        add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
4644        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05);  arg3_1 = arg4_1 = arg5_1 = None
4645        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
4646        getitem_1: "f32[3]" = _native_batch_norm_legit_functional[1]
4647        getitem_2: "f32[3]" = _native_batch_norm_legit_functional[2]
4648        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
4649        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
4650        relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem);  getitem = None
4651        detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu);  detach = None
4652        detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu)
4653        detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1);  detach_1 = None
4654        detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2);  detach_2 = None
4655        detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3);  detach_3 = None
4656        sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
4657        detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu);  relu = None
4658        detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5);  detach_5 = None
4659        detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6);  detach_6 = None
4660        detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7);  detach_7 = None
4661        detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8);  detach_8 = None
4662        detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9);  detach_9 = None
4663        ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
4664        expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]);  ones_like = None
4665        detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4);  detach_4 = None
4666        detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11);  detach_11 = None
4667        detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12);  detach_12 = None
4668        detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13);  detach_13 = None
4669        threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0);  expand = detach_14 = None
4670        native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]);  threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
4671        getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0]
4672        getitem_6: "f32[3]" = native_batch_norm_backward[1]
4673        getitem_7: "f32[3]" = native_batch_norm_backward[2];  native_batch_norm_backward = None
4674        convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]);  getitem_5 = arg7_1 = arg0_1 = None
4675        getitem_8 = convolution_backward[0];  getitem_8 = None
4676        getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1]
4677        getitem_10: "f32[3]" = convolution_backward[2];  convolution_backward = None
4678        return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7)
4679        """,  # noqa: B950
4680        )
4681
4682        self.assertExpectedInline(
4683            str(signature.parameters),
4684            """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""",
4685        )
4686        self.assertExpectedInline(
4687            str(signature.buffers),
4688            """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""",
4689        )
4690        self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""")
4691        self.assertExpectedInline(
4692            str(signature.inputs_to_parameters),
4693            """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""",
4694        )  # noqa: B950
4695        self.assertExpectedInline(
4696            str(signature.inputs_to_buffers),
4697            """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""",
4698        )  # noqa: B950
4699        self.assertExpectedInline(
4700            str(signature.buffers_to_mutate),
4701            """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""",
4702        )  # noqa: B950
4703        self.assertExpectedInline(
4704            str(signature.backward_signature.gradients_to_parameters),
4705            """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""",
4706        )  # noqa: B950
4707        self.assertExpectedInline(
4708            str(signature.backward_signature.gradients_to_user_inputs), """{}"""
4709        )
4710        self.assertExpectedInline(
4711            str(signature.backward_signature.loss_output), """getitem_3"""
4712        )
4713
4714        # Also check the inference graph
4715        # Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs.
4716        fx_g_inference, signature_inference = aot_export_module(
4717            mod, [inp], trace_joint=False
4718        )
4719        for node in fx_g_inference.graph.nodes:
4720            node.meta.pop("stack_trace", None)
4721        self.assertExpectedInline(
4722            fx_g_inference.print_readable(print_output=False),
4723            """\
4724class <lambda>(torch.nn.Module):
4725    def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"):
4726        # No stacktrace found for following nodes
4727        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  arg7_1 = arg0_1 = arg1_1 = None
4728        add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
4729        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05);  convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None
4730        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
4731        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
4732        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
4733        relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem);  getitem = None
4734        sum_1: "f32[]" = torch.ops.aten.sum.default(relu)
4735        detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu);  relu = None
4736        detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach);  detach = None
4737        detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1);  detach_1 = None
4738        return (getitem_3, getitem_4, add, sum_1, detach_2)
4739        """,  # noqa: B950
4740        )
4741        # Some important characteristics of the exported graph below:
4742        # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
4743        # 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)
4744
4745    def test_aot_export_simplified_basic(self):
4746        def f(x, y):
4747            return x * y, y * y.detach()
4748
4749        x = torch.randn(2, requires_grad=True)
4750        y = torch.randn(2, requires_grad=True)
4751
4752        f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False)
4753        out_ref = f(x, y)
4754        # No calling convention changes necessary to invoke the traced graph
4755        out_test = f_graph_fw(x, y)
4756        self.assertEqual(out_ref, out_test)
4757
4758        # Now test the backward
4759        x = torch.randn(2, requires_grad=True)
4760        y = torch.randn(2, requires_grad=True)
4761        x2 = x.clone().detach().requires_grad_(True)
4762        y2 = y.clone().detach().requires_grad_(True)
4763        x3 = x.clone().detach().requires_grad_(True)
4764        y3 = y.clone().detach().requires_grad_(True)
4765        f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True)
4766        num_fw_outputs = 2
4767        fw_g, bw_g = default_partition(
4768            f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs
4769        )
4770        out_ref2 = f(x2, y2)
4771        fw_outs = fw_g(x3, y3)
4772        out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:]
4773        self.assertEqual(out_ref2, out_test2)
4774
4775        # Test running the traced backward graph with a mocked-up grad_output
4776        grad_outs = [torch.ones_like(x) for x in out_ref2]
4777        grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs)
4778        grads_test = bw_g(*activations, *grad_outs)
4779        for g_ref, g_test in zip(grads_ref, grads_test):
4780            self.assertEqual(g_ref, g_test)
4781
4782    def test_aot_export_metadata_mutation_banned(self):
4783        def fn(p, x):
4784            x.t_()
4785            return (x * 2,)
4786
4787        mod = TestMod(fn)
4788        inp = torch.randn(2, 4)
4789        with self.assertRaisesRegex(
4790            RuntimeError, "Found an input that received a metadata mutation"
4791        ):
4792            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
4793            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
4794            aot_export_module(mod, [inp], trace_joint=False)
4795
4796    def test_aot_export_forward_mutation_no_buffer_mut(self):
4797        class M(torch.nn.Module):
4798            def __init__(self) -> None:
4799                super().__init__()
4800                self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))
4801
4802            def forward(self, x):
4803                x.add_(4)
4804                return (x.cos().sum() + self.buffer1.sum(),)
4805
4806        mod = M()
4807        inp = torch.ones(6, 4)
4808        gm, sig = aot_export_module(mod, [inp], trace_joint=False)
4809        self.assertExpectedInline(
4810            str(gm.code).strip(),
4811            """\
4812def forward(self, arg0_1, arg1_1):
4813    add = torch.ops.aten.add.Tensor(arg1_1, 4);  arg1_1 = None
4814    cos = torch.ops.aten.cos.default(add)
4815    sum_1 = torch.ops.aten.sum.default(cos);  cos = None
4816    sum_2 = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
4817    add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
4818    return (add, add_1)""",
4819        )  # noqa: B950
4820        self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg1_1"})
4821
4822    def test_aot_export_forward_mutation_multiple_mut(self):
4823        class M(torch.nn.Module):
4824            def __init__(self) -> None:
4825                super().__init__()
4826                self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))
4827
4828            def forward(self, x, y):
4829                y.add_(4)
4830                self.buffer1.add_(5)
4831                return (
4832                    x.cos().sum() + y.sin().sum(),
4833                    self.buffer1.sum(),
4834                )
4835
4836        mod = M()
4837        inp = [torch.ones(6, 4), torch.zeros(6, 4)]
4838        gm, sig = aot_export_module(mod, inp, trace_joint=False)
4839        self.assertExpectedInline(
4840            str(gm.code).strip(),
4841            """\
4842def forward(self, arg0_1, arg1_1, arg2_1):
4843    add = torch.ops.aten.add.Tensor(arg2_1, 4);  arg2_1 = None
4844    add_1 = torch.ops.aten.add.Tensor(arg0_1, 5);  arg0_1 = None
4845    cos = torch.ops.aten.cos.default(arg1_1);  arg1_1 = None
4846    sum_1 = torch.ops.aten.sum.default(cos);  cos = None
4847    sin = torch.ops.aten.sin.default(add)
4848    sum_2 = torch.ops.aten.sum.default(sin);  sin = None
4849    add_2 = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
4850    sum_3 = torch.ops.aten.sum.default(add_1)
4851    return (add_1, add, add_2, sum_3)""",
4852        )  # noqa: B950
4853        self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg2_1"})
4854        self.assertEqual(sig.buffers_to_mutate, {"add_1": "buffer1"})
4855
4856    def test_aot_export_input_mutation_on_input_requiring_grad_banned(self):
4857        class M(torch.nn.Module):
4858            def forward(self, x):
4859                x.add_(4)
4860                return (x,)
4861
4862        mod = M()
4863        inp = torch.randn(2, requires_grad=True)
4864        with self.assertRaisesRegex(
4865            RuntimeError,
4866            "Found a graph input that requires gradients, and received a mutation",
4867        ):
4868            aot_export_module(mod, [inp], trace_joint=False)
4869
4870    def test_aot_export_input_mutation_on_parameter_banned(self):
4871        def fn(p, x):
4872            p.mul_(2)
4873            return (p + x,)
4874
4875        mod = TestMod(fn)
4876        inp = torch.randn(2)
4877        with self.assertRaisesRegex(
4878            RuntimeError,
4879            "Found a graph input that requires gradients, and received a mutation",
4880        ):
4881            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
4882            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
4883            aot_export_module(mod, [inp], trace_joint=False)
4884
4885    def test_aot_export_synthetic_bases_banned(self):
4886        def fn(p, x, y):
4887            x.mul_(2)
4888            return (x + y,)
4889
4890        mod = TestMod(fn)
4891        inp = torch.randn(2)
4892        inp2 = inp.view(-1)
4893        with self.assertRaisesRegex(
4894            RuntimeError, "Encountered aliased inputs that are mutated"
4895        ):
4896            aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False)
4897            aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True)
4898            aot_export_module(mod, [inp, inp2], trace_joint=False)
4899
4900    def test_aot_export_input_dupes_banned(self):
4901        def fn(p, x, y):
4902            x.mul_(2)
4903            return (x + y,)
4904
4905        mod = TestMod(fn)
4906        inp = torch.randn(2)
4907        with self.assertRaisesRegex(
4908            RuntimeError, "Encountered duplicated inputs that are mutated in the graph"
4909        ):
4910            aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False)
4911            aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True)
4912            aot_export_module(mod, [inp, inp], trace_joint=False)
4913
4914    def test_aot_export_multiple_outputs_require_grad_banned(self):
4915        def fn(p, x):
4916            out = p * x
4917            return out, out.sum()
4918
4919        mod = TestMod(fn)
4920        inp = torch.randn(2)
4921        with self.assertRaisesRegex(
4922            RuntimeError,
4923            "Found an output of the forward that requires gradients, that was not",
4924        ):
4925            aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1)
4926
4927    @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
4928    @unittest.skipIf(
4929        not torch._dynamo.is_dynamo_supported(), "Cond needs dynamo to run"
4930    )
4931    def test_aot_export_with_torch_cond(self):
4932        class M(torch.nn.Module):
4933            def __init__(self) -> None:
4934                super().__init__()
4935
4936            def forward(self, x):
4937                def true_fn(x):
4938                    y = x + 4
4939                    y.add_(5)
4940                    return x.cos()
4941
4942                def false_fn(x):
4943                    y = x + 5
4944                    y.add_(6)
4945                    return x.sin()
4946
4947                a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
4948                return (a + 3, a + 4)
4949
4950        inp = torch.randn(3, 4)
4951        gm, _ = aot_export_module(M(), (inp,), trace_joint=False)
4952        self.assertExpectedInline(
4953            gm.code.strip(),
4954            """\
4955def forward(self, arg0_1):
4956    sum_1 = torch.ops.aten.sum.default(arg0_1)
4957    gt = torch.ops.aten.gt.Scalar(sum_1, 4);  sum_1 = None
4958    true_graph_0 = self.true_graph_0
4959    false_graph_0 = self.false_graph_0
4960    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
4961    getitem = cond[0];  cond = None
4962    add = torch.ops.aten.add.Tensor(getitem, 3)
4963    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
4964    return (add, add_1)""",  # noqa: B950
4965        )
4966
4967        self.assertExpectedInline(
4968            gm.true_graph_0.code.strip(),
4969            """\
4970def forward(self, arg0_1):
4971    add = torch.ops.aten.add.Tensor(arg0_1, 4)
4972    add_1 = torch.ops.aten.add.Tensor(add, 5);  add = add_1 = None
4973    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
4974    return (cos,)""",
4975        )
4976
4977        self.assertExpectedInline(
4978            gm.false_graph_0.code.strip(),
4979            """\
4980def forward(self, arg0_1):
4981    add = torch.ops.aten.add.Tensor(arg0_1, 5)
4982    add_1 = torch.ops.aten.add.Tensor(add, 6);  add = add_1 = None
4983    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
4984    return (sin,)""",
4985        )
4986
4987    def test_aot_export_simplified_pytrees_banned(self):
4988        def fn(inps):
4989            return (inps[0] + inps[1],)
4990
4991        inp1 = torch.randn(2)
4992        inp2 = torch.randn(2)
4993        inps = [inp1, inp2]
4994        with self.assertRaisesRegex(
4995            RuntimeError,
4996            "aot_export_joint_simple requires individual inputs not to be pytrees",
4997        ):
4998            aot_export_joint_simple(fn, [inps], trace_joint=False)
4999            aot_export_joint_simple(fn, [inps], trace_joint=True)
5000
5001    def test_aot_export_functionalized_rng_banned(self):
5002        def fn(p, x):
5003            return (p + x,)
5004
5005        mod = TestMod(fn)
5006        inp = torch.randn(2)
5007        with patch(
5008            "functorch.compile.config.functionalize_rng_ops", True
5009        ), self.assertRaisesRegex(
5010            RuntimeError,
5011            "Functionalized RNG is not currently supported in the aot_export",
5012        ):
5013            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
5014            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
5015            aot_export_module(mod, [inp], trace_joint=False)
5016
5017    def test_aot_export_unbacked_arg(self):
5018        class M(torch.nn.Module):
5019            def forward(self):
5020                full = torch.full((), 11)
5021                i0 = full.item()
5022                return (torch.full((i0,), 0),)
5023
5024        gm, _ = aot_export_module(
5025            mod=M(), args=(), trace_joint=False, dynamic_shapes=True
5026        )
5027        self.assertExpectedInline(
5028            gm.code.strip(),
5029            """\
5030def forward(self):
5031    full = torch.ops.aten.full.default([], 11, device = device(type='cpu'), pin_memory = False)
5032    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(full);  full = None
5033    full_1 = torch.ops.aten.full.default([_local_scalar_dense], 0, device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
5034    return (full_1,)""",  # noqa: B950
5035        )
5036
5037
5038class TestPartitioning(AOTTestCase):
5039    @unittest.skipIf(not USE_NETWORKX, "networkx not available")
5040    def test_recompute_partitioning(self):
5041        def fn(a, b):
5042            return torch.sin(torch.sin(a)) + b
5043
5044        # Reference calculation
5045        ref_a = torch.rand(10, 10, requires_grad=True)
5046        ref_b = torch.rand(10, 10, requires_grad=True)
5047        ref = fn(ref_a, ref_b)
5048        ref.sum().backward()
5049
5050        # Compiled function calculation
5051        res_a = ref_a.clone().detach().requires_grad_(True)
5052        res_b = ref_b.clone().detach().requires_grad_(True)
5053
5054        def compile_fn(x, _):
5055            return x
5056
5057        compiled_fn = compiled_function(
5058            fn, compile_fn, compile_fn, min_cut_rematerialization_partition
5059        )
5060        res = compiled_fn(res_a, res_b)
5061        res.sum().backward()
5062        assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
5063        assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3)
5064        assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3)
5065
5066    def test_meta_tensor_inplace_op(self):
5067        # Following module results in inplace ops while tracing. The test checks
5068        # that the meta tensor information is stored for inplace ops.
5069        class MockModule(torch.nn.Module):
5070            def __init__(self) -> None:
5071                super().__init__()
5072                self.weight = torch.nn.Parameter(
5073                    torch.randn(3072, 768, requires_grad=True)
5074                )
5075                self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True))
5076
5077            def forward(self, add_4):
5078                linear_4 = torch.nn.functional.linear(
5079                    add_4, self.weight, bias=self.bias
5080                )
5081                gelu = torch.nn.functional.gelu(linear_4)
5082                return gelu
5083
5084        def check_meta_tensor(fx_g, _):
5085            for node in fx_g.graph.nodes:
5086                if node.op != "output":
5087                    assert "tensor_meta" in node.meta
5088            return fx_g
5089
5090        inp0 = torch.randn(16, 128, 768, requires_grad=True)
5091        inputs = [
5092            inp0,
5093        ]
5094        mod = MockModule().to(device="cpu")
5095        aot_mod = aot_module(mod, fw_compiler=check_meta_tensor)
5096        aot_mod(*inputs)
5097
5098    def test_default_partitioner_getitem(self):
5099        mod = nn.LayerNorm([10])
5100
5101        def f(x, mod_weight, mod_bias):
5102            return torch.nn.functional.layer_norm(
5103                x, [10], mod_weight, mod_bias, eps=1e-6
5104            )
5105
5106        fw_graph, bw_graph = get_fw_bw_graph(
5107            f,
5108            [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias],
5109            partitioner=default_partition,
5110        )
5111        self.assertEqual(get_num_ins_outs(fw_graph), (3, 6))
5112        self.assertEqual(get_num_ins_outs(bw_graph), (6, 3))
5113
5114    @unittest.skipIf(not USE_NETWORKX, "networkx not available")
5115    def test_min_cut_partitioner_save_shape(self):
5116        def f(x):
5117            s = x.sum(dim=1)
5118            return s
5119
5120        inp = [torch.ones([10, 10], requires_grad=True)]
5121        fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True)
5122        _, fw_output = get_ins_outs(fw_graph)
5123        self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
5124        self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
5125        self.assertEqual(str(fw_output[0]), "sum_1")
5126        # make sure we don't do the suboptimal thing of saving the bigger primals input to sum,
5127        # rather than saving the sizes of the primals input for use in backward expand
5128        self.assertEqual(str(fw_output[1]), "sym_size_int")
5129        self.assertEqual(str(fw_output[2]), "sym_size_int_1")
5130
5131        inp = [
5132            torch.randn(10, requires_grad=True),
5133            torch.randn((3, 10), requires_grad=True),
5134            torch.randn((2, 10), requires_grad=True),
5135        ]
5136
5137        def f(a, b, c):
5138            # tried to test what happens if we save a size tuple in the graph;
5139            # turns out we never will due to how we trace, but this is probably
5140            # still a good test case for various size manipulations
5141            sb = torch.ops.aten.sym_size(b)
5142            sc = c.size()
5143            x = sb[0] + sc[0]
5144            a_sz = (x, a.size(0))
5145            return torch.cat([a.expand(a_sz), b, c])
5146
5147        fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True)
5148        self.assertEqual(get_num_ins_outs(fw_graph), (3, 4))
5149        self.assertEqual(get_num_ins_outs(bw_graph), (4, 3))
5150        _, outs = get_ins_outs(fw_graph)
5151        self.assertTrue(all(is_sym_node(n) for n in outs[1:]))
5152
5153    def test_default_partitioner_output_tensor_shape_tensor(self):
5154        inp = [
5155            torch.randn(10, requires_grad=True),
5156            torch.randn((3, 10), requires_grad=True),
5157            torch.randn((2, 10), requires_grad=True),
5158            torch.randn((10, 1), requires_grad=True),
5159        ]
5160
5161        def f(a, b, c, d):
5162            # Try to force symints intermixed with outputs in the function's returns
5163            sb = b.size()
5164            sc = c.size()
5165            x = sb[0] + sc[0]
5166            a_sz = (x, a.size(0))
5167            cat = torch.cat([a.expand(a_sz), b, c])
5168            mm = torch.mm(cat, d)
5169            mm2 = torch.mm(
5170                mm, a.view(mm.size(1), a.size(0))
5171            )  # this saves 4 new ints for backward. why?
5172            # and what do i have to do to make it save a tensor for backward?
5173            return cat, sb, c, mm2
5174
5175        fw_graph_cell = [None]
5176        bw_graph_cell = [None]
5177        compiled_outs = aot_function(
5178            f,
5179            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
5180            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
5181            partition_fn=default_partition,
5182            decompositions=default_decompositions,
5183            dynamic=True,
5184        )(*inp)
5185        fw_graph = fw_graph_cell[0]
5186        (compiled_outs[0].sum() + compiled_outs[2].sum()).backward()
5187        bw_graph = bw_graph_cell[0]
5188
5189        # in the fwd graph, 13 outs because:
5190        # - 5 original outputs (sb is a tuple, gets expanded to 2 symints)
5191        # - 8 saved outputs for backward: 5 tensors, 3 symints
5192        self.assertEqual(get_num_ins_outs(fw_graph), (4, 13))
5193        # in the bwd graph, 10 inputs (grad outs) because:
5194        # - The fwd graph had 13 outputs
5195        # - 1 was a view of an input, which gets regenerated outside of the graph
5196        #   and doesn't participate in the backward
5197        # - 2 user outs were symints (b.size()), which don't get tangents in the backward
5198        self.assertEqual(get_num_ins_outs(bw_graph), (10, 4))
5199        _, fw_graph_out_nodes = get_ins_outs(fw_graph)
5200        self.assertEqual(
5201            # fw outputs include b.size() which expands to 2 symints,
5202            #
5203            # TODO(whc)- are the saved-tensors/saved-symints correct here?
5204            # i just made the test pass based on what default partition did
5205            # Of the 5 original forward outputs, the 4th (c) is an input,
5206            # which won't show up in the compiled forward graph
5207            [False, True, True, False, False] + [False] * 4 + [True] * 4,
5208            [is_sym_node(n) for n in fw_graph_out_nodes],
5209        )
5210
5211        real_outs = f(*inp)
5212        self.assertEqual(compiled_outs, real_outs)
5213        self.assertTrue(isinstance(real_outs[1], torch.Size))
5214
5215        # TODO(whc) we should learn to return torch.Sizes
5216        self.assertFalse(isinstance(compiled_outs[1], torch.Size))
5217
5218    @unittest.skipIf(not USE_NETWORKX, "networkx not available")
5219    def test_min_cut_partitioner_output_tensor_shape_tensor(self):
5220        inp = [
5221            torch.randn(10, requires_grad=True),
5222            torch.randn((3, 10), requires_grad=True),
5223            torch.randn((2, 10), requires_grad=True),
5224            torch.randn((10, 1), requires_grad=True),
5225        ]
5226
5227        def f(a, b, c, d):
5228            # Try to force symints intermixed with outputs in the function's returns
5229            sb = b.size()
5230            sc = c.size()
5231            x = sb[0] + sc[0]
5232            a_sz = (x, a.size(0))
5233            cat = torch.cat([a.expand(a_sz), b, c])
5234            mm = torch.mm(cat, d)
5235            mm2 = torch.mm(
5236                mm, a.view(mm.size(1), a.size(0))
5237            )  # this saves 4 new ints for backward. why?
5238            # and what do i have to do to make it save a tensor for backward?
5239            return cat, sb, c, mm2
5240
5241        fw_graph_cell = [None]
5242        bw_graph_cell = [None]
5243        compiled_outs = aot_function(
5244            f,
5245            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
5246            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
5247            partition_fn=min_cut_rematerialization_partition,
5248            decompositions=default_decompositions,
5249            dynamic=True,
5250        )(*inp)
5251        fw_graph = fw_graph_cell[0]
5252        (compiled_outs[0].sum() + compiled_outs[2].sum()).backward()
5253        bw_graph = bw_graph_cell[0]
5254
5255        self.assertEqual(get_num_ins_outs(fw_graph), (4, 12))
5256        self.assertEqual(get_num_ins_outs(bw_graph), (9, 4))
5257        _, fw_graph_out_nodes = get_ins_outs(fw_graph)
5258        self.assertEqual(
5259            # fw outputs include b.size() which expands to 2 symints,
5260            # then 4 tensors (transposes of matricies used for mm) are saved
5261            # finally 3 symints are saved
5262            [False, True, True, False, False] + [False] * 4 + [True] * 3,
5263            [is_sym_node(n) for n in fw_graph_out_nodes],
5264        )
5265
5266        real_outs = f(*inp)
5267        self.assertEqual(compiled_outs, real_outs)
5268        self.assertTrue(isinstance(real_outs[1], torch.Size))
5269
5270        # TODO(whc) we should learn to return torch.Sizes
5271        self.assertFalse(isinstance(compiled_outs[1], torch.Size))
5272
5273    @unittest.skipIf(not USE_NETWORKX, "networkx not available")
5274    def test_min_cut_partitioner(self):
5275        def f(x):
5276            return x.cos().cos().cos()
5277
5278        fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)])
5279        self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
5280        self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
5281
5282        def f(a, b, c, d):
5283            x = a + b + c + d
5284            return x.cos().cos()
5285
5286        fw_graph, bw_graph = get_fw_bw_graph(
5287            f, [torch.randn(3, requires_grad=True) for _ in range(4)]
5288        )
5289        self.assertEqual(get_num_ins_outs(fw_graph), (4, 2))
5290        self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))
5291
5292    def test_contiguous(self):
5293        # The test simulates the condition where transpose followed by view
5294        # happens in the backward pass.
5295        # https://discuss.pytorch.org/t/error-on-transpose-and-view/434
5296        def f(x):
5297            return x.view(2, 3).t()
5298
5299        inp = torch.randn(6, requires_grad=True)
5300        out = aot_function(f, nop)(inp)
5301        torch.autograd.grad(out, inp, torch.randn(3, 2))
5302
5303    def test_preserve_random(self):
5304        def fn(x):
5305            return torch.nn.functional.dropout(x, 0.5) + x
5306
5307        x = torch.randn(4)
5308
5309        torch.manual_seed(0)
5310        ref = fn(x)
5311
5312        torch.manual_seed(0)
5313        aot_fn = aot_function(fn, nop)
5314        res = aot_fn(x)
5315
5316        assert torch.allclose(ref, res)
5317
5318    # https://github.com/pytorch/pytorch/issues/110666
5319    def test_generate_gives_inference_graph(self):
5320        # We expect this to give an inference graph
5321        def generate(x):
5322            with torch.no_grad():
5323                return torch.mul(x, x)
5324
5325        inference_graph_cell = [None]
5326        inference_compiler = make_boxed_compiler(
5327            partial(extract_graph, graph_cell=inference_graph_cell)
5328        )
5329        aot_fn = aot_function(generate, nop, inference_compiler=inference_compiler)
5330        # Even though x requires grad, we should still get an inference graph
5331        x = torch.randn(4, requires_grad=True)
5332        res = aot_fn(x)
5333        self.assertTrue(inference_graph_cell[0] is not None)
5334
5335    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
5336    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
5337    def test_autocast(self):
5338        mod = torchvision.models.resnet18().cuda()
5339        mod.train()
5340
5341        x = torch.randn(16, 3, 32, 32, device="cuda")
5342        aot_mod = memory_efficient_fusion(mod)
5343
5344        # Ensure that AOT Autograd works with AMP
5345        with torch.cuda.amp.autocast(True):
5346            res = aot_mod(x)
5347        res.sum().backward()
5348
5349
5350class TestAOTDispatch(AOTTestCase):
5351    # Tests to add cases for (non-exhaustive list, mostly for my notes):
5352    # - subclass / mode introduced in the middle of the compiled fn
5353    # - various input mutation / intermediate base tests
5354    # - input mutation that changes a tensor into a subclass
5355    # - metadata mutation? (TBD)
5356    # - guard tests (fw guards *and* bw guards)
5357    # - subclass test involving _indices_of_inps_to_detach
5358    def test_aot_dispatch_simple(self):
5359        # a is a subclass, b is not
5360        def f(a, b):
5361            aa = torch.mul(a, 6)
5362            bb = torch.div(b, 2)
5363            return aa + bb
5364
5365        a1_ref = torch.ones(3, 3, requires_grad=True)
5366        a2_ref = torch.ones(3, 3, requires_grad=True)
5367        a_ref = TwoTensor(a1_ref, a2_ref)
5368        b_ref = torch.ones(3, 3, requires_grad=True)
5369
5370        a1_test = a1_ref.clone().detach().requires_grad_(True)
5371        a2_test = a2_ref.clone().detach().requires_grad_(True)
5372        a_test = TwoTensor(a1_test, a2_test)
5373        b_test = b_ref.clone().detach().requires_grad_(True)
5374
5375        fw_graph_cell = [None]
5376        bw_graph_cell = [None]
5377
5378        compiled_f = aot_function(
5379            f,
5380            fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
5381            bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
5382            partition_fn=min_cut_rematerialization_partition,
5383        )
5384        out_ref = f(a_ref, b_ref)
5385        out_test = compiled_f(a_test, b_test)
5386
5387        # Output is a TwoTensor (check both inner tensors)
5388        self.assertEqual(out_ref.a, out_test.a)
5389        self.assertEqual(out_ref.b, out_test.b)
5390
5391        out_ref.sum().backward()
5392        out_test.sum().backward()
5393        # Both grad_inputs are TwoTensor
5394        self.assertEqual(a_ref.grad.a, a_test.grad.a)
5395        self.assertEqual(a_ref.grad.b, a_test.grad.b)
5396        self.assertEqual(b_ref.grad.a, b_test.grad.a)
5397        self.assertEqual(b_ref.grad.b, b_test.grad.b)
5398
5399        # Important pieces of the graph:
5400        # - mul() and div() show up twice, because we called them on a TwoTensor
5401        # - add() shows up once, because we called it on a plain Tensor
5402        # - The user forward() fn returns 1 output (the result of add),
5403        #   while the graph itself returns two outputs (add, add_1)
5404        # - add, add_1 correspond to the two inner dense tensors that will be wrapped
5405        # - into a single TwoTensor output.
5406        self.assertExpectedInline(
5407            fw_graph_cell[0].code.strip(),
5408            """\
5409def forward(self, primals_1, primals_2, primals_3):
5410    mul = torch.ops.aten.mul.Tensor(primals_1, 6);  primals_1 = None
5411    mul_1 = torch.ops.aten.mul.Tensor(primals_2, 6);  primals_2 = None
5412    div = torch.ops.aten.div.Tensor(primals_3, 2);  primals_3 = None
5413    add = torch.ops.aten.add.Tensor(mul, div);  mul = None
5414    add_1 = torch.ops.aten.add.Tensor(mul_1, div);  mul_1 = div = None
5415    return (add, add_1)""",
5416        )
5417
5418        # Important pieces of the graph:
5419        # - 4 total dense outputs.
5420        #   This corresponds to the fact that each user fwd inpt (a, b)
5421        #   will get a gradient that is a TwoTensor subclass,
5422        #   so (mul_2, mul_3) will be wrapped into a.grad
5423        #   and (div_1, div_2) will be wrapped into b.grad
5424        # - 4 total dense outputs,
5425        self.assertExpectedInline(
5426            bw_graph_cell[0].code.strip(),
5427            """\
5428def forward(self, tangents_1, tangents_2):
5429    div_1 = torch.ops.aten.div.Tensor(tangents_1, 2)
5430    div_2 = torch.ops.aten.div.Tensor(tangents_2, 2)
5431    mul_2 = torch.ops.aten.mul.Tensor(tangents_1, 6);  tangents_1 = None
5432    mul_3 = torch.ops.aten.mul.Tensor(tangents_2, 6);  tangents_2 = None
5433    return (mul_2, mul_3, div_1, div_2)""",
5434        )
5435
5436    def test_aot_dispatch_inference(self):
5437        # a is a subclass, b is not
5438        def f(a, b):
5439            aa = torch.mul(a, 6)
5440            bb = torch.div(b, 2)
5441            return aa + bb
5442
5443        a1_ref = torch.ones(3, 3)
5444        a2_ref = torch.ones(3, 3)
5445        a_ref = TwoTensor(a1_ref, a2_ref)
5446        b_ref = torch.ones(3, 3)
5447
5448        a1_test = a1_ref.clone()
5449        a2_test = a2_ref.clone()
5450        a_test = TwoTensor(a1_test, a2_test)
5451        b_test = b_ref.clone()
5452
5453        compiled_f = aot_function(
5454            f,
5455            fw_compiler=nop,
5456            bw_compiler=nop,
5457            partition_fn=min_cut_rematerialization_partition,
5458        )
5459        out_ref = f(a_ref, b_ref)
5460        out_test = compiled_f(a_test, b_test)
5461
5462        # Output is a TwoTensor (check both inner tensors)
5463        self.assertEqual(out_ref.a, out_test.a)
5464        self.assertEqual(out_ref.b, out_test.b)
5465
5466    def test_aot_dispatch_incorrect_backward(self):
5467        # a is a subclass, b is not
5468        def f(a, b):
5469            aa = torch.mul(a, 2)
5470            bb = torch.add(b, 3)
5471            out_subclass = torch.div(aa, bb)
5472            out_reg = torch.add(b, b)
5473            # When creating the joint, we assume that the second grad_out
5474            # is not a subclass.
5475            # In the below test case though, we end up being wrong.
5476            # This would require re-tracing and recompiling the backward.
5477            return out_subclass, out_reg
5478
5479        a1_ref = torch.ones(3, 3, requires_grad=True)
5480        a2_ref = torch.ones(3, 3, requires_grad=True)
5481        a_ref = TwoTensor(a1_ref, a2_ref)
5482        b_ref = torch.ones(3, 3, requires_grad=True)
5483
5484        a1_test = a1_ref.clone().detach().requires_grad_(True)
5485        a2_test = a2_ref.clone().detach().requires_grad_(True)
5486        a_test = TwoTensor(a1_test, a2_test)
5487        b_test = b_ref.clone().detach().requires_grad_(True)
5488
5489        compiled_f = aot_function(
5490            f,
5491            fw_compiler=nop,
5492            bw_compiler=nop,
5493            partition_fn=min_cut_rematerialization_partition,
5494        )
5495        out_ref = f(a_ref, b_ref)
5496        out_test = compiled_f(a_test, b_test)
5497        # First out is a TwoTensor, second is an ordinary tensor
5498        self.assertEqual(out_ref[0].a, out_test[0].a)
5499        self.assertEqual(out_ref[0].b, out_test[0].b)
5500        self.assertEqual(out_ref[1], out_test[1])
5501
5502        # We compiled our graph assuming type(grad_out[1]) == torch.Tensor,
5503        # but we were wrong: in the below tests, it is a subclass.
5504        # This will eventually require a repartition + recompile
5505        with self.assertRaisesRegex(
5506            AssertionError,
5507            "incorrectly attempted to compile the backward with incorrect subclass metadata",
5508        ):
5509            (out_test[0] + out_test[1]).sum().backward()
5510
5511    def test_aot_dispatch_output_alias(self):
5512        # a is a tensor, b is a TwoTensor
5513        def f(a, b):
5514            return b.view(b.shape), a * b
5515
5516        b1_ref = torch.ones(3, 3, requires_grad=True)
5517        b2_ref = torch.ones(3, 3, requires_grad=True)
5518        b_ref = TwoTensor(b1_ref, b2_ref)
5519        a_ref = torch.ones(3, 3, requires_grad=True)
5520
5521        b1_test = b1_ref.clone().detach().requires_grad_(True)
5522        b2_test = b2_ref.clone().detach().requires_grad_(True)
5523        b_test = TwoTensor(b1_test, b2_test)
5524        a_test = a_ref.clone().detach().requires_grad_(True)
5525
5526        compiled_f = aot_function(
5527            f,
5528            fw_compiler=nop,
5529            bw_compiler=nop,
5530            partition_fn=min_cut_rematerialization_partition,
5531        )
5532        out_ref1, out_ref2 = f(a_ref, b_ref)
5533        out_test1, out_test2 = compiled_f(a_test, b_test)
5534        self.assertEqual(out_ref1, out_test1)
5535        self.assertEqual(out_ref2.a, out_test2.a)
5536        self.assertEqual(out_ref2.b, out_test2.b)
5537
5538        (out_ref1 + out_ref2).sum().backward()
5539        (out_test1 + out_test2).sum().backward()
5540        # Both grad_inputs are TwoTensor
5541        self.assertEqual(a_ref.grad.a, a_test.grad.a)
5542        self.assertEqual(a_ref.grad.b, a_test.grad.b)
5543        self.assertEqual(b_ref.grad.a, b_test.grad.a)
5544        self.assertEqual(b_ref.grad.b, b_test.grad.b)
5545
5546    def test_aot_dispatch_input_mutation(self):
5547        def f(a, b):
5548            a.mul_(2)
5549            b.mul_(3)
5550            return a + b
5551
5552        b1_ref = torch.ones(3, 3, requires_grad=True)
5553        b2_ref = torch.ones(3, 3, requires_grad=True)
5554        b_ref_base = TwoTensor(b1_ref, b2_ref)
5555        a_ref_base = torch.ones(3, 3, requires_grad=True)
5556        b_ref = b_ref_base + 1
5557        a_ref = a_ref_base + 1
5558
5559        b1_test = b1_ref.clone().detach().requires_grad_(True)
5560        b2_test = b2_ref.clone().detach().requires_grad_(True)
5561        b_test_base = TwoTensor(b1_test, b2_test)
5562        a_test_base = a_ref_base.clone().detach().requires_grad_(True)
5563        b_test = b_test_base + 1
5564        a_test = a_test_base + 1
5565
5566        compiled_f = aot_function(
5567            f,
5568            fw_compiler=nop,
5569            bw_compiler=nop,
5570            partition_fn=min_cut_rematerialization_partition,
5571        )
5572        out_ref = f(a_ref, b_ref)
5573        out_test = compiled_f(a_test, b_test)
5574        self.assertEqual(out_ref.a, out_test.a)
5575        self.assertEqual(out_ref.b, out_test.b)
5576
5577        # confirm input mutations worked
5578        self.assertEqual(a_test, a_ref)
5579        self.assertEqual(b_test.a, b_ref.a)
5580        self.assertEqual(b_test.b, b_ref.b)
5581
5582        # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile teh backward.
5583        (b_ref * out_ref).sum().backward()
5584        (b_test * out_test).sum().backward()
5585        # Both grad_inputs are TwoTensor
5586        self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
5587        self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
5588        self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
5589        self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
5590
5591    # NB: Metadata mutation for subclasses is currently broken and disabled
5592    # See https://github.com/pytorch/pytorch/issues/114975
5593    @unittest.expectedFailure
5594    def test_aot_dispatch_input_metadata_mutation(self):
5595        def f(a, b):
5596            a.t_()
5597            b.unsqueeze_(0)
5598            return a + b
5599
5600        b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
5601        b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
5602        b_ref_base = TwoTensor(b1_ref, b2_ref)
5603        a_ref_base = (
5604            torch.arange(9, dtype=torch.float32)
5605            .reshape(3, 3)
5606            .detach()
5607            .requires_grad_(True)
5608        )
5609        b_ref = b_ref_base + 1
5610        a_ref = a_ref_base + 1
5611
5612        b1_test = b1_ref.clone().detach().requires_grad_(True)
5613        b2_test = b2_ref.clone().detach().requires_grad_(True)
5614        b_test_base = TwoTensor(b1_test, b2_test)
5615        a_test_base = a_ref_base.clone().detach().requires_grad_(True)
5616        b_test = b_test_base + 1
5617        a_test = a_test_base + 1
5618
5619        compiled_f = aot_function(
5620            f,
5621            fw_compiler=nop,
5622            bw_compiler=nop,
5623            partition_fn=min_cut_rematerialization_partition,
5624        )
5625        out_ref = f(a_ref, b_ref)
5626        out_test = compiled_f(a_test, b_test)
5627        self.assertEqual(out_ref.a, out_test.a)
5628        self.assertEqual(out_ref.b, out_test.b)
5629
5630        # confirm input mutations worked
5631        self.assertEqual(a_test, a_ref)
5632        self.assertEqual(b_test.a, b_ref.a)
5633        self.assertEqual(b_test.b, b_ref.b)
5634
5635        # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward.
5636        (b_ref * out_ref).sum().backward()
5637        (b_test * out_test).sum().backward()
5638        # Both grad_inputs are TwoTensor
5639        self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
5640        self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
5641        self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
5642        self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
5643
5644    # NB: Metadata mutation for subclasses is currently broken and disabled
5645    # See https://github.com/pytorch/pytorch/issues/114975
5646    @unittest.expectedFailure
5647    def test_aot_dispatch_input_data_and_metadata_mutation(self):
5648        def f(a, b):
5649            a.t_()
5650            b.unsqueeze_(0)
5651            a.mul_(2)
5652            b.mul_(3)
5653            return a + b
5654
5655        b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
5656        b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
5657        b_ref_base = TwoTensor(b1_ref, b2_ref)
5658        a_ref_base = (
5659            torch.arange(9, dtype=torch.float32)
5660            .reshape(3, 3)
5661            .detach()
5662            .requires_grad_(True)
5663        )
5664        b_ref = b_ref_base + 1
5665        a_ref = a_ref_base + 1
5666
5667        b1_test = b1_ref.clone().detach().requires_grad_(True)
5668        b2_test = b2_ref.clone().detach().requires_grad_(True)
5669        b_test_base = TwoTensor(b1_test, b2_test)
5670        a_test_base = a_ref_base.clone().detach().requires_grad_(True)
5671        b_test = b_test_base + 1
5672        a_test = a_test_base + 1
5673
5674        compiled_f = aot_function(
5675            f,
5676            fw_compiler=nop,
5677            bw_compiler=nop,
5678            partition_fn=min_cut_rematerialization_partition,
5679        )
5680        out_ref = f(a_ref, b_ref)
5681        out_test = compiled_f(a_test, b_test)
5682        self.assertEqual(out_ref.a, out_test.a)
5683        self.assertEqual(out_ref.b, out_test.b)
5684
5685        # confirm input mutations worked
5686        self.assertEqual(a_test, a_ref)
5687        self.assertEqual(b_test.a, b_ref.a)
5688        self.assertEqual(b_test.b, b_ref.b)
5689
5690        # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward.
5691        (b_ref * out_ref).sum().backward()
5692        (b_test * out_test).sum().backward()
5693        # Both grad_inputs are TwoTensor
5694        self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
5695        self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
5696        self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a)
5697        self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b)
5698
5699    def test_aot_dispatch_input_mutation_and_output_alias(self):
5700        def f(a, b):
5701            a.mul_(2)
5702            b.mul_(3)
5703            return b.view(b.shape), a + b
5704
5705        b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
5706        b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
5707        b_ref_base = TwoTensor(b1_ref, b2_ref)
5708        a_ref_base = (
5709            torch.arange(9, dtype=torch.float32)
5710            .reshape(3, 3)
5711            .detach()
5712            .requires_grad_(True)
5713        )
5714        b_ref = b_ref_base + 1
5715        a_ref = a_ref_base + 1
5716
5717        b1_test = b1_ref.clone().detach().requires_grad_(True)
5718        b2_test = b2_ref.clone().detach().requires_grad_(True)
5719        b_test_base = TwoTensor(b1_test, b2_test)
5720        a_test_base = a_ref_base.clone().detach().requires_grad_(True)
5721        b_test = b_test_base + 1
5722        a_test = a_test_base + 1
5723
5724        compiled_f = aot_function(
5725            f,
5726            fw_compiler=nop,
5727            bw_compiler=nop,
5728            partition_fn=min_cut_rematerialization_partition,
5729        )
5730        out_ref1, out_ref2 = f(a_ref, b_ref)
5731        out_test1, out_test2 = compiled_f(a_test, b_test)
5732        self.assertEqual(out_ref1.a, out_test1.a)
5733        self.assertEqual(out_ref1.b, out_test1.b)
5734        self.assertEqual(out_ref2.a, out_test2.a)
5735        self.assertEqual(out_ref2.b, out_test2.b)
5736
5737        # confirm input mutations worked
5738        self.assertEqual(a_test, a_ref)
5739        self.assertEqual(b_test.a, b_ref.a)
5740        self.assertEqual(b_test.b, b_ref.b)
5741
5742        (out_ref1 * out_ref2).sum().backward()
5743        (out_test1 * out_test2).sum().backward()
5744        # Both grad_inputs are TwoTensors
5745        self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
5746        self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
5747
5748    def test_aot_dispatch_output_requires_grad_in_no_grad(self):
5749        def fn(x):
5750            out1 = x.sin()
5751            with torch.enable_grad():
5752                out2 = x.cos()
5753            return out1, out2
5754
5755        inp_fns = [
5756            lambda: torch.ones(10, requires_grad=True),
5757            lambda: torch.ones(10, requires_grad=False),
5758        ]
5759
5760        compiled_f = aot_function(fn, nop)
5761        for inp_fn in inp_fns:
5762            with torch.no_grad():
5763                ref_x = inp_fn()
5764                ref_out = fn(ref_x)
5765                x = inp_fn()
5766                out = compiled_f(x)
5767                for r, o in zip(ref_out, out):
5768                    self.assertEqual(r.requires_grad, o.requires_grad)
5769            if ref_x.requires_grad:
5770                with torch.enable_grad():
5771                    (ref_out[0] + ref_out[1]).sum().backward()
5772                    (out[0] + out[1]).sum().backward()
5773                    self.assertEqual(ref_x.grad, x.grad)
5774                    assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3)
5775
5776    def test_aot_dispatch_output_requires_grad_in_no_grad_views(self):
5777        # view-type ops preserve requires_grad even in no_grad.
5778        def fn(x):
5779            return x.view(-1), x.sin()
5780
5781        inference_graph_cell = [None]
5782        inference_compiler = make_boxed_compiler(
5783            partial(extract_graph, graph_cell=inference_graph_cell)
5784        )
5785        compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler)
5786
5787        inp_x0 = torch.ones(2, 3, requires_grad=True)
5788        # Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad
5789        ref_x0 = inp_x0.clone()
5790        x0 = inp_x0.clone()
5791        with torch.no_grad():
5792            ref_out1, ref_out2 = fn(ref_x0)
5793
5794            out1, out2 = compiled_fn(x0)
5795            # Assert that we executed inference graph
5796            self.assertTrue(inference_graph_cell[0] is not None)
5797
5798            self.assertEqual(ref_out1.requires_grad, out1.requires_grad)
5799            self.assertEqual(ref_out2.requires_grad, out2.requires_grad)
5800
5801
5802class TestAOTModuleSimplified(AOTTestCase):
5803    def test_aot_module_simplified(self):
5804        class MockModule(torch.nn.Module):
5805            def __init__(self) -> None:
5806                super().__init__()
5807                self.linear = torch.nn.Linear(20, 30)
5808
5809            def forward(self, x, y):
5810                return (self.linear(x) + y,)
5811
5812        mod = MockModule()
5813        mod.zero_grad()
5814
5815        x = torch.randn(128, 20, requires_grad=True)
5816        y = torch.randn(128, 30, requires_grad=True)
5817        inputs = [x, y]
5818        cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs]
5819
5820        ref = mod(*inputs)
5821        ref[0].sum().backward()
5822
5823        compiled_f = aot_module_simplified(mod, cloned_inputs, nop)
5824        mod.zero_grad()
5825        res = compiled_f(*cloned_inputs)
5826        res[0].sum().backward()
5827
5828        assert torch.allclose(ref[0], res[0])
5829        assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
5830        assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
5831
5832    def test_aot_module_simplified_dynamic(self):
5833        class MockModule(torch.nn.Module):
5834            def __init__(self) -> None:
5835                super().__init__()
5836                self.linear = torch.nn.Linear(20, 30)
5837
5838            def forward(self, x, y):
5839                return (self.linear(x) + y,)
5840
5841        mod = MockModule()
5842
5843        shape_env = ShapeEnv()
5844        fake_mode = FakeTensorMode(shape_env=shape_env)
5845
5846        x = torch.randn(128, 20, requires_grad=True)
5847        y = torch.randn(128, 30, requires_grad=True)
5848
5849        inputs = [x, y]
5850        fake_inputs = [fake_mode.from_tensor(x) for x in inputs]
5851        compiled_f = aot_module_simplified(mod, fake_inputs, nop)
5852
5853        ref = mod(*inputs)
5854        ref[0].sum().backward()
5855
5856        cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs]
5857        res = compiled_f(*cloned_inputs)
5858        res[0].sum().backward()
5859
5860        self.assertExpectedInline(
5861            shape_env.format_guards(),
5862            """\
5863 - Eq(s1, 20)
5864 - Eq(s2, 30)""",
5865        )
5866
5867        assert torch.allclose(ref[0], res[0])
5868        assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
5869        assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
5870
5871    # https://github.com/pytorch/pytorch/issues/105327
5872    def test_lift_fresh_copy_in_graph(self):
5873        class MyMod(torch.nn.Module):
5874            def forward(self, x):
5875                _tensor_constant0 = torch.tensor([1])
5876                lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(
5877                    _tensor_constant0
5878                )
5879                y = x.mul(lift_fresh_copy)
5880                return (y,)
5881
5882        mod = MyMod()
5883        shape_env = ShapeEnv()
5884        fake_mode = FakeTensorMode(shape_env=shape_env)
5885        x = torch.ones(4, requires_grad=True)
5886        inputs = [x]
5887        fake_inputs = [fake_mode.from_tensor(x) for x in inputs]
5888        compiled_f = aot_module_simplified(mod, fake_inputs, nop)
5889
5890        out_ref = mod(x)
5891        out_test = compiled_f(x)
5892        self.assertEqual(out_ref[0].detach(), out_test[0].detach())
5893
5894    def test_inference_python_dispatcher(self):
5895        # Extracted from unet
5896        class MockModule(torch.nn.Module):
5897            def __init__(self) -> None:
5898                super().__init__()
5899                self.upsample = torch.nn.Upsample(
5900                    scale_factor=2, mode="bilinear", align_corners=True
5901                )
5902
5903            def forward(self, x):
5904                return (self.upsample(x),)
5905
5906        mod = MockModule()
5907        shape_env = ShapeEnv()
5908        fake_mode = FakeTensorMode(shape_env=shape_env)
5909        x = torch.randn(2, 512, 40, 59)  # NB: must not require grad
5910        inputs = [x]
5911        fake_inputs = [fake_mode.from_tensor(x) for x in inputs]
5912        compiled_f = aot_module_simplified(mod, fake_inputs, nop)
5913
5914    def test_aot_module_simplified_preserves_stack_trace(self):
5915        class MockModule(torch.nn.Module):
5916            def __init__(self) -> None:
5917                super().__init__()
5918                self.linear = torch.nn.Linear(20, 30)
5919
5920            def forward(self, x, y):
5921                z = self.linear(x)
5922                z = z + y
5923                z = z.relu()
5924                return (z,)
5925
5926        tracer = torch.fx.Tracer()
5927        tracer.record_stack_traces = True
5928        graph = tracer.trace(MockModule())
5929        mod = torch.fx.GraphModule(tracer.root, graph)
5930
5931        for node in mod.graph.nodes:
5932            if node.op == "output":
5933                continue
5934            self.assertTrue(node.stack_trace is not None)
5935            assert "test_aotdispatch.py" in node.stack_trace
5936
5937        def assert_compiler(gm: torch.fx.GraphModule, _):
5938            for node in gm.graph.nodes:
5939                if node.op == "output" or node.op == "placeholder":
5940                    continue
5941                self.assertTrue(node.stack_trace is not None)
5942                assert "test_aotdispatch.py" in node.stack_trace
5943            return gm.forward  # return a python callable
5944
5945        x = torch.randn(128, 20, requires_grad=True)
5946        y = torch.randn(128, 30, requires_grad=True)
5947        inputs = [x, y]
5948
5949        compiled_f = aot_module_simplified(
5950            mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler
5951        )
5952        res = compiled_f(*inputs)
5953        res[0].sum().backward()
5954
5955    def test_aot_module_simplified_preserves_stack_trace_from_mutation(self):
5956        class MockModule(torch.nn.Module):
5957            def __init__(self) -> None:
5958                super().__init__()
5959
5960            def forward(self, x):
5961                x_view = x[0]
5962                x_view.mul_(2)
5963                return (x + x,)
5964
5965        tracer = torch.fx.Tracer()
5966        tracer.record_stack_traces = True
5967        graph = tracer.trace(MockModule())
5968        mod = torch.fx.GraphModule(tracer.root, graph)
5969
5970        for node in mod.graph.nodes:
5971            if node.op == "output":
5972                continue
5973            self.assertTrue(node.stack_trace is not None)
5974            assert "test_aotdispatch.py" in node.stack_trace
5975
5976        def assert_compiler(gm: torch.fx.GraphModule, _):
5977            assert torch.ops.aten.copy_.default in [x.target for x in gm.graph.nodes]
5978            for node in gm.graph.nodes:
5979                if node.target == torch.ops.aten.copy_.default:
5980                    assert "stack_trace" in node.meta
5981                    assert "x_view.mul_(2)" in node.meta["stack_trace"]
5982            return gm.forward  # return a python callable
5983
5984        x = torch.randn(128, 20)
5985        inputs = [x]
5986
5987        aot_module_simplified(
5988            mod,
5989            inputs,
5990            fw_compiler=assert_compiler,
5991            bw_compiler=assert_compiler,
5992            keep_inference_input_mutations=True,
5993        )
5994
5995    def test_aot_module_simplified_fake_tensor_gm_raises(self):
5996        fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
5997        real_x = torch.randn(4, requires_grad=True)
5998        fake_x = fake_mode.from_tensor(real_x)
5999        real_z = torch.randn(4)
6000        fake_z = fake_mode.from_tensor(real_z)
6001
6002        class MockModule(torch.nn.Module):
6003            def forward(self, x):
6004                # Accessing a free variable fake tensor will look like a
6005                # constant to make_fx, and result in the tensor being traced
6006                # into the graph, which is an error condition.  Make sure we
6007                # report adequately in this case.
6008                return (x + fake_z,)
6009
6010        with self.assertRaisesRegex(AssertionError, "Unexpected fake"):
6011            aot_module_simplified(MockModule(), (fake_x,), nop)
6012
6013    def test_aot_test_subclasses_with_tensor_factories(self):
6014        from torch.testing._internal.common_subclass import SubclassWithTensorFactory
6015
6016        inp = SubclassWithTensorFactory(torch.zeros(3, 5))
6017
6018        def fn(x):
6019            return 2 * x
6020
6021        ref_out = fn(inp)
6022        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp)
6023        self.assertEqual(ref_out, out)
6024
6025
6026# entries in here don't work and need to be fixed.
6027# Each one of these is a bug (or needs to be investigated)
6028aot_autograd_failures = {
6029    # data-dependent control flow
6030    xfail("cov"),
6031    xfail("nn.functional.gaussian_nll_loss"),
6032    xfail("tensor_split"),
6033    xfail("corrcoef"),
6034    xfail("quantile"),
6035    xfail("nanquantile"),
6036    xfail("narrow"),
6037    xfail("istft"),
6038    xfail("linalg.eig"),
6039    skip("as_strided_scatter"),
6040    skip("as_strided", "partial_views"),  # flaky
6041    # Given input size: (s0xs1x2). Calculated output size: ...
6042    skip("max_pool2d_with_indices_backward"),
6043    skip("nn.functional.nll_loss", ""),  # UBSAN failure!
6044    # Misc
6045    xfail("to_sparse"),
6046    xfail("corrcoef"),
6047    xfail("cov"),
6048    xfail("chalf"),  # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
6049    xfail("sparse.sampled_addmm"),
6050    xfail("sparse.mm", "reduce"),
6051    skip("nn.functional.binary_cross_entropy_with_logits"),  # seems to fail sometimes?
6052    skip("nn.functional.margin_ranking_loss"),  # seems flaky
6053    skip("linalg.lu_solve"),  # flaky
6054    decorate("matmul", decorator=unittest.skipIf(IS_ARM64, "flaky")),
6055    decorate("__rmatmul__", decorator=unittest.skipIf(IS_ARM64, "flaky")),
6056    # overrides atol=1e-4, rtol=1e-5 would do as well
6057    decorate(
6058        "svd_lowrank",
6059        decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)}),
6060    ),
6061    decorate(
6062        "linalg.householder_product",
6063        decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
6064    ),
6065    decorate(
6066        "linalg.pinv",
6067        "singular",
6068        decorator=toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}),
6069    ),
6070    decorate(
6071        "nn.functional.interpolate",
6072        "bicubic",
6073        decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)}),
6074    ),
6075    # conv2d sometimes nondeterministic in this config?
6076    decorate("nn.functional.conv2d", decorator=unittest.skipIf(IS_ARM64, "flaky")),
6077}
6078
6079symbolic_aot_autograd_failures = {
6080    xfail("combinations", ""),  # aten.masked_select.default
6081    xfail(
6082        "index_fill", ""
6083    ),  # Cannot call sizes() on tensor with symbolic sizes/strides
6084    xfail(
6085        "linalg.lstsq", ""
6086    ),  # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition
6087    xfail(
6088        "linalg.lstsq", "grad_oriented"
6089    ),  # aten.linalg_lstsq.default - couldn't find symbolic meta funct...
6090    xfail(
6091        "linalg.lu_solve", ""
6092    ),  # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco...
6093    skip(
6094        "nn.functional.batch_norm", ""
6095    ),  # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
6096    xfail(
6097        "nn.functional.binary_cross_entropy", ""
6098    ),  # aten.fill_.Scalar - couldn't find symbolic meta funct...
6099    xfail(
6100        "nn.functional.cross_entropy", ""
6101    ),  # Cannot call sizes() on tensor with symbolic sizes/strides
6102    xfail(
6103        "nn.functional.ctc_loss", ""
6104    ),  # aten._ctc_loss.Tensor - couldn't find symbolic meta function/deco...
6105    xfail(
6106        "nn.functional.fractional_max_pool3d", ""
6107    ),  # rand() received an invalid combination of arguments - g...
6108    xfail(
6109        "nn.functional.group_norm", ""
6110    ),  # Cannot call sizes() on tensor with symbolic sizes/strides
6111    xfail(
6112        "nn.functional.nll_loss", ""
6113    ),  # Cannot call sizes() on tensor with symbolic sizes/strides
6114    xfail(
6115        "_segment_reduce", "lengths"
6116    ),  # aten.segment_reduce.default - couldn't find symbolic meta functio...
6117    xfail(
6118        "_segment_reduce", "offsets"
6119    ),  # aten.segment_reduce.default - couldn't find symbolic meta functio...
6120    xfail("trace", ""),  # Cannot call sizes() on tensor with symbolic sizes/strides
6121    xfail(
6122        "_upsample_bilinear2d_aa"
6123    ),  # RuntimeError: isIntList() INTERNAL ASSERT FAILED  Expected IntList but got GenericList
6124    decorate(
6125        "linalg.householder_product",
6126        decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
6127    ),
6128    # many complex operators incorrect striding, metadata
6129    xfail("fft.fft", ""),
6130    xfail("fft.hfft2", ""),
6131    xfail("fft.hfft", ""),
6132    xfail("fft.hfftn", ""),
6133    xfail("fft.ifft", ""),
6134    xfail("fft.ihfft2", ""),
6135    xfail("fft.ihfft", ""),
6136    xfail("fft.ihfftn", ""),
6137    xfail("fft.irfft2", ""),
6138    xfail("fft.irfft", ""),
6139    xfail("fft.irfftn", ""),
6140    xfail("fft.rfft2", ""),
6141    xfail("fft.rfft", ""),
6142    xfail("fft.rfftn", ""),
6143    xfail("stft", ""),  # Cannot call sizes() on tensor with symbolic sizes/strides
6144}
6145
6146
6147def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
6148    if not op.supports_autograd:
6149        self.skipTest("Op does not support autograd")
6150
6151    # aot_autograd_check is able to check data specialization by
6152    # randomizing the inputs. Here's a list of ops that really do not
6153    # like random inputs for which we want to disable that.
6154    cant_check_data_specialization = set(
6155        {
6156            "nn.functional.max_unpool1d",
6157            "nn.functional.max_unpool2d",
6158            "nn.functional.max_unpool3d",
6159        }
6160    )
6161    try_check_data_specialization = op.name not in cant_check_data_specialization
6162
6163    sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
6164    for sample_input in sample_inputs_itr:
6165        t_args = [sample_input.input] + list(sample_input.args)
6166        t_kwargs = sample_input.kwargs
6167        try:
6168            aot_autograd_check(
6169                op.op,
6170                t_args,
6171                t_kwargs,
6172                dynamic,
6173                self.assertRaisesRegex,
6174                self.assertEqual,
6175                check_gradients=True,
6176                try_check_data_specialization=try_check_data_specialization,
6177            )
6178        except DynamicOutputShapeException:
6179            self.skipTest("Dynamic output shape operation in trace")
6180        except GuardOnDataDependentSymNode:
6181            # Carveout for getitem; I don't want to xfail the entire test
6182            # because that will reject known to be good tests see
6183            # https://github.com/pytorch/pytorch/issues/94705
6184            if op.name == "__getitem__":
6185                self.skipTest("Dynamic output shape operation in trace")
6186            else:
6187                raise
6188
6189
6190def _test_aot_autograd_module_helper(
6191    self, device, dtype, training, module_info, *, dynamic=False
6192):
6193    module_cls = module_info.module_cls
6194    module_inputs = module_info.module_inputs_func(
6195        module_info, device=device, dtype=dtype, requires_grad=True, training=training
6196    )
6197    for module_input in module_inputs:
6198        if module_input.forward_input is None:
6199            continue
6200
6201        args, kwargs = (
6202            module_input.constructor_input.args,
6203            module_input.constructor_input.kwargs,
6204        )
6205        m = module_cls(*args, **kwargs)
6206        m.to(device).to(dtype)
6207        m.train(training)
6208
6209        # Lazy modules need to see an input first to initialize params.
6210        args, kwargs = (
6211            module_input.forward_input.args,
6212            module_input.forward_input.kwargs,
6213        )
6214        flat_args, args_spec = pytree.tree_flatten((args, kwargs))
6215
6216        # PackedSequence is only used for RNNs. It might be possible to fake-ify if they're pytrees but
6217        # torchdynamo already doesn't support RNNs
6218        if any(tuple(isinstance(flat_arg, PackedSequence) for flat_arg in flat_args)):
6219            continue
6220
6221        if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
6222            with torch.no_grad():
6223                m(*args, **kwargs)
6224
6225        sentinel_val = -42
6226        is_tensor_spec = [
6227            sentinel_val if isinstance(arg, torch.Tensor) else arg for arg in flat_args
6228        ]
6229        args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)]
6230
6231        def f(params_buffers_args):
6232            named_params, named_buffers, args = params_buffers_args
6233            cur_flat_args = list(is_tensor_spec)
6234            args = iter(args)
6235            for idx, v in enumerate(cur_flat_args):
6236                if v == sentinel_val:
6237                    cur_flat_args[idx] = next(args)
6238            c_args, c_kwargs = pytree.tree_unflatten(cur_flat_args, args_spec)
6239            params_and_buffers = {**named_params, **named_buffers}
6240            return torch.func.functional_call(m, params_and_buffers, c_args, c_kwargs)
6241
6242        named_params = dict(m.named_parameters(remove_duplicate=False))
6243        named_buffers = dict(m.named_buffers(remove_duplicate=False))
6244        num_params_buffers = len(named_params) + len(named_buffers)
6245        compiled_f = aot_function(
6246            f, nop, num_params_buffers=num_params_buffers, dynamic=dynamic
6247        )
6248        params_buffers_args = [named_params, named_buffers, args]
6249        _test_aot_autograd_forwards_backwards_helper(
6250            f,
6251            compiled_f,
6252            params_buffers_args,
6253            self.assertRaisesRegex,
6254            self.assertEqual,
6255            True,
6256        )
6257
6258
6259class TestEagerFusionOpInfo(AOTTestCase):
6260    @ops(op_db + hop_db, allowed_dtypes=(torch.float,))
6261    @skipOps(
6262        "TestEagerFusionOpInfo", "test_aot_autograd_exhaustive", aot_autograd_failures
6263    )
6264    def test_aot_autograd_exhaustive(self, device, dtype, op):
6265        _test_aot_autograd_helper(self, device, dtype, op)
6266
6267    @ops(op_db + hop_db, allowed_dtypes=(torch.float,))
6268    @patch("functorch.compile.config.debug_assert", True)
6269    @skipOps(
6270        "TestEagerFusionOpInfo",
6271        "test_aot_autograd_symbolic_exhaustive",
6272        aot_autograd_failures | symbolic_aot_autograd_failures,
6273    )
6274    def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
6275        _test_aot_autograd_helper(self, device, dtype, op, dynamic=True)
6276
6277
6278aot_autograd_module_failures = set(
6279    {
6280        torch.nn.CTCLoss,  # torch._subclasses.fake_tensor.DynamicOutputShapeException: aten._ctc_loss.default
6281        torch.nn.GaussianNLLLoss,  # RuntimeError: It appears that you're trying to get value out
6282        # of a tracing tensor with aten._local_scalar_dense.default -
6283        # erroring out! It's likely that this is caused by data-dependent
6284        # control flow or similar.
6285        torch.nn.MultiLabelMarginLoss,  # AssertionError: The values for attribute 'shape' do not match:
6286        # torch.Size([1]) != torch.Size([]). Outputs of the operator are different in
6287        # eager-mode PyTorch vs AOTAutograd. This means the operator will have incorrect
6288        # output underneath torch.compile. This could be because the operator's
6289        # implementation not traceable or that there is a bug in AOTAutograd.
6290        torch.nn.TransformerEncoder,  # DataDependentOutputException: aten.eq compares a mask input
6291        # to a causal mask tensor, to see if Boolean is_causal should be set
6292        # for TrnasformerEncoder layers, MHA and sdp custom kernels
6293        torch.nn.Transformer,  # DataDependentOutputException: aten.equal compares a mask input
6294        # to a causal mask tensor, to see if Boolean is_causal should be set
6295        # for TransformerEncoder layers, MHA and sdp custom kernels
6296        # (this bubbles up to Transformer)
6297    }
6298)
6299
6300symbolic_aot_autograd_module_failures = {
6301    torch.nn.Transformer,  # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
6302    torch.nn.TransformerEncoder,  # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
6303    torch.nn.GaussianNLLLoss,  # NotImplementedError: local_scalar_dense/item NYI for torch.bool
6304    torch.nn.GroupNorm,  # in native_group_norm_backward cpg, _rem = divmod(C, group)
6305    # TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
6306    torch.nn.FractionalMaxPool3d,  # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
6307    torch.nn.BCELoss,  # new_size = _infer_size(target.size(), weight.size())
6308    # RuntimeError: expected int at position 0, but got: SymInt
6309}
6310
6311
6312class TestEagerFusionModuleInfo(AOTTestCase):
6313    @modules(module_db, allowed_dtypes=(torch.float,))
6314    @decorateForModules(unittest.expectedFailure, aot_autograd_module_failures)
6315    def test_aot_autograd_module_exhaustive(self, device, dtype, training, module_info):
6316        _test_aot_autograd_module_helper(self, device, dtype, training, module_info)
6317
6318    @modules(module_db, allowed_dtypes=(torch.float,))
6319    @decorateForModules(
6320        unittest.expectedFailure,
6321        aot_autograd_module_failures | symbolic_aot_autograd_module_failures,
6322    )
6323    def test_aot_autograd_symbolic_module_exhaustive(
6324        self, device, dtype, training, module_info
6325    ):
6326        _test_aot_autograd_module_helper(
6327            self, device, dtype, training, module_info, dynamic=True
6328        )
6329
6330
6331instantiate_parametrized_tests(TestAOTAutograd)
6332only_for = "cpu"
6333instantiate_device_type_tests(
6334    TestPythonKey,
6335    globals(),
6336    only_for=only_for,
6337)
6338instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for)
6339instantiate_device_type_tests(TestEagerFusionModuleInfo, globals(), only_for=only_for)
6340
6341
6342@xfail_inherited_tests(
6343    [
6344        "test_set__and_data_mutation_bad",
6345        "test_subclass_metadata_mutation_req_grad_True",
6346        "test_subclass_metadata_mutation_req_grad_False",
6347    ]
6348)
6349@skipIfTorchDynamo("This test suite already uses dynamo")
6350class TestAOTAutogradWithDynamo(TestAOTAutograd):
6351    """
6352    These are the same as TestAOTAutograd tests, but we run dynamo first to get a graph module.
6353    """
6354
6355    def assertExpectedInline(self, *args, **kwargs):
6356        # These will have different outputs because dynamo returns a different graph module
6357        # But we don't really care about that assertion when testing with dynamo,
6358        # only that the outputs match, etc.
6359        pass
6360
6361    def make_compiler(self, graph_cell):
6362        return make_boxed_compiler(partial(extract_graph, graph_cell=graph_cell))
6363
6364    # Compiler to passes to dynamo
6365    def run_autograd(
6366        self,
6367        f: Callable,
6368        fw_graph_cell: List[Optional[Callable]],
6369        decompositions: Optional[Dict],
6370        keep_input_mutations: bool,
6371        dynamic: bool,
6372    ):
6373        """
6374        Runs dynamo and aot_autograd with the specified settings
6375        """
6376
6377        def dynamo_compiler(gm, inputs, **kwargs):
6378            result = aot_module_simplified(
6379                gm,
6380                inputs,
6381                fw_compiler=self.make_compiler(fw_graph_cell),
6382                bw_compiler=self.make_compiler([None]),
6383                decompositions=decompositions,
6384                keep_inference_input_mutations=keep_input_mutations,
6385                # Dynamic is calculated from whether the inputs have fake tensors
6386            )
6387            return result
6388
6389        def torch_compile_wrapper(*args, **kwargs):
6390            torch._dynamo.reset()
6391            fn = torch.compile(f, backend=dynamo_compiler)
6392            try:
6393                result = fn(*args, **kwargs)
6394            except torch._dynamo.exc.BackendCompilerFailed as e:
6395                # So that assertRaises works properly
6396                raise e.inner_exception from e
6397            return result
6398
6399        return torch_compile_wrapper
6400
6401
6402class MockFXGraphCache:
6403    """
6404    In memory version of FXGraphCache so we can isolate testing for FXGraphCache
6405    """
6406
6407    def __init__(self) -> None:
6408        self.cache = {}
6409
6410    def save(self, key, gm):
6411        self.cache[key] = gm
6412
6413    def load(self, gm, inputs):
6414        key, _ = compiled_fx_graph_hash(gm, inputs, {}, {})
6415        if key in self.cache:
6416            gm = make_boxed_func(gm)
6417            gm._fx_graph_cache_key = key
6418            return gm
6419        else:
6420            self.save(key, gm)
6421            gm = make_boxed_func(gm)
6422            gm._fx_graph_cache_key = key
6423            return gm
6424
6425    def _lookup_graph(self, key, inputs, local, remote_cache):
6426        gm = self.cache.get(key)
6427        if gm is not None:
6428            gm = make_boxed_func(gm)
6429        return gm
6430
6431    def post_compile(self, gm, inputs, cudagraphs):
6432        pass
6433
6434
6435# The following tests fail in strict caching mode (i.e. they bypass or
6436# cache miss instead of cache hitting). They will be fixed in the PRs above this.
6437FAILING_CACHE_TESTS = (
6438    # BypassAOTAutogradCache: unsupported nodes
6439    "test_backward_mutation_data",  # Custom Autograd Function
6440    "test_backward_mutation_metadata",  # Custom Autograd Function
6441    "test_custom_autograd",  # Custom Autograd Function
6442    "test_input_output_aliase_custom_autograd_function",
6443)
6444
6445
6446@xfail_inherited_tests(FAILING_CACHE_TESTS)
6447class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
6448    """
6449    In memory version of FXGraphCache so we can isolate testing for FXGraphCache
6450    """
6451
6452    def make_compiler(self, fw_graph_cell):
6453        mock_inductor_cache = self.inductor_cache
6454
6455        def compiler(gm, inputs):
6456            nonlocal mock_inductor_cache, fw_graph_cell
6457            result = mock_inductor_cache.load(gm, inputs)
6458            fw_graph_cell[0] = gm
6459            return result
6460
6461        return compiler
6462
6463    def run_autograd(
6464        self,
6465        f: Callable,
6466        fw_graph_cell: List[Optional[Callable]],
6467        decompositions: Optional[Dict],
6468        keep_input_mutations: bool,
6469        dynamic: bool,
6470    ):
6471        return super().run_autograd(
6472            f,
6473            fw_graph_cell,
6474            decompositions,
6475            keep_input_mutations,
6476            dynamic,
6477        )
6478
6479    @torch._functorch.config.patch(
6480        {
6481            "enable_autograd_cache": True,
6482            "strict_autograd_cache": True,
6483            "view_replay_for_aliased_outputs": False,
6484        }
6485    )
6486    @torch._inductor.config.patch("fx_graph_cache", True)
6487    def verify_aot_autograd(
6488        self,
6489        f,
6490        inp_: Union[Callable, List[Any]],
6491        *,
6492        test_mutation: bool = False,
6493        keep_inp_mutations: bool = False,
6494        decompositions: Optional[Dict] = None,
6495        dynamic: bool = False,
6496        # Only active when inp_ is Callable.
6497        # TODO: probably consolidate all tests to make inp a Callable.
6498        make_inputs_subclasses: bool = False,
6499    ):
6500        self.inductor_cache = MockFXGraphCache()
6501        AOTAutogradCache.clear()
6502        with patch(
6503            "torch._inductor.codecache.FxGraphCache._lookup_graph",
6504            new=self.inductor_cache._lookup_graph,
6505        ), patch(
6506            "torch._inductor.codecache.FxGraphCache.post_compile",
6507            new=self.inductor_cache.post_compile,
6508        ):
6509            return super().verify_aot_autograd(
6510                f,
6511                inp_,
6512                test_mutation=test_mutation,
6513                keep_inp_mutations=keep_inp_mutations,
6514                decompositions=decompositions,
6515                dynamic=dynamic,
6516                make_inputs_subclasses=make_inputs_subclasses,
6517            )
6518
6519    def test_input_mutation_false_aliasing(self):
6520        # This test is disabled because it fails in strict cache mode
6521        # But also can't be xfailed because it causes undefined behavior for
6522        # ASAN
6523        self.skipTest("Skipping because it fails in strict cache mode")
6524
6525
6526if __name__ == "__main__":
6527    run_tests()
6528