xref: /aosp_15_r20/external/pytorch/test/dynamo/test_aot_autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import copy
3import re
4import unittest
5from textwrap import dedent
6from unittest.mock import patch
7
8import torch
9import torch._dynamo
10import torch._dynamo.test_case
11import torch.fx.traceback as fx_traceback
12import torch.utils._pytree as pytree
13from torch._dynamo.testing import CompileCounter, expectedFailureDynamic, rand_strided
14from torch._functorch.aot_autograd import _aot_export_function, create_functional_call
15from torch._subclasses.fake_tensor import FakeTensorMode
16from torch.fx.experimental.proxy_tensor import make_fx
17from torch.profiler import profile
18from torch.testing import FileCheck
19from torch.testing._internal.common_utils import compare_equal_outs_and_grads
20
21
22def maybe_dupe_op(x):
23    y = x + 1
24    z = x + 2
25    if x.numel() < 5:
26        return y, y
27    else:
28        return y, z
29
30
31def is_dynamic_shape_test(test_name):
32    return test_name.endswith("_dynamic_shapes")
33
34
35aten = torch.ops.aten
36lib = torch.library.Library("custom", "DEF")  # noqa: TOR901
37lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
38lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
39lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
40
41
42class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
43    def test_LSTM(self):
44        # https://github.com/pytorch/torchdynamo/issues/1147
45        class Repro(torch.nn.Module):
46            def __init__(self) -> None:
47                super().__init__()
48                self.self_mod_model_lstm_lstm = torch.nn.LSTM(
49                    64, 64, num_layers=2, bidirectional=True
50                )
51
52            def forward(self, permute: torch.Tensor):
53                self_mod_model_lstm_lstm = self.self_mod_model_lstm_lstm(permute)
54                return (self_mod_model_lstm_lstm,)
55
56        mod = Repro()
57
58        aot_mod = torch._dynamo.optimize("aot_eager")(mod)
59
60        args = [((92, 4, 64), (1, 5888, 92), torch.float32, "cpu", False)]
61        args = [
62            rand_strided(sh, st, dt, dev).requires_grad_(rg)
63            for (sh, st, dt, dev, rg) in args
64        ]
65
66        eager_result = mod(*args)
67        aot_result = aot_mod(*args)
68        self.assertTrue(torch._dynamo.testing.same(eager_result, aot_result))
69
70    def test_mutation(self):
71        # https://github.com/pytorch/torchdynamo/issues/1301
72        def fn(param, y):
73            prev_grad = torch.is_grad_enabled()
74            try:
75                torch.set_grad_enabled(False)
76                param.add_(y)
77            finally:
78                torch.set_grad_enabled(prev_grad)
79            return y
80
81        y = torch.randn(4)
82        x = torch.nn.Parameter(torch.randn(4))
83        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
84        # This should not error: we mutated an autograd leaf under no_grad mode.
85        aot_fn(x, y)
86
87    def test_mutation1(self):
88        def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
89            getitem = diagonal_chunked_attention_scores[
90                (
91                    slice(None, None, None),
92                    slice(None, None, None),
93                    slice(None, 256, None),
94                    slice(None, 257, None),
95                )
96            ]
97            _stack0[
98                (
99                    slice(None, None, None),
100                    slice(None, -1, None),
101                    slice(None, None, None),
102                    slice(256, None, None),
103                )
104            ] = getitem
105            view = _stack0.view(1, 12, 1024, 513)
106            return (view,)
107
108        x = torch.randn(torch.Size([12, 4, 256, 513]))
109        y = torch.randn(torch.Size([12, 3, 512, 513]))
110        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
111        aot_fn(x, y)
112
113    def test_negative_testing_mutation(self):
114        def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
115            getitem = diagonal_chunked_attention_scores[
116                (
117                    slice(None, None, None),
118                    slice(None, None, None),
119                    slice(None, 256, None),
120                    slice(None, 257, None),
121                )
122            ]
123            _stack0 = torch.sin(_stack0)
124            _stack0[
125                (
126                    slice(None, None, None),
127                    slice(None, -1, None),
128                    slice(None, None, None),
129                    slice(256, None, None),
130                )
131            ] = getitem
132            view = _stack0.view(1, 12, 1024, 513)
133            return (view,)
134
135        x = torch.randn(torch.Size([12, 4, 256, 513]))
136        y = torch.randn(torch.Size([12, 3, 512, 513]))
137        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
138        aot_fn(x, y)
139
140    def test_negative_testing(self):
141        def fn(x, y):
142            return torch.sin(x).add_(y)
143
144        y = torch.randn(4)
145        x = torch.randn(4)
146        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
147        aot_fn(x, y)
148
149    def test_call_fn_with_non_const_inputs_aot_safe(self):
150        class ModuleSpecialFwd(torch.nn.Module):
151            def __init__(self) -> None:
152                super().__init__()
153                self.conv = torch.nn.Conv2d(
154                    in_channels=3, out_channels=20, kernel_size=(5, 5)
155                )
156
157            def _conv_forward(self, x):
158                return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)
159
160            def forward(self, x):
161                return self._conv_forward(x)
162
163        # Init mod
164        mod = ModuleSpecialFwd()
165        rx = torch.randn([3, 10, 10])
166
167        # Run it for real
168        real = mod(rx)
169
170        # Run it in export
171        graph, _ = torch._dynamo.export(mod)(rx)
172
173        # Run exported graph with AOT
174        self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
175
176        aot_fn = torch._dynamo.optimize("aot_eager")(graph)
177        aot_fn(rx)
178
179    def test_call_fn_with_non_const_inputs_aot_unsafe(self):
180        class ModuleSpecialFwd(torch.nn.Module):
181            def _some_bad_fwd(self, param, y):
182                prev_grad = torch.is_grad_enabled()
183                try:
184                    torch.set_grad_enabled(False)
185                    param.add_(y)
186                finally:
187                    torch.set_grad_enabled(prev_grad)
188                return y
189
190            def forward(self, x, y):
191                return self._some_bad_fwd(x, y)
192
193        # Init mod
194        mod = ModuleSpecialFwd()
195        x = torch.nn.Parameter(torch.randn(4))
196        y = torch.randn([4])
197
198        # Run it for real
199        real = mod(x, y)
200
201        # Run it in export
202        graph, _ = torch._dynamo.export(mod)(x, y)
203
204        # Assert equal
205        self.assertTrue(torch._dynamo.testing.same(real, graph(x, y)))
206
207        # Run exported graph with AOT
208        aot_fn = torch._dynamo.optimize("aot_eager")(graph)
209        # This should not error: we mutated an autograd leaf under no_grad mode.
210        aot_fn(x, y)
211
212    def test_call_fn_with_non_const_inputs_aot_unsafe_control_flow(self):
213        class ModuleSpecialFwd(torch.nn.Module):
214            def _some_bad_fwd(self, param, y):
215                if y[0][0] < 3:
216                    return y + param
217                return param * y
218
219            def forward(self, x, y):
220                a = x * y
221                a = self._some_bad_fwd(a, a)
222                b = x + y
223                return a * b
224
225        # Init mod
226        mod = ModuleSpecialFwd()
227        x = torch.nn.Parameter(torch.randn([2, 2]))
228        y = torch.randn([2, 2])
229
230        # Run it for real
231        real = mod(x, y)
232
233        # Run it through optimize, with our capturing fn
234
235        gms = []
236        counter = CompileCounter()
237
238        def capturing_fn(gm, inputs):
239            nonlocal gms
240            gms.append(gm)
241            return counter(gm, inputs)
242
243        optimized_mod = torch._dynamo.optimize(capturing_fn)(mod)
244
245        # Assert equal
246        self.assertTrue(torch._dynamo.testing.same(real, optimized_mod(x, y)))
247
248        # Uncomment to reproduce commented out graphs below.
249        # for gm in gms:
250        #     print("GM CODE", gm.code)
251
252        self.assertEqual(counter.frame_count, 4)
253        self.assertEqual(counter.op_count, 7)
254        # Graph 1
255        # def forward(self, x : torch.nn.parameter.Parameter, y : torch.Tensor):
256        #     mul = x * y;  x = y = None
257        #     return (mul,)
258        # BREAK
259        # Graph 2
260        # def forward(self, y : torch.Tensor):
261        #     getitem = y[0];  y = None
262        #     getitem_1 = getitem[0];  getitem = None
263        #     lt = getitem_1 < 3;  getitem_1 = None
264        #     return (lt,)
265        # BREAK
266        # Graph 3
267        # def forward(self, param : torch.Tensor, y : torch.Tensor):
268        #     add = y + param;  y = param = None
269        #     return (add,)
270        # BREAK
271        # Graph 4
272        # def forward(self, _stack0 : torch.Tensor, x : torch.nn.parameter.Parameter, y : torch.Tensor):
273        #     add = x + y;  x = y = None
274        #     mul = _stack0 * add;  _stack0 = add = None
275        #     return (mul,)
276
277        # Run fn with AOT
278        torch._dynamo.reset()
279
280        aot_fn = torch._dynamo.optimize("aot_eager")(optimized_mod)
281        aot_fn(x, y)
282
283    # Note: Dynamo recompilation guarding invalid grad
284    #
285    # This test is a spiritual equivalent to test_invalid_requires_grad_fake in test_autodispatch.py
286    # The point of this test is to invoke aot_autograd in a way that would normally trigger an assertion
287    # (This is what test_invalid_requires_grad_fake) does. However, the point of this test is to prove
288    # that we do not hit this assertion, as dynamo recompiles correctly and protects this condition.
289    #
290    # Subnote: The reason for us having test_invalid_requires_grad_fake utilizing fake tensors
291    # is because dynamo sends fake tensors down to aot_autograd.
292    @patch("torch._functorch.config.debug_assert", True)
293    def test_requires_grad_fake_via_dynamo_recompiles(self):
294        class F(torch.nn.Module):
295            def forward(self, x, y):
296                return (x + y,)
297
298        x = torch.randn(3, 3, requires_grad=True)
299        y = torch.randn(3, 3, requires_grad=True)
300        z = torch.randn(3, 3, requires_grad=False)
301
302        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
303
304        failure_reason = None
305
306        def guard_fail_fn(failure):
307            nonlocal failure_reason
308            failure_reason = failure[0]
309
310        fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
311        compare_equal_outs_and_grads(self, F(), fxy, (x, y))
312        compare_equal_outs_and_grads(self, F(), fxy, (x, z))
313        self.assertIn(
314            """tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
315            failure_reason,
316        )
317
318        # Reset failure reason
319        failure_reason = None
320
321        self.assertEqual(cc.frame_count, 2)
322
323        torch._dynamo.reset()  # for new backend
324        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
325
326        fxz = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
327        compare_equal_outs_and_grads(self, F(), fxz, (x, z))
328        compare_equal_outs_and_grads(self, F(), fxz, (x, z))
329        self.assertEqual(cc.frame_count, 1)
330        self.assertTrue(failure_reason is None)
331
332    def test_double_backward_errors(self):
333        # Remove this test after we get double backward to actually work
334        for grad_output in (torch.tensor(1.0, requires_grad=True), None):
335            x = torch.tensor(1.0, requires_grad=True)
336            err = "torch.compile with aot_autograd does not currently support double backward"
337
338            # The following cases should be equivalent:
339
340            # (1) double backward entirely inside compiled function
341            def f1(x):
342                y = x.sin().exp()
343                (gx,) = torch.autograd.grad(
344                    y, x, create_graph=True, grad_outputs=grad_output
345                )
346                torch.autograd.grad(gx, x)
347                return gx
348
349            compiled_f1 = torch.compile(backend="aot_eager")(f1)
350            f1(x)
351            with self.assertRaisesRegex(RuntimeError, err):
352                compiled_f1(x)
353
354            # (2) the second half of double backward outside compiled function
355            def f2(x):
356                y = x.sin().exp()
357                (gx,) = torch.autograd.grad(
358                    y, x, create_graph=True, grad_outputs=grad_output
359                )
360                return gx
361
362            compiled_f2 = torch.compile(backend="aot_eager")(f2)
363            gx = compiled_f2(x)
364            with self.assertRaisesRegex(RuntimeError, err):
365                torch.autograd.grad(gx, x)
366
367            # (3) double backward entirely outside compiled function
368            def f3(x):
369                y = x.sin().exp()
370                return y
371
372            compiled_f3 = torch.compile(backend="aot_eager")(f3)
373            y = compiled_f3(x)
374            (gx,) = torch.autograd.grad(
375                y, x, create_graph=True, grad_outputs=grad_output
376            )
377            with self.assertRaisesRegex(RuntimeError, err):
378                torch.autograd.grad(gx, x)
379
380        # create_graph=False
381        def f4(x):
382            y = x.sin().exp()
383            return y
384
385        compiled_f4 = torch.compile(backend="aot_eager")(f4)
386        x = torch.tensor(1.0, requires_grad=True)
387        y = compiled_f4(x)
388        (gx,) = torch.autograd.grad(y, x, create_graph=False, grad_outputs=grad_output)
389
390    @patch("torch._functorch.config.debug_assert", True)
391    def test_arg_dupe_via_dynamo_recompiles(self):
392        class F(torch.nn.Module):
393            def forward(self, x, y):
394                x = x.trunc_()
395                y = y.trunc_()
396                return (x + y,)
397
398        x = torch.randn(3, 3, requires_grad=True)
399        x1, x2, x3, x4 = x.clone(), x.clone(), x.clone(), x.clone()
400        y = torch.randn(3, 3, requires_grad=True)
401        y1, y2, y4 = y.clone(), y.clone(), y.clone()
402
403        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
404
405        failure_reason = None
406
407        def guard_fail_fn(failure):
408            nonlocal failure_reason
409            failure_reason = failure[0]
410
411        fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
412        # Note: to prevent a recompilation between the two calls,
413        # we need to clone x and y on each use.
414        # fxy mutates the input's metadata, so otherwise dynamo will end up recompiling.
415        fxy(x1, y1)
416        fxy(x2, y2)
417
418        self.assertTrue(failure_reason is None)
419
420        # Reset failure reason
421        failure_reason = None
422
423        self.assertEqual(cc.frame_count, 1)
424
425        torch._dynamo.reset()  # for new backend
426        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
427
428        fxx = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
429        fxx(x3, x3)
430        fxx(x4, y4)
431        self.assertEqual(cc.frame_count, 2)
432        self.assertIn("""L['x'] is L['y']""", failure_reason)
433
434    @patch("torch._functorch.config.debug_assert", True)
435    def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
436        class F(torch.nn.Module):
437            def __init__(self) -> None:
438                super().__init__()
439                self.mean = torch.nn.Parameter(torch.randn(3, 3))
440
441            def forward(self, a, b, e, f):
442                a.trunc_()
443                b.trunc_()
444                return (a + b + self.mean) * e * f
445
446        a = torch.randn(3, 3, requires_grad=True)
447        b = torch.randn(3, 3, requires_grad=True)
448        a1, a2 = a.clone(), a.clone()
449        b1, b2 = b.clone(), b.clone()
450
451        failure_reason = None
452
453        def guard_fail_fn(failure):
454            nonlocal failure_reason
455            failure_reason = failure[0]
456
457        self.assertTrue(failure_reason is None)
458
459        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
460
461        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
462        f(a1, a1, 2, 2)
463        f(a2, b2, 2, 2)
464        self.assertEqual(cc.frame_count, 2)
465        self.assertIn(
466            """L['a'] is L['b']""",
467            failure_reason,
468        )
469
470        torch._dynamo.reset()
471
472        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
473
474        c = torch.randn(3, 3, requires_grad=True)
475        d = torch.randn(3, 3, requires_grad=True)
476        c3, c4 = c.clone(), c.clone()
477        d3, d4 = d.clone(), d.clone()
478
479        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
480        f(c3, c3, 3, 3)
481        f(c4, d4, 3, 3)
482        self.assertEqual(cc.frame_count, 2)
483        self.assertIn("""L['a'] is L['b']""", failure_reason)
484
485    @patch("torch._functorch.config.debug_assert", True)
486    def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
487        z = None
488
489        class F(torch.nn.Module):
490            def __init__(self) -> None:
491                super().__init__()
492                self.mean = torch.nn.Parameter(torch.randn(3, 3))
493
494            def forward(self, a, b, e, f):
495                a.trunc_()
496                b.trunc_()
497                return (a + b + z + self.mean) * e * f
498
499        a = torch.randn(3, 3, requires_grad=True)
500        b = torch.randn(3, 3, requires_grad=True)
501        z = a
502        a1, a2 = a.clone(), a.clone()
503        b1, b2 = b.clone(), b.clone()
504
505        failure_reason = None
506
507        def guard_fail_fn(failure):
508            nonlocal failure_reason
509            failure_reason = failure[0]
510
511        self.assertTrue(failure_reason is None)
512
513        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
514
515        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
516        f(a1, a1, 2, 2)
517        f(a2, b2, 2, 2)
518        self.assertEqual(cc.frame_count, 2)
519        self.assertIn(
520            """L['a'] is L['b']""",
521            failure_reason,
522        )
523
524    @patch("torch._functorch.config.debug_assert", True)
525    def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg_list(self):
526        class F(torch.nn.Module):
527            def __init__(self) -> None:
528                super().__init__()
529                self.mean = torch.nn.Parameter(torch.randn(3, 3))
530
531            def forward(self, e, f, a, b):
532                a.trunc_()
533                b.trunc_()
534                return (a + b + self.mean) * e[0] * f[0]
535
536        a = torch.randn(3, 3, requires_grad=True)
537        b = torch.randn(3, 3, requires_grad=True)
538        a1, a2 = a.clone(), a.clone()
539        b1, b2 = b.clone(), b.clone()
540
541        failure_reason = None
542
543        def guard_fail_fn(failure):
544            nonlocal failure_reason
545            failure_reason = failure[0]
546
547        self.assertTrue(failure_reason is None)
548
549        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
550
551        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
552        f([3, 2, 1], [4, 5, 6], a1, a1)
553        f([3, 2, 1], [4, 5, 6], a2, b2)
554        self.assertEqual(cc.frame_count, 2)
555        self.assertIn(
556            """L['a'] is L['b']""",
557            failure_reason,
558        )
559
560        torch._dynamo.reset()
561
562        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
563
564        c = torch.randn(3, 3, requires_grad=True)
565        d = torch.randn(3, 3, requires_grad=True)
566        c3, c4 = c.clone(), c.clone()
567        d3, d4 = d.clone(), d.clone()
568
569        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
570        f([3, 2, 1], [4, 5, 6], c3, c3)
571        f([3, 2, 1], [4, 5, 6], c4, d4)
572        self.assertEqual(cc.frame_count, 2)
573
574    @patch("torch._functorch.config.debug_assert", True)
575    def test_arg_dupe_via_dynamo_recompiles_many_args_param(self):
576        class F(torch.nn.Module):
577            def __init__(self) -> None:
578                super().__init__()
579                self.mean = torch.nn.Parameter(torch.randn(3, 3))
580
581            def forward(self, a, b):
582                a.trunc_()
583                b.trunc_()
584                return a + b + self.mean
585
586        a = torch.randn(3, 3, requires_grad=True)
587        b = torch.randn(3, 3, requires_grad=True)
588        a1, a2 = a.clone(), a.clone()
589        b1, b2 = b.clone(), b.clone()
590
591        failure_reason = None
592
593        def guard_fail_fn(failure):
594            nonlocal failure_reason
595            failure_reason = failure[0]
596
597        self.assertTrue(failure_reason is None)
598
599        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
600
601        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
602        f(a1, a1)
603        f(a2, b2)
604        self.assertEqual(cc.frame_count, 2)
605        self.assertIn(
606            """L['a'] is L['b']""",
607            failure_reason,
608        )
609
610        torch._dynamo.reset()
611
612        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
613
614        c = torch.randn(3, 3, requires_grad=True)
615        d = torch.randn(3, 3, requires_grad=True)
616        c3, c4 = c.clone(), c.clone()
617        d3, d4 = d.clone(), d.clone()
618
619        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
620        f(c3, c3)
621        f(c4, d4)
622        self.assertEqual(cc.frame_count, 2)
623        self.assertIn("""L['a'] is L['b']""", failure_reason)
624
625    @patch("torch._functorch.config.debug_assert", True)
626    def test_arg_dupe_via_dynamo_recompiles_many_args(self):
627        class F(torch.nn.Module):
628            def forward(self, a, b, c, d):
629                a.trunc_()
630                b.trunc_()
631                c.trunc_()
632                d.trunc_()
633                return (a + b + c + d,)
634
635        a = torch.randn(3, 3, requires_grad=True)
636        b = torch.randn(3, 3, requires_grad=True)
637        a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
638        b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
639
640        failure_reason = None
641
642        def guard_fail_fn(failure):
643            nonlocal failure_reason
644            failure_reason = failure[0]
645
646        self.assertTrue(failure_reason is None)
647
648        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
649
650        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
651        f(a1, a1, a1, a1)
652        f(a2, b2, b2, b2)
653        self.assertEqual(cc.frame_count, 2)
654        self.assertIn(
655            """L['a'] is L['b']""",
656            failure_reason,
657        )
658
659        torch._dynamo.reset()
660
661        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
662
663        c = torch.randn(3, 3, requires_grad=True)
664        d = torch.randn(3, 3, requires_grad=True)
665        c3, c4 = c.clone(), c.clone()
666        d3, d4 = d.clone(), d.clone()
667
668        f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
669        f(a3, b3, c3, c3)
670        f(a4, b4, c4, d4)
671        self.assertEqual(cc.frame_count, 2)
672        self.assertIn("""L['c'] is L['d']""", failure_reason)
673
674    def test_alias_inputs(self):
675        def fn():
676            a = torch.tensor([1])
677            a = a[0:1]
678            b = a.squeeze()
679            a[0] = 0
680            if a[0] < 1e5:
681                pass
682            a[0] = 2
683            return b
684
685        ref_output = fn()
686        aot_fn = torch._dynamo.optimize("aot_eager")(fn)
687        actual_output = aot_fn()
688        self.assertEqual(ref_output, actual_output)
689
690    def test_grad_inputs_alias_inputs(self):
691        class Test(torch.autograd.Function):
692            @staticmethod
693            def forward(ctx, x, y):
694                ctx.save_for_backward(x)
695                return y
696
697            @staticmethod
698            def backward(ctx, grad):
699                (x,) = ctx.saved_tensors
700                return x, grad
701
702        def fn(x, y):
703            return Test.apply(x, y)
704
705        x = torch.ones(1, requires_grad=True)
706        y = torch.ones(1, requires_grad=True)
707        compiled_fn = torch.compile(fn, backend="aot_eager")
708        out = compiled_fn(x, y)
709        out.sum().backward()
710
711    @expectedFailureDynamic  # https://github.com/pytorch/pytorch/issues/103539
712    @torch._dynamo.config.patch(automatic_dynamic_shapes=False)
713    @patch("torch._functorch.config.debug_assert", True)
714    def test_multiple_aot_autograd_calls_dupe_args(self):
715        # this is just dealing with the fact that
716        # aot_module_simplified expects submods to always return tuples/lists
717        class WrapperModule(torch.nn.Module):
718            def __init__(self, mod):
719                super().__init__()
720                self.mod = mod
721
722            def forward(self, *args):
723                out = self.mod(*args)
724                if isinstance(out, (list, tuple)):
725                    return out
726                return (out,)
727
728        def compile_submod(input_mod, args):
729            from functorch.compile import nop
730            from torch._functorch.aot_autograd import aot_module_simplified
731
732            class WrapperModule(torch.nn.Module):
733                def __init__(self) -> None:
734                    super().__init__()
735                    self.original = input_mod
736                    self.submod = aot_module_simplified(input_mod, args, nop)
737
738                def forward(self, *args):
739                    return self.submod(*args)
740
741            return WrapperModule()
742
743        def test_compile(fx_g, example_inps):
744            split_gm = torch.fx.passes.split_module.split_module(
745                fx_g, None, lambda node: 1 if "mul" in str(node) else 0
746            )
747            submod_1_inps = split_gm.submod_0(*example_inps)
748            split_gm.submod_0 = compile_submod(
749                WrapperModule(split_gm.submod_0), example_inps
750            )
751            split_gm.submod_1 = compile_submod(
752                WrapperModule(split_gm.submod_1), submod_1_inps
753            )
754            return split_gm
755
756        @torch._dynamo.optimize(test_compile)
757        def f(a):
758            b, c = torch.ops.custom.maybe_dupe_op(a)
759            return (b.mul_(c),)
760
761        f(torch.ones(4))
762        f(torch.ones(6))
763
764    def test_nn_parameter_construction(self):
765        # https://github.com/pytorch/pytorch/issues/99569
766        def fn(x):
767            y = x.sin()
768            z = torch.nn.Parameter(torch.ones(1))
769            return y + z
770
771        x = torch.rand((4, 4))
772
773        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
774        self.assertTrue(torch._dynamo.testing.same(fn(x), opt_fn(x)))
775
776    def test_aot_sequence_nr(self):
777        class Model(torch.nn.Module):
778            def __init__(self) -> None:
779                super().__init__()
780                self.conv1 = torch.nn.Conv2d(
781                    in_channels=16,
782                    out_channels=16,
783                    kernel_size=(1, 1),
784                    stride=1,
785                    padding="same",
786                    bias=True,
787                )
788                self.bn1 = torch.nn.BatchNorm2d(num_features=16)
789                self.relu1 = torch.nn.ReLU()
790                self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1)
791                self.loss_fn = torch.nn.L1Loss()
792
793            def forward(self, x, target):
794                y = x
795                x = self.conv1(x)
796                x = self.bn1(x)
797                x = self.relu1(x)
798                x = x + y
799                x = torch.flatten(x)
800                x = self.fc1(x)
801                output = self.loss_fn(x, target)
802
803                return (output,)
804
805        mod = Model()
806        mod.train()
807        x = torch.rand(100, 16, 32, 32, requires_grad=True)
808        target = torch.rand(1)
809
810        # Use dynamo export to get the fx graph module
811        g_mod, _ = torch._dynamo.export(mod, x, target)
812
813        def _prepare_model_args():
814            named_parameters = dict(g_mod.named_parameters(remove_duplicate=False))
815            named_buffers = dict(g_mod.named_buffers(remove_duplicate=False))
816            params_and_buffers = {
817                **dict(named_parameters),
818                **dict(named_buffers),
819            }
820            params_and_buffers_flat, params_spec = pytree.tree_flatten(
821                params_and_buffers
822            )
823            params_len = len(params_and_buffers_flat)
824            functional_call = create_functional_call(g_mod, params_spec, params_len)
825            return params_and_buffers_flat, functional_call
826
827        full_args, fn_to_trace = _prepare_model_args()
828        param_and_buf_len = len(full_args)
829        full_args.extend([x, target])
830
831        # aot_export requires a graph mod input of fwd graph
832        # returns the full fwd/bwd graph in graph mod format
833        with torch.enable_grad(), fx_traceback.preserve_node_meta():
834            fx_g, _, _, _ = _aot_export_function(
835                fn_to_trace,
836                full_args,
837                decompositions=None,
838                num_params_buffers=param_and_buf_len,
839                no_tangents=True,
840            )
841
842        # Walk all the nodes in fx graph.
843        # Write the resulting ops to a table
844        min_seq_nr = -1
845        seq_table = "SeqNr|OrigAten|SrcFn|FwdSrcFn\n"
846        for node in fx_g.graph.nodes:
847            if "call_" in node.op and "getitem" not in str(node.target):
848                seq_nr = node.meta.get("seq_nr", -1)
849                if seq_nr < 0:
850                    continue
851                if min_seq_nr < 0:
852                    min_seq_nr = seq_nr
853                source_fn_stack = node.meta.get("source_fn_stack", [])
854                orig_aten = node.meta.get("original_aten", "")
855                mod_name = ""
856                if len(source_fn_stack) > 0:
857                    mod_name = source_fn_stack[-1][0]
858                # Make all seq_nr relative so it starts at 0
859                seq_nr = seq_nr - min_seq_nr
860                # For backward nodes, also test that metadata from the corresponding
861                # forward node is copied over.
862                fwd_source_fn_stack = node.meta.get("fwd_source_fn_stack", [])
863                fwd_mod_name = ""
864                if len(fwd_source_fn_stack):
865                    fwd_mod_name = fwd_source_fn_stack[-1][0]
866                seq_table = (
867                    seq_table + f"{seq_nr}|{orig_aten}|{mod_name}|{fwd_mod_name}\n"
868                )
869
870        self.maxDiff = None
871        self.assertExpectedInline(
872            seq_table,
873            dedent(
874                """\
875SeqNr|OrigAten|SrcFn|FwdSrcFn
8760|aten.convolution.default|l__self___conv1|
8770|aten.add.Tensor|l__self___bn1|
8781|aten._native_batch_norm_legit_functional.default|l__self___bn1|
8792|aten.relu.default|l__self___relu1|
8802|aten.detach.default|l__self___relu1|
8812|aten.detach.default|l__self___relu1|
8823|aten.add.Tensor|add|
8834|aten.view.default|flatten|
8845|aten.view.default|l__self___fc1|
8856|aten.t.default|l__self___fc1|
8867|aten.addmm.default|l__self___fc1|
8878|aten.view.default|l__self___fc1|
8889|aten.sub.Tensor|l__self___loss_fn|
88910|aten.abs.default|l__self___loss_fn|
89011|aten.mean.default|l__self___loss_fn|
89111|aten.ones_like.default||l__self___loss_fn
89211|aten.expand.default||l__self___loss_fn
89311|aten.div.Scalar||l__self___loss_fn
89410|aten.sgn.default||l__self___loss_fn
89510|aten.mul.Tensor||l__self___loss_fn
8968|aten.view.default||l__self___fc1
8977|aten.t.default||l__self___fc1
8987|aten.mm.default||l__self___fc1
8997|aten.t.default||l__self___fc1
9007|aten.mm.default||l__self___fc1
9017|aten.t.default||l__self___fc1
9027|aten.sum.dim_IntList||l__self___fc1
9037|aten.view.default||l__self___fc1
9046|aten.t.default||l__self___fc1
9055|aten.view.default||l__self___fc1
9064|aten.view.default||
9072|aten.detach.default||l__self___relu1
9082|aten.detach.default||l__self___relu1
9092|aten.threshold_backward.default||l__self___relu1
9101|aten.native_batch_norm_backward.default||l__self___bn1
9110|aten.convolution_backward.default||l__self___conv1
91211|aten.add.Tensor||l__self___loss_fn
913"""
914            ),
915        )
916
917    def test_split_with_sizes_aot_autograd_cleans_up_traceback_meta(self):
918        from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
919
920        def fn(result, split_sizes):
921            rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
922            return rs
923
924        example_inputs = (
925            torch.randn(32, requires_grad=True),
926            torch.tensor((7, 16, 9)),
927        )
928        outs = fn(*example_inputs)
929        setup_stacktrace_preservation_hooks([out.grad_fn for out in outs])
930        with fx_traceback.preserve_node_meta():
931            (outs[0].sum() + outs[1].sum() + outs[2].sum()).backward()
932
933        self.assertNotIn("grad_fn_seq_nr", fx_traceback.current_meta)
934        self.assertNotIn("in_grad_fn", fx_traceback.current_meta)
935
936    # https://github.com/pytorch/pytorch/issues/110121
937    def test_aot_export_joint_simple_repro(self):
938        class Mod(torch.nn.Module):
939            def __init__(self, *args, **kwargs) -> None:
940                super().__init__(*args, **kwargs)
941                self.linear = torch.nn.Linear(5, 7)
942
943            def forward(self, x):
944                return self.linear(x)
945
946        def mini_backend(gm, sample_inputs):
947            from torch._functorch.aot_autograd import aot_export_joint_simple
948
949            fake_mode = torch._dynamo.utils.detect_fake_mode(sample_inputs)
950
951            with patch.object(fake_mode, "allow_non_fake_inputs", True), fake_mode:
952                return aot_export_joint_simple(gm, sample_inputs, trace_joint=False)
953
954        sample_inputs = [torch.rand((3, 4, 5))]
955        model = Mod()
956        m_compiled = torch.compile(model, backend=mini_backend)
957
958        out_ref = model(*sample_inputs)
959        out_test = m_compiled(*sample_inputs)
960        self.assertEqual(out_ref, out_test)
961
962    def test_eager_sequence_nr(self):
963        class Model(torch.nn.Module):
964            def __init__(self) -> None:
965                super().__init__()
966                self.conv1 = torch.nn.Conv2d(
967                    in_channels=16,
968                    out_channels=16,
969                    kernel_size=(1, 1),
970                    stride=1,
971                    padding="same",
972                    bias=True,
973                )
974                self.bn1 = torch.nn.BatchNorm2d(num_features=16)
975                self.relu1 = torch.nn.ReLU()
976                self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1)
977                self.loss_fn = torch.nn.L1Loss()
978
979            def forward(self, x, target):
980                y = x
981                x = self.conv1(x)
982                x = self.bn1(x)
983                x = self.relu1(x)
984                x = x + y
985                x = torch.flatten(x)
986                x = self.fc1(x)
987                output = self.loss_fn(x, target)
988
989                return (output,)
990
991        def grad_with_create_graph(mod, x, target):
992            y = mod(x, target)
993            # Set create_graph=True to ensure that the sequence_nr
994            # for backward ops continues to count down.
995            (gx,) = torch.autograd.grad(
996                y[0], x, create_graph=True, grad_outputs=grad_output
997            )
998            return gx
999
1000        x = torch.rand(100, 16, 32, 32, requires_grad=True)
1001        target = torch.rand(1)
1002        mod = Model()
1003        args = [mod, x, target]
1004        grad_output = torch.tensor(1.0, requires_grad=True)
1005        compiled_f1 = torch.compile(backend="aot_eager")(grad_with_create_graph)
1006        model_instance = compiled_f1
1007        with profile(
1008            activities=[torch.profiler.ProfilerActivity.CPU],
1009            record_shapes=True,
1010        ) as kineto_prof:
1011            res = model_instance(*args)
1012        bwd_set = set()
1013        prof_str = "SeqNr|Thread|FwdThread|Name\n"
1014        for event in kineto_prof.events():
1015            if event.sequence_nr >= 0:
1016                prof_str = (
1017                    prof_str + f"{event.sequence_nr}|{event.thread}"
1018                    f"|{event.fwd_thread}|{event.name}|\n"
1019                )
1020                if re.search(r"Backward[01]", event.name):
1021                    bwd_set.add(event.sequence_nr)
1022        self.assertTrue(len(bwd_set), 13)
1023
1024    def test_aot_grad_mode_mutation(self):
1025        for compiler in ["aot_eager", "inductor"]:
1026
1027            def f(x):
1028                y = x * x
1029                torch.set_grad_enabled(False)
1030                return y.clone(), y
1031
1032            f_compiled = torch.compile(f, backend=compiler, fullgraph=True)
1033
1034            torch.set_grad_enabled(True)
1035            x = torch.ones(3, requires_grad=True) * 3
1036            y_ref = f(x)
1037            self.assertEqual(torch.is_grad_enabled(), False)
1038            torch.set_grad_enabled(True)
1039            y = f_compiled(x)
1040            self.assertEqual(torch.is_grad_enabled(), False)
1041            torch.set_grad_enabled(True)
1042            self.assertEqual(y_ref, y)
1043
1044            self.assertIsNone(y_ref[0].grad_fn)
1045            self.assertIsNone(y[0].grad_fn)
1046
1047            self.assertIsNotNone(y_ref[1].grad_fn)
1048            self.assertIsNotNone(y[1].grad_fn)
1049
1050            # Check that the grad computed for the inputs, given the input, is the same
1051            # The tangent to `y[0]`, which has grad_required=False, is irrelevant
1052            self.assertEqual(
1053                sum(y_ref[1].grad_fn(torch.tensor([-1.0, 2.0, 0.0]))),
1054                sum(
1055                    x
1056                    for x in y[1].grad_fn.apply(None, torch.tensor([-1.0, 2.0, 0.0]))
1057                    if x is not None
1058                ),
1059            )
1060
1061    def test_aot_autograd_raises_invalid_leaf_set(self):
1062        @torch.compile
1063        def f(x):
1064            x.set_(torch.ones(2))
1065
1066        # We still want to make sure that this raises
1067        x = torch.ones(2, requires_grad=True)
1068        with self.assertRaisesRegex(
1069            RuntimeError, "is being used in an in-place operation"
1070        ):
1071            f(x)
1072
1073    def test_aot_autograd_expand_mutation_functionalizes(self):
1074        def fn(x):
1075            y = x.expand(3, *x.shape)
1076            y[0, 0].add_(5)
1077            return y
1078
1079        opt_fn = torch.compile(fn, backend="aot_eager")
1080
1081        x = torch.arange(6)
1082        x_opt = x.clone().detach()
1083        self.assertEqual(fn(x), opt_fn(x_opt))
1084        self.assertEqual(x, x_opt)
1085
1086    def test_aot_autograd_expand_mutation_backwards(self):
1087        def fn(x, z):
1088            y = x.expand(3, *x.shape)
1089            y[1, 1].mul_(5)
1090            ret = y * z
1091            return ret
1092
1093        opt_fn = torch.compile(fn, backend="aot_eager")
1094
1095        x = torch.arange(6, dtype=torch.float)
1096        z = x.clone().detach()
1097        x_opt = x.clone().detach()
1098        z_opt = x.clone().detach()
1099
1100        z.requires_grad = True
1101        z_opt.requires_grad = True
1102
1103        res = fn(x, z)
1104        opt_res = opt_fn(x_opt, z_opt)
1105
1106        self.assertEqual(res, opt_res)
1107
1108        res.sum().backward()
1109        opt_res.sum().backward()
1110
1111        self.assertEqual(x, x_opt)
1112        self.assertEqual(z.grad, z_opt.grad)
1113
1114    def test_data_ptr_access_copy(self):
1115        import torch._functorch.config as _config
1116
1117        with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
1118            with FakeTensorMode():
1119                x = torch.randn(3)
1120                y = copy.copy(x)
1121        self.assertEqual(y.shape, x.shape)
1122
1123    def test_data_ptr_access_fails_in_forward(self):
1124        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1125            torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
1126
1127            @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
1128            def _(x):
1129                x.data_ptr()
1130                return x.clone()
1131
1132            x = torch.randn(3)
1133
1134            def data_ptr_graph_input(x):
1135                r0 = torch.ops.mylib.foo(x)
1136                return r0
1137
1138            def data_ptr_graph_intermediate(x):
1139                y = x.clone()
1140                r0 = torch.ops.mylib.foo(y)
1141                return r0
1142
1143            tests = [data_ptr_graph_input, data_ptr_graph_intermediate]
1144
1145            def ctx():
1146                return self.assertRaisesRegex(
1147                    RuntimeError, "Cannot access data pointer"
1148                )
1149
1150            for f in tests:
1151                with ctx():
1152                    make_fx(f, tracing_mode="fake")(x)
1153                with ctx():
1154                    make_fx(f, tracing_mode="symbolic")(x)
1155                with ctx():
1156                    torch.compile(f, backend="eager", fullgraph=True)(x)
1157
1158    def test_data_ptr_access_fails_in_backward(self):
1159        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
1160            torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
1161
1162            backward_called = False
1163
1164            class Foo(torch.autograd.Function):
1165                @staticmethod
1166                def forward(ctx, x):
1167                    return x.clone()
1168
1169                @staticmethod
1170                def backward(ctx, grad):
1171                    nonlocal backward_called
1172                    backward_called = True
1173                    grad.data_ptr()
1174                    return grad.clone()
1175
1176            @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
1177            def _(x):
1178                return Foo.apply(x)
1179
1180            def f(x):
1181                return torch.ops.mylib.foo(x)
1182
1183            x = torch.randn(3, requires_grad=True)
1184            with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
1185                y = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
1186            self.assertTrue(backward_called)
1187
1188    # We don't know how to catch multiple mutations to the same memory location
1189    @unittest.expectedFailure
1190    def test_aot_autograd_expand_mutation_error(self):
1191        def fn(x):
1192            y = x.expand(3, *x.shape)
1193            y[0:3, 0].add_(5)
1194            return y
1195
1196        opt_fn = torch.compile(fn, backend="aot_eager")
1197
1198        x = torch.arange(6)
1199        x_opt = x.clone().detach()
1200        with self.assertRaises(Exception):
1201            fn(x)
1202        with self.assertRaises(Exception):
1203            opt_fn(x_opt)
1204
1205    @torch._functorch.config.patch(donated_buffer=True)
1206    def test_donated_buffer1(self):
1207        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
1208
1209        @torch.compile()
1210        def relu(x):
1211            return torch.nn.functional.relu(x)
1212
1213        with self.assertLogs(logger_name, level="INFO") as captured:
1214            relu(torch.rand([3, 3], requires_grad=True)).sum().backward()
1215
1216        if is_dynamic_shape_test(self._testMethodName):
1217            # an extra symint exists
1218            expected_msg = "bw_donated_idxs=[1]"
1219        else:
1220            expected_msg = "bw_donated_idxs=[0]"
1221
1222        # le is a donated buffer from relu
1223        FileCheck().check(expected_msg).run("\n".join(captured.output))
1224
1225    @torch._functorch.config.patch("donated_buffer", True)
1226    def test_donated_buffer2(self):
1227        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
1228
1229        # we will re-use the graph for g across f1 and f2
1230        @torch.compile()
1231        def g(activation, param2):
1232            return torch.matmul(activation, param2)
1233
1234        def f(inp, param1, param2):
1235            activation = inp + param1
1236            return g(activation, param2)
1237
1238        inp = torch.ones(4, 4)
1239        param1 = torch.ones(4, 4, requires_grad=True)
1240        param2 = torch.ones(4, 4, requires_grad=True)
1241
1242        with self.assertLogs(logger_name, level="INFO") as captured:
1243            f(inp, param1, param2).sum().backward()
1244
1245        FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output))
1246
1247    @torch._functorch.config.patch("donated_buffer", True)
1248    def test_donated_buffer3(self):
1249        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
1250
1251        # we will re-use the graph for g across f1 and f2
1252        @torch.compile()
1253        def g(activation, param2):
1254            return torch.matmul(activation, param2)
1255
1256        def f(inp, param1, param2):
1257            # exp saves it output (the activation) for bw
1258            activation = torch.exp(inp + param1)
1259            return g(activation, param2)
1260
1261        inp = torch.ones(4, 4)
1262        param1 = torch.ones(4, 4, requires_grad=True)
1263        param2 = torch.ones(4, 4, requires_grad=True)
1264
1265        with self.assertLogs(logger_name, level="INFO") as captured:
1266            f(inp, param1, param2).sum().backward()
1267
1268        FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output))
1269
1270    @torch._functorch.config.patch("donated_buffer", True)
1271    def test_donated_buffer4(self):
1272        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
1273
1274        class Mod(torch.nn.Module):
1275            def __init__(self) -> None:
1276                super().__init__()
1277                self.param = torch.nn.Parameter(torch.zeros([2, 2]))
1278
1279            def forward(self, x: torch.Tensor) -> torch.Tensor:
1280                return torch.nn.functional.relu(x) + self.param
1281
1282        mod = Mod()
1283        mod = torch.compile(mod)
1284
1285        inp = torch.ones([2, 2], requires_grad=True)
1286
1287        with self.assertLogs(logger_name, level="INFO") as captured:
1288            mod(inp).sum().backward()
1289
1290        # Forward graph:
1291        #   %primals_1 : [num_users=1] = placeholder[target=primals_1]
1292        #   %primals_2 : [num_users=1] = placeholder[target=primals_2]
1293        #   %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%primals_2,), kwargs = {})
1294        #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %primals_1), kwargs = {})
1295        #   %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
1296        #   return [add, le]
1297        #
1298        # `le` is a donated buffer
1299        FileCheck().check("bw_donated_idxs=[0]").run("\n".join(captured.output))
1300
1301    @torch._functorch.config.patch("donated_buffer", True)
1302    def test_donated_buffer5(self):
1303        logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
1304
1305        @torch.compile()
1306        def f(x, z):
1307            y = x.view(2, 3)
1308            z = torch.nn.functional.relu(z)
1309            return torch.mm(y, x) + z
1310
1311        inp = [
1312            torch.rand([3, 2], requires_grad=True),
1313            torch.rand([2, 2], requires_grad=True),
1314        ]
1315
1316        with self.assertLogs(logger_name, level="INFO") as captured:
1317            f(*inp).sum().backward()
1318
1319        # Forward graph:
1320        #   %primals_1 : [num_users=3] = placeholder[target=primals_1]
1321        #   %primals_2 : [num_users=1] = placeholder[target=primals_2]
1322        #   %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%primals_1, [2, 3]), kwargs = {})
1323        #   %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%primals_2,), kwargs = {})
1324        #   %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %primals_1), kwargs = {})
1325        #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mm, %relu), kwargs = {})
1326        #   %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
1327        #   return [add, primals_1, le]
1328        #
1329        # `le` is a donated buffer but primals_1 is not.
1330        FileCheck().check("bw_donated_idxs=[1]").run("\n".join(captured.output))
1331
1332    @torch._functorch.config.patch("donated_buffer", True)
1333    def test_donated_buffer_with_retain_or_create_graph1(self):
1334        # Gives non-empty bw_donated_idxs
1335        class Mod(torch.nn.Module):
1336            def __init__(self) -> None:
1337                super().__init__()
1338                self.param = torch.nn.Parameter(torch.zeros([3, 3]))
1339
1340            def forward(self, x):
1341                return torch.nn.functional.relu(x) + self.param
1342
1343        inp = torch.randn(3, 3, requires_grad=True)
1344
1345        mod = torch.compile(Mod())
1346        for _ in range(5):
1347            mod(inp).sum().backward()
1348
1349    @torch._functorch.config.patch("donated_buffer", True)
1350    def test_donated_buffer_with_retain_or_create_graph2(self):
1351        # Gives non-empty bw_donated_idxs
1352        class Mod(torch.nn.Module):
1353            def __init__(self) -> None:
1354                super().__init__()
1355                self.param = torch.nn.Parameter(torch.zeros([3, 3]))
1356
1357            def forward(self, x):
1358                return torch.nn.functional.relu(x) + self.param
1359
1360        inp = torch.randn(3, 3, requires_grad=True)
1361
1362        mod = torch.compile(Mod())
1363        out = mod(inp).sum()
1364        for _ in range(5):
1365            out.backward(retain_graph=True)
1366        out.backward()
1367
1368    @torch._functorch.config.patch("donated_buffer", True)
1369    def test_donated_buffer_with_retain_or_create_graph3(self):
1370        # Gives non-empty bw_donated_idxs
1371        class Mod(torch.nn.Module):
1372            def __init__(self) -> None:
1373                super().__init__()
1374                self.param = torch.nn.Parameter(torch.zeros([3, 3]))
1375
1376            def forward(self, x):
1377                return torch.nn.functional.relu(x) + self.param
1378
1379        inp = torch.randn(3, 3, requires_grad=True)
1380
1381        mod = torch.compile(Mod())
1382        mod(inp).sum().backward(create_graph=True)
1383        out = mod(inp).sum()
1384        for _ in range(5):
1385            out.backward(retain_graph=True)
1386        out.backward()
1387
1388    @torch._functorch.config.patch("donated_buffer", True)
1389    def test_donated_buffer_with_retain_or_create_graph4(self):
1390        # Gives non-empty bw_donated_idxs
1391        class Mod(torch.nn.Module):
1392            def __init__(self) -> None:
1393                super().__init__()
1394                self.param = torch.nn.Parameter(torch.zeros([3, 3]))
1395
1396            def forward(self, x):
1397                return torch.nn.functional.relu(x) + self.param
1398
1399        inp = torch.randn(3, 3, requires_grad=True)
1400
1401        mod = torch.compile(Mod())
1402        mod(inp).sum().backward()
1403        out = mod(inp).sum()
1404        with self.assertRaisesRegex(
1405            RuntimeError,
1406            r"This backward function was compiled with non-empty donated "
1407            r"buffers which requires create_graph=False and retain_graph=False. "
1408            r"Please keep backward\(create_graph=False, retain_graph=False\) "
1409            r"across all backward\(\) function calls, or set "
1410            r"torch._functorch.config.donated_buffer=False to disable "
1411            r"donated buffer.",
1412        ):
1413            out.backward(retain_graph=True)
1414
1415
1416if __name__ == "__main__":
1417    from torch._dynamo.test_case import run_tests
1418
1419    run_tests()
1420