xref: /aosp_15_r20/external/pytorch/test/dynamo/test_higher_order_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import enum
3import functools
4import pprint
5import re
6import unittest
7import warnings
8
9import functorch.experimental.control_flow as control_flow
10import torch
11import torch._dynamo.config as config
12import torch._dynamo.test_case
13import torch._functorch.config
14import torch.nn as nn
15import torch.utils._pytree as pytree
16import torch.utils.checkpoint
17from torch._dynamo.backends.common import aot_autograd
18from torch._dynamo.testing import (
19    CompileCounter,
20    CompileCounterWithBackend,
21    EagerAndRecordGraphs,
22    empty_line_normalizer,
23    normalize_gm,
24)
25from torch._dynamo.utils import counters, ifdynstaticdefault
26from torch._higher_order_ops.hints_wrap import hints_wrapper
27from torch._higher_order_ops.wrap import wrap
28from torch.testing._internal.common_utils import (
29    munge_exc,
30    TEST_WITH_TORCHDYNAMO,
31    xfailIfTorchDynamo,
32)
33from torch.testing._internal.inductor_utils import HAS_CUDA
34from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
35
36
37requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
38
39
40def check_dynamic_shape_capture():
41    # This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
42    return not config.assume_static_by_default
43
44
45def count_ops(gm, args, freq, op):
46    actual = [node.target for node in gm.graph.nodes].count(op)
47    assert actual == freq, f"expected={freq}, actual={actual}"
48    return gm
49
50
51class Obj:
52    pass
53
54
55class MyModule(nn.Module):
56    def __init__(self) -> None:
57        super().__init__()
58        self.existing = torch.nn.Parameter(torch.ones([]))
59
60    def forward(self, x):
61        return self.existing * x
62
63
64global_obj = Obj()
65global_module = MyModule()
66global_var = torch.randn(3)
67global_num = 3.14
68global_list = []
69
70
71def find_first_node(gm, func):
72    for node in gm.graph.nodes:
73        if node.target is func:
74            return node
75    return None
76
77
78def op_count(gm):
79    result = 0
80    for node in gm.graph.nodes:
81        if "call" in node.op:
82            result += 1
83    return result
84
85
86# Checks that a dict matches a dict with "regex keys". That is,
87# the keys are regex expressions.
88def assert_dict_matches_regex(self, dct, dct_with_regex_keys):
89    regex_keys = dct_with_regex_keys.keys()
90    regex_key_to_actual_key = {}
91    for regex_key in regex_keys:
92        for key in dct:
93            if re.match(regex_key, key):
94                if regex_key in regex_key_to_actual_key:
95                    raise AssertionError(
96                        f"Single key regex mapped to multiple keys. Please improve your "
97                        f"regex. Got: regex='{regex_key}' "
98                        f"keys='{regex_key_to_actual_key[regex_key]}',"
99                        f"'{key}'"
100                    )
101                regex_key_to_actual_key[regex_key] = key
102    new_dct = {}
103    for regex_key in regex_keys:
104        if regex_key not in regex_key_to_actual_key:
105            raise AssertionError(
106                f"Got regex '{regex_key}' but could not match any key in dict with "
107                f"keys {dct.keys()}"
108            )
109        new_dct[regex_key_to_actual_key[regex_key]] = dct_with_regex_keys[regex_key]
110    self.assertEqual(dct, new_dct)
111
112
113def default_args_generator(seed_value):
114    flat_args, args_spec = pytree.tree_flatten(seed_value)
115    for i in range(3):
116        new_flat_arg = []
117        for val in flat_args:
118            if isinstance(val, torch.Tensor):
119                new_val = val + 0.1 * i
120            elif isinstance(val, int):
121                new_val = val + 1 * i
122            elif isinstance(val, float):
123                new_val = val + 0.1 * i
124            elif isinstance(val, enum.Enum):
125                new_val = val
126            else:
127                raise AssertionError("unexpected arg type")
128
129            new_flat_arg.append(new_val)
130        new_args = pytree.tree_unflatten(new_flat_arg, args_spec)
131        yield new_args
132
133
134class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
135    def _assert_wrap_fallback(self, func, args, setup=lambda: None):
136        counters.clear()
137        backend = EagerAndRecordGraphs()
138        cnt = CompileCounterWithBackend(backend)
139
140        setup()
141        expected = func(*args)
142        setup()
143        result = torch.compile(func, backend=cnt, fullgraph=False)(*args)
144        num_graph_breaks = len(counters["graph_break"].keys())
145        self.assertGreater(num_graph_breaks, 0)
146
147        for gm in backend.graphs:
148            for node in gm.graph.nodes:
149                self.assertFalse(node.target is wrap)
150
151        self.assertEqual(result, expected)
152
153    def _test_wrap_simple(
154        self,
155        func,
156        args_generator,
157        expected_num_wrap_args,
158        expected_opcount=2,
159        return_graph=False,
160    ):
161        # Given a `func` that has a single call to `wrap`,
162        # we check that:
163        # - there are no graph breaks
164        # - eager vs torch.compile has the same result (correctness)
165        # - other compilation metrics, e.g, # of ops in the dynamo captured graph,
166        #   the wrap has the expected number of args, etc
167        #
168        # we have one or multiple runs through with each of the args from args_generator,
169        # and we will check:
170        # - correctness and no graph breaks for every run
171        # - other compilation metrics only for the first run, since automatic_dynamic_shapes
172        #   may compile another dynamic version graph for the later runs
173        graph = None
174        for i, args in enumerate(args_generator):
175            backend = EagerAndRecordGraphs()
176            cnt = CompileCounterWithBackend(backend)
177            expected = func(*args)
178            result = torch.compile(func, fullgraph=True, backend=cnt)(*args)
179            # check correctness and no graph breaks
180            self.assertEqual(result, expected)
181            self.assertEqual(cnt.frame_count, 1)
182            self.assertEqual(len(backend.graphs), 1)
183            # check other compilation metrics
184            if i == 0:
185                self.assertEqual(cnt.op_count, expected_opcount)
186                graph = backend.graphs[0]
187                wrap_node = find_first_node(graph, wrap)
188                self.assertEqual(len(wrap_node.args), expected_num_wrap_args)
189        # We always return/check the graph from the first run if return_graph = True
190        if return_graph:
191            return normalize_gm(graph.print_readable(print_output=False))
192
193    def test_error_message_sane(self):
194        foo = []
195
196        def inner(x):
197            foo.append(x)
198            return x.clone()
199
200        @torch.compile(backend="eager", fullgraph=True)
201        def f(x):
202            return wrap(inner, x)
203
204        x = torch.randn(3)
205        with self.assertRaisesRegex(
206            torch._dynamo.exc.Unsupported,
207            r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)",
208        ):
209            f(x)
210
211    def test_no_freevars(self):
212        def f(x):
213            return wrap(lambda x: torch.sin(x), x)
214
215        x = torch.randn(3)
216        self._test_wrap_simple(f, default_args_generator((x,)), 2)
217
218    def test_enum_arg(self):
219        class SomeEnum(enum.Enum):
220            A = 0
221            B = 1
222
223        def g(x, val):
224            if val == SomeEnum.A:
225                return torch.sin(x)
226            return torch.cos(x)
227
228        def f(x, val):
229            return wrap(g, x, val)
230
231        x = torch.randn(3)
232        self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), 2)
233
234    def test_return_captured_var(self):
235        freevar = torch.randn(3)
236
237        def test(x):
238            return freevar
239
240        def fn(x):
241            return wrap(test, x)
242
243        x = torch.randn(3)
244
245        # Since, `x` is unused, we don't lift it to
246        # be the input.
247        self._test_wrap_simple(fn, default_args_generator((x,)), 2)
248
249    def test_return_captured_vars(self):
250        freevar1 = torch.randn(3)
251        freevar2 = torch.randn(3)
252
253        def test(x):
254            return freevar1, freevar2, freevar1
255
256        def fn(x):
257            return wrap(test, x)
258
259        x = torch.randn(3)
260
261        # Since, `x` is unused, we don't lift it to
262        # be the input.
263        self._test_wrap_simple(fn, default_args_generator((x,)), 3, 4)
264
265    def test_return_captured_var_used_multiple_times(self):
266        freevar = torch.randn(3)
267
268        def test(x):
269            y = x + freevar
270            return y, freevar
271
272        def fn(x):
273            return wrap(test, x)
274
275        x = torch.randn(3)
276        self._test_wrap_simple(fn, default_args_generator((x,)), 3, 3)
277
278    def test_capture_untracked_global(self):
279        def f(x):
280            return wrap(lambda x: x + global_var, x)
281
282        x = torch.randn(3)
283        self._test_wrap_simple(f, default_args_generator((x,)), 3)
284
285    def test_symint_input(self):
286        def f(x):
287            i = x.size(0)
288            return wrap(lambda x, i: x.view(i), x, i)
289
290        x = torch.randn(3, 1)
291        self._test_wrap_simple(
292            f,
293            default_args_generator((x,)),
294            ifdynstaticdefault(2, 3),
295            expected_opcount=2,
296        )
297
298    def test_wrap_pytree_args_nested(self):
299        def f(x, y, z):
300            def fn(d):
301                return d["x"].sin() + d["y"][0].cos() - d["y"][1][2].sin()
302
303            return wrap(fn, d)
304
305        x = torch.tensor(1.5)
306        y = torch.tensor(2.0)
307        z = torch.tensor(3.0)
308        d = {"x": x, "y": (y, [x, y, z])}
309
310        def my_args_generator(t):
311            yield t
312            yield t[0] + 0.1, t[1], t[2]
313            yield t[0], t[1] + 0.1, t[2]
314
315        actual_graph = self._test_wrap_simple(
316            f,
317            my_args_generator((x, y, z)),
318            4,
319            return_graph=True,
320        )
321        self.assertExpectedInline(
322            actual_graph,
323            """\
324class GraphModule(torch.nn.Module):
325    def forward(self, L_d_x_: "f32[]", L_d_y_0_: "f32[]", L_d_y_1_2_: "f32[]"):
326        l_d_x_ = L_d_x_
327        l_d_y_0_ = L_d_y_0_
328        l_d_y_1_2_ = L_d_y_1_2_
329
330        wrap_body_0 = self.wrap_body_0
331        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_);  wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
332        getitem: "f32[]" = wrap[0];  wrap = None
333        return (getitem,)
334
335    class wrap_body_0(torch.nn.Module):
336        def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"):
337            sin: "f32[]" = l_d_x_.sin();  l_d_x_ = None
338            cos: "f32[]" = l_d_y_0_.cos();  l_d_y_0_ = None
339            add: "f32[]" = sin + cos;  sin = cos = None
340            sin_1: "f32[]" = l_d_y_1_2_.sin();  l_d_y_1_2_ = None
341            sub: "f32[]" = add - sin_1;  add = sin_1 = None
342            return (sub,)
343""",  # NOQA: B950
344        )
345
346    def test_wrap_pytree_args_with_symint_constant(self):
347        def f(x, y):
348            i = x.size(0)
349            return wrap(lambda t: t[0].view(t[2]) + t[1], (x, y, i))
350
351        x = torch.randn(3, 1)
352        y = 0.5
353        actual_graph = self._test_wrap_simple(
354            f,
355            default_args_generator((x, y)),
356            ifdynstaticdefault(2, 3),
357            expected_opcount=2,
358            return_graph=True,
359        )
360        if torch._dynamo.config.assume_static_by_default:
361            self.assertExpectedInline(
362                actual_graph,
363                """\
364class GraphModule(torch.nn.Module):
365    def forward(self, L_x_: "f32[3, 1]"):
366        l_x_ = L_x_
367
368        wrap_body_0 = self.wrap_body_0
369        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
370        getitem: "f32[3]" = wrap[0];  wrap = None
371        return (getitem,)
372
373    class wrap_body_0(torch.nn.Module):
374        def forward(self, l_x_: "f32[3, 1]"):
375            view: "f32[3]" = l_x_.view(3);  l_x_ = None
376            add: "f32[3]" = view + 0.5;  view = None
377            return (add,)
378""",
379            )
380        else:
381            self.assertExpectedInline(
382                actual_graph,
383                """\
384class GraphModule(torch.nn.Module):
385    def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
386        l_x_ = L_x_
387
388        wrap_body_0 = self.wrap_body_0
389        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, s0);  wrap_body_0 = l_x_ = s0 = None
390        getitem: "f32[s0]" = wrap[0];  wrap = None
391        return (getitem,)
392
393    class wrap_body_0(torch.nn.Module):
394        def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"):
395            view: "f32[s0]" = l_x_.view(size);  l_x_ = size = None
396            add: "f32[s0]" = view + 0.5;  view = None
397            return (add,)
398""",
399            )
400
401    def test_wrap_pytree_kwargs(self):
402        def f(x, y, z):
403            def fn(*, x, y, z):
404                z1, z2 = z
405                return (x * 2) + y + z1
406
407            return wrap(fn, x=x, y=y, z=z)
408
409        x = torch.randn(3)
410        y = torch.randn(3, 3)
411
412        def my_args_generator(t):
413            yield t
414            x1 = t[0] + 0.1
415            y1 = t[1] + 0.1
416            yield (x1, y1, (x1, y1))
417            x2 = t[0] + 0.2
418            y2 = t[0] + 0.2
419            yield (x2, y2, (x2, y2))
420
421        self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), 3)
422
423    def test_wrap_pytree_args_not_const_symint_tensor(self):
424        class MyClass:
425            def __init__(self, x):
426                self.val = x
427
428        def f(x, y):
429            return wrap(lambda z: z[0].sin() * z[1].val.cos(), (x, y))
430
431        x = torch.tensor(1.2)
432        y = MyClass(torch.tensor(3.4))
433        self._test_wrap_simple(f, [(x, y)], 3)
434
435    def test_capture_constants(self):
436        x = torch.randn(3, 3)
437        y = 4.0
438
439        def fn(x, y, z):
440            if z:
441                return x + y
442            return x * y
443
444        def f(x, y, z):
445            return wrap(fn, x, y, z)
446
447        args = (x, 4.0, None)
448        opt_f = torch.compile(f, fullgraph=True, backend=CompileCounter())
449        expected = f(*args)
450        result = opt_f(*args)
451        self.assertEqual(result, expected)
452
453        # Ensure that we recompile here
454        args = (x, 5.0, None)
455        expected = f(*args)
456        result = opt_f(*args)
457        self.assertEqual(result, expected)
458
459    def test_capture_untracked_global_nested(self):
460        backend = EagerAndRecordGraphs()
461        cnt = CompileCounterWithBackend(backend)
462
463        @torch.compile(backend=cnt, fullgraph=True)
464        def f(x):
465            return wrap(lambda x: wrap(lambda x: x + global_var, x), x)
466
467        x = torch.randn(3)
468        result = f(x)
469
470        self.assertEqual(result, x + global_var)
471        self.assertEqual(cnt.frame_count, 1)
472        self.assertEqual(cnt.op_count, 2)
473
474        self.assertEqual(len(backend.graphs), 1)
475        wrap_node = find_first_node(backend.graphs[0], wrap)
476        self.assertTrue(len(wrap_node.args), 3)
477
478        body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
479        self.assertEqual(op_count(body_function), 2)
480        inner_wrap_node = find_first_node(body_function, wrap)
481        self.assertTrue(len(inner_wrap_node.args), 3)
482
483    def test_capture_untracked_nonlocal(self):
484        x = torch.randn(3, 3)
485        y = torch.randn(3, 3)
486
487        def f(x, y):
488            def g(x):
489                return wrap(lambda x: x + y, x)
490
491            self._test_wrap_simple(g, default_args_generator((x,)), 3)
492            return g(x)
493
494        f(x, y)
495
496    def test_capture_tracked(self):
497        x = torch.randn(3, 3)
498        y = torch.randn(3, 3)
499
500        def f(x, y):
501            return wrap(lambda x: x + y, x)
502
503        self._test_wrap_simple(f, default_args_generator((x, y)), 3)
504
505    def test_capture_tracked_nested(self):
506        x = torch.randn(3, 3)
507        y = torch.randn(3, 3)
508
509        def f(x, y):
510            return wrap(lambda x: wrap(lambda x: x + y, x), x)
511
512        self._test_wrap_simple(f, default_args_generator((x, y)), 3)
513
514    def test_inlined_functions(self):
515        def g(x, y):
516            return x + y
517
518        def f(x, y):
519            return wrap(lambda x: g(x, y), x)
520
521        x = torch.randn(3, 3)
522        y = torch.randn(3, 3)
523        self._test_wrap_simple(f, default_args_generator((x, y)), 3)
524
525    def test_same_freevar_twice(self):
526        free = torch.randn(3)
527
528        def g(x):
529            y = free.sin()
530            z = free.cos()
531            return y, z
532
533        def f(x):
534            return wrap(g, x)
535
536        x = torch.randn(3)
537
538        # Since, `x` is unused, we don't lift it to
539        # be the input.
540        self._test_wrap_simple(f, default_args_generator((x,)), 2, 3)
541
542    def test_register_subclass(self):
543        from torch._higher_order_ops.cond import cond_op
544        from torch.testing._internal.two_tensor import TwoTensor
545
546        a = torch.tensor([1.0, 0.0, 1.0])
547        b = torch.randn(3)
548        t = TwoTensor(a, b)
549        with self.assertRaisesRegex(
550            NotImplementedError,
551            "no rule registered for HOP cond and subclass .*TwoTensor'>",
552        ):
553            res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
554
555        called = 0
556
557        # Using cond.py_impl
558        @cond_op.py_impl(TwoTensor)
559        def _(pred, true_fn, false_fn, operands):
560            nonlocal called
561            called += 1
562            assert len(operands) == 1
563            a = cond_op(pred, true_fn, false_fn, (operands[0].a,))
564            b = cond_op(pred, true_fn, false_fn, (operands[0].b,))
565            return TwoTensor(a, b)
566
567        res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
568        self.assertEqual(res.a, torch.sin(a))
569        self.assertEqual(res.b, torch.sin(b))
570        self.assertEqual(called, 1)
571
572    def test_register_mode(self):
573        from torch._higher_order_ops.cond import cond_op
574
575        torch_dispatch_called = 0
576
577        class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
578            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
579                nonlocal torch_dispatch_called
580                torch_dispatch_called += 1
581                return func(*args, **kwargs)
582
583        a = torch.tensor([1.0, 0.1, 1.0])
584        pred = a.sum() > 0
585        with self.assertRaisesRegex(
586            NotImplementedError,
587            "no rule registered for HOP cond and mode .*MyMode",
588        ):
589            with MyMode():
590                res = cond_op(pred, torch.sin, torch.cos, (a,))
591
592        py_impl_called = 0
593
594        # Using cond.py_impl
595        @cond_op.py_impl(MyMode)
596        def _(mode, pred, true_fn, false_fn, operands):
597            nonlocal py_impl_called
598            py_impl_called += 1
599            return cond_op(pred, true_fn, false_fn, operands)
600
601        a = torch.tensor([1.0, 0.1, 1.0])
602        pred = a.sum() > 0
603        with MyMode():
604            res = cond_op(pred, torch.sin, torch.cos, (a,))
605        self.assertEqual(res, a.sin())
606
607    def test_capture_value_created_in_subgraph(self):
608        backend = EagerAndRecordGraphs()
609        cnt = CompileCounterWithBackend(backend)
610
611        x = torch.randn(3, 3)
612        y = torch.randn(3, 3)
613
614        def inner(x, y):
615            z = x + y
616            return wrap(lambda x: wrap(lambda x: x + z, x), x)
617
618        @torch.compile(backend=cnt, fullgraph=True)
619        def f(x, y):
620            return wrap(inner, x, y)
621
622        result = f(x, y)
623
624        self.assertEqual(result, x + y + x)
625        self.assertEqual(cnt.frame_count, 1)
626        self.assertEqual(cnt.op_count, 2)
627        self.assertEqual(len(backend.graphs), 1)
628
629        # No changes to args of outer wrap
630        gm = backend.graphs[0]
631        wrap_node = find_first_node(gm, wrap)
632        self.assertTrue(len(wrap_node.args), 3)
633
634        # z was lifted to arg of inner wrap
635        body_function = getattr(gm, wrap_node.args[0].name)
636        # addition + wrap + getitem
637        self.assertEqual(op_count(body_function), 3)
638        inner_wrap_node = find_first_node(body_function, wrap)
639        self.assertTrue(len(inner_wrap_node.args), 3)
640
641        # Innermost body function: z was also lifted to arg
642        body_function = getattr(body_function, inner_wrap_node.args[0].name)
643        self.assertEqual(op_count(body_function), 2)
644        inner_wrap_node = find_first_node(body_function, wrap)
645        self.assertTrue(len(inner_wrap_node.args), 3)
646
647    def test_side_effect_set_new_attr_global_obj(self):
648        def setup():
649            global global_obj
650            global_obj = Obj()
651
652        def f(x):
653            def h(x):
654                def g(x):
655                    global_obj.foo = x + 1
656                    return x.clone()
657
658                y = wrap(g, x)
659                return y + global_obj.foo
660
661            return h(x)
662
663        x = torch.zeros([])
664        self._assert_wrap_fallback(f, (x,), setup=setup)
665
666    def test_side_effect_set_existing_attr_global_obj(self):
667        def setup():
668            global global_obj
669            global_obj = Obj()
670            global_obj.foo = nn.Parameter(torch.tensor(4.0))
671
672        def f(x):
673            def h(x):
674                def g(x):
675                    global_obj.foo = x + 1
676                    return x.clone()
677
678                y = wrap(g, x)
679                return y + global_obj.foo
680
681            return h(x)
682
683        x = torch.zeros([])
684        self._assert_wrap_fallback(f, (x,), setup=setup)
685
686    def test_side_effect_del_existing_attr_global_obj(self):
687        def setup():
688            global global_obj
689            global_obj = Obj()
690            global_obj.foo = torch.tensor(4.0)
691
692        def f(x):
693            def h(x):
694                def g(x):
695                    del global_obj.foo
696                    return x.clone()
697
698                y = wrap(g, x)
699                return y
700
701            return h(x)
702
703        x = torch.zeros([])
704        self._assert_wrap_fallback(f, (x,), setup=setup)
705
706    def test_side_effect_set_new_attr_global_module(self):
707        def setup():
708            global global_module
709            global_module = MyModule()
710
711        def h(x):
712            def g(x):
713                global_module.foo = nn.Parameter(x + 1)
714                return x.clone()
715
716            y = wrap(g, x)
717            return y + global_module.foo
718
719        x = torch.zeros([])
720        self._assert_wrap_fallback(h, (x,), setup=setup)
721
722    def test_side_effect_set_existing_attr_global_module(self):
723        def setup():
724            global global_module
725            global_module = MyModule()
726
727        def h(x):
728            def g(x):
729                global_module.existing = nn.Parameter(torch.tensor(4.0))
730                return global_module(x)
731
732            y = wrap(g, x)
733            return y
734
735        x = torch.zeros([])
736        self._assert_wrap_fallback(h, (x,), setup=setup)
737
738    def test_side_effect_del_existing_attr_global_module(self):
739        def setup():
740            global global_module
741            global_module = MyModule()
742
743        def h(x):
744            def g(x):
745                del global_module.existing
746                return x.clone()
747
748            y = wrap(g, x)
749            return y
750
751        x = torch.zeros([])
752        self._assert_wrap_fallback(h, (x,), setup=setup)
753
754    def test_side_effect_mutate_global_num(self):
755        def setup():
756            global global_num
757            global_num = 3.14
758
759        def f(x):
760            def g(x):
761                global global_num
762                global_num = global_num + 1
763                return x + global_num
764
765            y = wrap(g, x)
766            return y + global_num
767
768        x = torch.zeros([])
769        self._assert_wrap_fallback(f, (x,), setup=setup)
770
771    def test_side_effect_mutate_global_num_builtin(self):
772        def setup():
773            global global_num
774            global_num = 3.14
775
776        def f(x):
777            def g(x):
778                global global_num
779                global_num += 1
780                return x + global_num
781
782            y = wrap(g, x)
783            return y + global_num
784
785        x = torch.zeros([])
786        self._assert_wrap_fallback(f, (x,), setup=setup)
787
788    def test_side_effect_mutate_global_tensor(self):
789        def setup():
790            global global_var
791            global_var = torch.ones(3)
792
793        def f(x):
794            def g(x):
795                global global_var
796                global_var = global_var + 1
797                return x + global_var
798
799            y = wrap(g, x)
800            return y + global_var
801
802        x = torch.zeros([])
803        self._assert_wrap_fallback(f, (x,), setup=setup)
804
805    def test_side_effect_mutate_global_tensor_builtin(self):
806        def setup():
807            global global_var
808            global_var = torch.ones(3)
809
810        def f(x):
811            def g(x):
812                global global_var
813                global_var += 1
814                return x + global_var
815
816            y = wrap(g, x)
817            return y + global_var
818
819        x = torch.zeros([])
820        self._assert_wrap_fallback(f, (x,), setup=setup)
821
822    def test_side_effect_mutate_global_list(self):
823        def setup():
824            global global_list
825            global_list = []
826
827        def f(x):
828            def g(x):
829                val = x + 1
830                global_list.append(val)
831                return global_list[-1]
832
833            y = wrap(g, x)
834            z = y + global_list[-1]
835            return z
836
837        x = torch.zeros([])
838        self._assert_wrap_fallback(f, (x,), setup=setup)
839
840    def test_side_effect_mutate_nonlocal_num(self):
841        def f(x):
842            def h(x):
843                val = 1
844
845                def g(x):
846                    nonlocal val
847                    val = val + 1
848                    return x + val
849
850                y = wrap(g, x)
851                z = y + val
852                return z
853
854            return h(x)
855
856        x = torch.zeros([])
857        self._assert_wrap_fallback(f, (x,))
858
859    def test_side_effect_set_new_attr_nonlocal_obj(self):
860        def f(x):
861            def h(x):
862                obj = Obj()
863
864                def g(x):
865                    obj.val = x.dim()
866                    return x.clone()
867
868                y = wrap(g, x)
869                z = y + obj.val
870                return z
871
872            return h(x)
873
874        x = torch.zeros([])
875        self._assert_wrap_fallback(f, (x,))
876
877    def test_side_effect_set_existing_attr_nonlocal_obj(self):
878        def f(x):
879            def h(x):
880                obj = Obj()
881                obj.val = 3
882
883                def g(x):
884                    obj.val = x.dim()
885                    return x.clone()
886
887                y = wrap(g, x)
888                z = y + obj.val
889                return z
890
891            return h(x)
892
893        x = torch.zeros([])
894        self._assert_wrap_fallback(f, (x,))
895
896    def test_side_effect_del_existing_attr_nonlocal_obj(self):
897        def f(x):
898            def h(x):
899                obj = Obj()
900                obj.val = 3
901
902                def g(x):
903                    del obj.val
904                    return x.clone()
905
906                y = wrap(g, x)
907                return y
908
909            return h(x)
910
911        x = torch.zeros([])
912        self._assert_wrap_fallback(f, (x,))
913
914    def test_side_effect_set_new_attr_nonlocal_module(self):
915        def h(x):
916            obj = MyModule()
917
918            def g(x):
919                obj.val = x.dim()
920                return x.clone()
921
922            y = wrap(g, x)
923            z = y + obj.val
924            return z
925
926        x = torch.zeros([])
927        self._assert_wrap_fallback(h, (x,))
928
929    def test_side_effect_set_existing_attr_nonlocal_module(self):
930        def h(x):
931            obj = MyModule()
932
933            def g(x):
934                obj.existing = nn.Parameter(torch.tensor(3.14))
935                return obj(x)
936
937            y = wrap(g, x)
938            return y
939
940        x = torch.zeros([])
941        self._assert_wrap_fallback(h, (x,))
942
943    def test_side_effect_del_existing_attr_nonlocal_module(self):
944        def h(x):
945            obj = MyModule()
946
947            def g(x):
948                del obj.existing
949                return x.clone()
950
951            y = wrap(g, x)
952            return y
953
954        x = torch.zeros([])
955        self._assert_wrap_fallback(h, (x,))
956
957    def test_side_effect_mutate_nonlocal_tensor(self):
958        def f(x):
959            def h(x):
960                val = torch.tensor(1.0)
961
962                def g(x):
963                    nonlocal val
964                    val = val + 1
965                    return x + val
966
967                y = wrap(g, x)
968                z = y + val
969                return z
970
971            return h(x)
972
973        x = torch.zeros([])
974        self._assert_wrap_fallback(f, (x,))
975
976    def test_side_effect_mutate_nonlocal_num_builtin(self):
977        def f(x):
978            def h(x):
979                val = 1
980
981                def g(x):
982                    nonlocal val
983                    val += 1
984                    return x + val
985
986                y = wrap(g, x)
987                z = y + val
988                return z
989
990            return h(x)
991
992        x = torch.zeros([])
993        self._assert_wrap_fallback(f, (x,))
994
995    def test_side_effect_mutate_nonlocal_tensor_builtin(self):
996        def f(x):
997            def h(x):
998                val = torch.tensor(1.0)
999
1000                def g(x):
1001                    nonlocal val
1002                    val += 1
1003                    return x + val
1004
1005                y = wrap(g, x)
1006                z = y + val
1007                return z
1008
1009            return h(x)
1010
1011        x = torch.zeros([])
1012        self._assert_wrap_fallback(f, (x,))
1013
1014    def test_side_effect_nonlocal_list_append_graph_break(self):
1015        def g(x):
1016            y = []
1017
1018            def f(k):
1019                m = k + 1
1020                y.append(m)
1021                return k
1022
1023            wrap(f, x)
1024            return y[0]
1025
1026        x = torch.randn(3, 3)
1027        self._assert_wrap_fallback(g, (x,))
1028
1029    def test_side_effect_nested_nonlocal_list_append_graph_break(self):
1030        def g(x):
1031            def h(x):
1032                y = []
1033
1034                def f(k):
1035                    m = k + 1
1036                    y.append(m)
1037                    return k
1038
1039                wrap(f, x)
1040                return y[0]
1041
1042            return h(x)
1043
1044        x = torch.randn(3, 3)
1045        self._assert_wrap_fallback(g, (x,))
1046
1047    def test_side_effect_local_list_append_no_graph_break(self):
1048        def g(x):
1049            def f(k):
1050                y = []
1051                y.append(k + 1)
1052                return y[0]
1053
1054            return wrap(f, x)
1055
1056        x = torch.randn(3, 3)
1057        self._test_wrap_simple(g, default_args_generator((x,)), 2)
1058
1059    def test_wrap_kwarg(self):
1060        def f(x, y):
1061            return wrap(lambda x, y: x + y, x, y=y)
1062
1063        x = torch.randn(3)
1064        y = torch.randn(3, 3)
1065        self._test_wrap_simple(f, default_args_generator((x, y)), 3)
1066
1067    def test_wrap_kwarg_int(self):
1068        def f(x, y):
1069            return wrap(lambda x, y: x + y, x, y=y)
1070
1071        x = torch.randn(3)
1072        y = 8
1073
1074        self._test_wrap_simple(
1075            f, default_args_generator((x, y)), ifdynstaticdefault(2, 3)
1076        )
1077
1078    def test_wrap_all_kwarg(self):
1079        def f(y, x):
1080            return wrap(lambda x, y: (x * 2) + y, x=x, y=y)
1081
1082        x = torch.randn(3)
1083        y = torch.randn(3, 3)
1084
1085        self._test_wrap_simple(f, default_args_generator((x, y)), 3)
1086
1087    def test_wrap_kwarg_only(self):
1088        def f(x, y):
1089            def fn(*, x, y):
1090                return (x * 2) + y
1091
1092            return wrap(fn, x=x, y=y)
1093
1094        x = torch.randn(3)
1095        y = torch.randn(3, 3)
1096
1097        self._test_wrap_simple(f, default_args_generator((x, y)), 3)
1098
1099    def test_wrap_kwarg_default(self):
1100        def f(x, y):
1101            def fn(*, x, y, z=8):
1102                return (x * 2) + y + z
1103
1104            return wrap(fn, x=x, y=y)
1105
1106        x = torch.randn(3)
1107        y = torch.randn(3, 3)
1108
1109        self._test_wrap_simple(f, default_args_generator((x, y)), 3)
1110
1111    def test_wrap_kwarg_default_if_branch(self):
1112        def f(x, y):
1113            def fn(*, x, y, z=None):
1114                if z is None:
1115                    return (x * 2) + y
1116                else:
1117                    return 2 * x
1118
1119            return wrap(fn, x=x, y=y)
1120
1121        x = torch.randn(3)
1122        y = torch.randn(3, 3)
1123
1124        self._test_wrap_simple(f, default_args_generator((x, y)), 3)
1125
1126    def test_wrap_kwarg_recompile(self):
1127        def f(x, y, z=None):
1128            def fn(*, x, y, z=None):
1129                if z is None:
1130                    return (x * 2) + y
1131                else:
1132                    return 2 * x
1133
1134            return wrap(fn, x=x, y=y, z=z)
1135
1136        x = torch.randn(3)
1137        y = torch.randn(3, 3)
1138
1139        counters.clear()
1140        opt = torch.compile(f, backend="eager", fullgraph=True)
1141        opt(x, y)
1142        self.assertEqual(counters["stats"]["calls_captured"], 2)
1143
1144        # verify that we `don't` recompile
1145        opt(x, y)
1146        self.assertEqual(counters["stats"]["calls_captured"], 2)
1147
1148        output = opt(x, y, 8)
1149        self.assertEqual(counters["stats"]["calls_captured"], 4)
1150        self.assertEqual(output, 2 * x)
1151
1152    def test_wrap_kwarg_default_else_branch(self):
1153        def f(x, y, z):
1154            def fn(*, x, y, z=None):
1155                if z is None:
1156                    return (x * 2) + y
1157                else:
1158                    return 2 * x
1159
1160            return wrap(fn, x=x, y=y, z=z)
1161
1162        x = torch.randn(3)
1163        y = torch.randn(3, 3)
1164
1165        self._test_wrap_simple(f, default_args_generator((x, y, 8)), 2)
1166
1167    def test_map_subgraph_name_is_valid(self):
1168        backend = EagerAndRecordGraphs()
1169        cnt = CompileCounterWithBackend(backend)
1170
1171        xs = torch.randn(2, 3, 3)
1172        y = torch.randn(3)
1173
1174        def map_f(xs, y):
1175            def inner(x, y):
1176                def inner2(x, y):
1177                    return x + y
1178
1179                return control_flow.map(inner2, x, y)
1180
1181            return control_flow.map(inner, xs, y)
1182
1183        graphs = self._check_map_graph_and_extract(map_f, (xs, y))
1184        if graphs:
1185            graph, body_graph = graphs
1186            self.assertExpectedInline(
1187                graph,
1188                """\
1189def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
1190    l_xs_ = L_xs_
1191    l_y_ = L_y_
1192    map_body_1 = self.map_body_1
1193    map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]);  map_body_1 = l_xs_ = l_y_ = None
1194    getitem_1 = map_impl[0];  map_impl = None
1195    return (getitem_1,)""",
1196            )
1197            self.assertExpectedInline(
1198                body_graph,
1199                """\
1200def forward(self, child, l_y_):
1201    child_1 = child[0];  child_1 = None
1202    map_body_0 = self.map_body_0
1203    map_impl = torch.ops.higher_order.map_impl(map_body_0, [child], [l_y_]);  map_body_0 = child = l_y_ = None
1204    getitem_1 = map_impl[0];  map_impl = None
1205    return (getitem_1,)""",
1206            )
1207
1208    def test_map_multi_return(self):
1209        cnt = CompileCounter()
1210
1211        def f(x):
1212            return control_flow.map(lambda x: (x.sin(), x.sin()), x)
1213
1214        x = torch.randn(3)
1215        graphs = self._check_map_graph_and_extract(f, (x,))
1216        if graphs:
1217            graph, body_graph = graphs
1218            self.assertExpectedInline(
1219                graph,
1220                """\
1221def forward(self, L_x_ : torch.Tensor):
1222    l_x_ = L_x_
1223    map_body_0 = self.map_body_0
1224    map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []);  map_body_0 = l_x_ = None
1225    getitem_1 = map_impl[0]
1226    getitem_2 = map_impl[1];  map_impl = None
1227    return (getitem_1, getitem_2)""",
1228            )
1229            self.assertExpectedInline(
1230                body_graph,
1231                """\
1232def forward(self, child):
1233    child_1 = child.sin()
1234    child_2 = child.sin();  child = None
1235    return (child_1, child_2)""",
1236            )
1237
1238    def test_map_pytree_return(self):
1239        cnt = CompileCounter()
1240
1241        def _construct_pytree(a):
1242            return (a, [[[a]]], a, (a, (a,), a), {"a": a})
1243
1244        def f(x):
1245            def inner_f(xs):
1246                return _construct_pytree(xs)
1247
1248            return control_flow.map(inner_f, x)
1249
1250        x = torch.randn(3)
1251        graphs = self._check_map_graph_and_extract(f, (x,))
1252        if graphs:
1253            graph, body_graph = graphs
1254            self.assertExpectedInline(
1255                graph,
1256                """\
1257def forward(self, L_x_ : torch.Tensor):
1258    l_x_ = L_x_
1259    map_body_0 = self.map_body_0
1260    map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []);  map_body_0 = l_x_ = None
1261    getitem_1 = map_impl[0]
1262    getitem_2 = map_impl[1]
1263    getitem_3 = map_impl[2]
1264    getitem_4 = map_impl[3]
1265    getitem_5 = map_impl[4]
1266    getitem_6 = map_impl[5]
1267    getitem_7 = map_impl[6];  map_impl = None
1268    return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7)""",
1269            )
1270            self.assertExpectedInline(
1271                body_graph,
1272                """\
1273def forward(self, child):
1274    return (child, child, child, child, child, child, child)""",
1275            )
1276
1277    def test_map_kwargs(self):
1278        cnt = CompileCounter()
1279
1280        @torch.compile(backend=cnt)
1281        def f(x):
1282            return control_flow.map(lambda x: x.sin(), x=x)
1283
1284        x = torch.randn(3)
1285        self.assertRaises(TypeError, lambda: f(x))
1286        self.assertEqual(cnt.frame_count, 0)
1287
1288    def test_map_symint_input(self):
1289        backend = EagerAndRecordGraphs()
1290        cnt = CompileCounterWithBackend(backend)
1291
1292        def fn(x, y):
1293            def inner(x, y):
1294                return torch.sin(x + y)
1295
1296            return control_flow.map(inner, x, y.size(0))
1297
1298        x = torch.randn(3, 1)
1299        y = torch.randn(3, 1)
1300        graphs = self._check_map_graph_and_extract(fn, (x, y))
1301        if graphs:
1302            graph, body_graph = graphs
1303            self.assertExpectedInline(
1304                graph,
1305                """\
1306def forward(self, L_x_ : torch.Tensor):
1307    l_x_ = L_x_
1308    map_body_0 = self.map_body_0
1309    map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]);  map_body_0 = l_x_ = None
1310    getitem_1 = map_impl[0];  map_impl = None
1311    return (getitem_1,)""",
1312            )
1313            self.assertExpectedInline(
1314                body_graph,
1315                """\
1316def forward(self, child, const_unused):
1317    add = child + 3;  child = None
1318    sin = torch.sin(add);  add = None
1319    return (sin,)""",
1320            )
1321
1322    def test_map_lowers_to_graph(self):
1323        backend = EagerAndRecordGraphs()
1324        cnt = CompileCounterWithBackend(backend)
1325
1326        def fn(x, y):
1327            def inner(x, y):
1328                return torch.sin(x + y)
1329
1330            return control_flow.map(inner, x, y.size(0))
1331
1332        x = torch.randn(3, 1)
1333        y = torch.randn(3, 1)
1334        graphs = self._check_map_graph_and_extract(fn, (x, y))
1335        if graphs:
1336            graph, body_graph = graphs
1337            self.assertExpectedInline(
1338                graph,
1339                """\
1340def forward(self, L_x_ : torch.Tensor):
1341    l_x_ = L_x_
1342    map_body_0 = self.map_body_0
1343    map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]);  map_body_0 = l_x_ = None
1344    getitem_1 = map_impl[0];  map_impl = None
1345    return (getitem_1,)""",
1346            )
1347            self.assertExpectedInline(
1348                body_graph,
1349                """\
1350def forward(self, child, const_unused):
1351    add = child + 3;  child = None
1352    sin = torch.sin(add);  add = None
1353    return (sin,)""",
1354            )
1355
1356    def test_map_example_value_metadata_consistent_with_eager(self):
1357        from torch._higher_order_ops.map import map_dense
1358
1359        backend = EagerAndRecordGraphs()
1360
1361        def inner(x):
1362            return x.sin(), x.cos().T, x.sin().view(-1)
1363
1364        rand_44 = torch.randn(4, 4)
1365        inps = [
1366            torch.randn(3),
1367            torch.randn(3, 4),
1368            torch.randn(3, 4, 5, requires_grad=True),
1369            torch.randn(3, 4, 5, requires_grad=True).permute((2, 0, 1)),
1370            torch.randn(3, 4, 5, requires_grad=True).detach(),
1371            torch.randn(3, 4, 5, requires_grad=True).narrow(1, 1, 2),
1372            rand_44.T,
1373            rand_44[::2],
1374            rand_44[::2, ::2],
1375            rand_44[1::3, 1::3],
1376            rand_44[1::3, 1::2].T,
1377            rand_44.unsqueeze(1),
1378            rand_44.squeeze(0),
1379            rand_44.reshape(2, 8),
1380        ]
1381        for x in inps:
1382            compiled_ret = torch.compile(
1383                control_flow.map, backend=backend, fullgraph=True
1384            )(inner, x)
1385            eager_sin, eager_transpose, eager_view = map_dense(inner, (x,), ())
1386
1387            map_node = next(
1388                node
1389                for node in backend.graphs[0].graph.nodes
1390                if node.op == "call_function" and "map" in node.name
1391            )
1392
1393            fake_sin, fake_transpose, fake_view = map_node.meta["example_value"]
1394
1395            def _check_size_stride_contiguous(x, y):
1396                self.assertEqual(y.size(), x.size())
1397                self.assertEqual(y.stride(), x.stride())
1398                self.assertEqual(y.requires_grad, x.requires_grad)
1399                self.assertEqual(x.is_contiguous(), True)
1400                self.assertEqual(y.is_contiguous(), True)
1401
1402            _check_size_stride_contiguous(eager_sin, fake_sin)
1403            _check_size_stride_contiguous(eager_transpose, fake_transpose)
1404            _check_size_stride_contiguous(eager_view, fake_view)
1405
1406            torch._dynamo.reset()
1407            backend.graphs.clear()
1408
1409    def test_cond_subgraph_name_is_valid(self):
1410        backend = EagerAndRecordGraphs()
1411        cnt = CompileCounterWithBackend(backend)
1412
1413        pred = torch.tensor(True)
1414        pred2 = torch.tensor(False)
1415        xs = torch.randn(2, 3, 3)
1416        y = torch.randn(3, 3)
1417
1418        @torch.compile(backend=cnt, fullgraph=True)
1419        def cond_f(pred, pred2, x, y):
1420            def true_fn(pred2, x, y):
1421                return x + y
1422
1423            def false_fn(pred2, x, y):
1424                def true_fn2(x, y):
1425                    return x.sin() - y.cos()
1426
1427                def false_fn2(x, y):
1428                    return x.cos() - y.sin()
1429
1430                return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])
1431
1432            return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])
1433
1434        result = cond_f(pred, pred2, xs, y)
1435        self.assertEqual(result, xs + y)
1436
1437        cond_gm = backend.graphs[0]
1438        name_set = set()
1439        name_set.update(name for name, _ in cond_gm.named_modules())
1440        self.assertEqual(
1441            name_set,
1442            {
1443                "",
1444                "cond_true_1",
1445                "cond_false_1",
1446                "cond_false_1.cond_false_0",
1447                "cond_false_1.cond_true_0",
1448            },
1449        )
1450
1451    @torch._dynamo.config.patch(
1452        assume_static_by_default=True,
1453        dynamic_shapes=True,
1454    )
1455    def test_cond_graph_break_in_one_branch(self):
1456        backend = EagerAndRecordGraphs()
1457        cnt = CompileCounterWithBackend(backend)
1458
1459        class Foo(torch.nn.Module):
1460            def __init__(self) -> None:
1461                super().__init__()
1462                self.buffer = torch.nn.Buffer(torch.ones(6, 4))
1463
1464            def forward(self, x):
1465                def true_fn(x):
1466                    self.buffer += 1
1467                    return self.buffer.sum() + x.sum()
1468
1469                def false_fn(x):
1470                    return (x - 1).sum()
1471
1472                return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x])
1473
1474        mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
1475        mod_for_eager = Foo()
1476
1477        with self.assertRaisesRegex(
1478            torch._dynamo.exc.UncapturedHigherOrderOpError,
1479            r"Cond doesn't work unless it is captured completely with torch.compile",
1480        ):
1481            mod_for_eager(torch.ones(6, 4))
1482
1483        with self.assertRaisesRegex(
1484            torch._dynamo.exc.UncapturedHigherOrderOpError,
1485            r"Cond doesn't work unless it is captured completely with torch.compile",
1486        ):
1487            mod_for_compile(torch.ones(3, 4))
1488
1489    def test_cond_free_variable_in_both_branches(self):
1490        backend = EagerAndRecordGraphs()
1491        cnt = CompileCounterWithBackend(backend)
1492
1493        z = torch.ones(4, 4)
1494
1495        class Foo(torch.nn.Module):
1496            def __init__(self) -> None:
1497                super().__init__()
1498                self.buffer = torch.nn.Buffer(torch.ones(6, 4))
1499
1500            def forward(self, x, y):
1501                def true_fn(x):
1502                    return x.sum() + self.buffer.sum() + z.sum()
1503
1504                def false_fn(x):
1505                    return x.sum() - z.sum() - self.buffer.sum()
1506
1507                return control_flow.cond(y, true_fn, false_fn, [x])
1508
1509        mod_for_compile = torch.compile(
1510            Foo(), backend=cnt, dynamic=True, fullgraph=True
1511        )
1512        mod_for_eager = Foo()
1513
1514        self.assertEqual(
1515            mod_for_compile(torch.tensor(True), torch.tensor(5)),
1516            mod_for_eager(torch.tensor(True), torch.tensor(5)),
1517        )
1518
1519        for node in backend.graphs[0].graph.nodes:
1520            if (
1521                node.op == "call_function"
1522                and node.target == torch.ops.higher_order.cond
1523            ):
1524                _, _, _, operands = node.args
1525                # Each branch takes 3 inputs (buffer, x, z)
1526                self.assertEqual(len(operands), 3)
1527            if node.op == "get_attr":
1528                if str(node.target) in ("cond_true_0, cond_false_0"):
1529                    num_placeholders = len(
1530                        [
1531                            node
1532                            for node in getattr(
1533                                backend.graphs[0], str(node.target)
1534                            ).graph.nodes
1535                            if node.op == "placeholder"
1536                        ]
1537                    )
1538                    self.assertEqual(num_placeholders, 3)
1539
1540    def _check_cond_graph_and_extract(self, fn, args):
1541        backend = EagerAndRecordGraphs()
1542        cnt = CompileCounterWithBackend(backend)
1543        out = torch.compile(fn, backend=cnt, fullgraph=True)(*args)
1544        self.assertEqual(out, fn(*args))
1545        self.assertEqual(cnt.frame_count, 1)
1546        self.assertEqual(len(backend.graphs), 1)
1547
1548        # Dynamic shapes produce a slightly different graph.
1549        if check_dynamic_shape_capture():
1550            return
1551
1552        gm = backend.graphs[0]
1553        graph = gm.code.strip()
1554        true_graph = gm.cond_true_0.code.strip()
1555        false_graph = gm.cond_false_0.code.strip()
1556        return (graph, true_graph, false_graph)
1557
1558    def _check_map_graph_and_extract(self, fn, args):
1559        backend = EagerAndRecordGraphs()
1560        cnt = CompileCounterWithBackend(backend)
1561        out = torch.compile(fn, backend=cnt, fullgraph=True)(*args)
1562        self.assertEqual(out, fn(*args))
1563        self.assertEqual(cnt.frame_count, 1)
1564        self.assertEqual(len(backend.graphs), 1)
1565
1566        # Dynamic shapes produce a slightly different graph.
1567        if check_dynamic_shape_capture():
1568            return
1569
1570        gm = backend.graphs[0]
1571        graph = gm.code.strip()
1572        subgraphs = []
1573        for module_name in gm._modules.keys():
1574            subgraphs.append(getattr(gm, module_name).code.strip())
1575        return (graph, *subgraphs)
1576
1577    def test_cond_branches_no_arguments(self):
1578        def fn(x):
1579            def true_fn():
1580                return torch.sin(x)
1581
1582            def false_fn():
1583                return torch.cos(x)
1584
1585            return control_flow.cond(x.sum() > 0, true_fn, false_fn, ())
1586
1587        graphs = self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
1588        if graphs is not None:
1589            graph, true_graph, false_graph = graphs
1590            self.assertExpectedInline(
1591                graph,
1592                """\
1593def forward(self, L_x_ : torch.Tensor):
1594    l_x_ = L_x_
1595    sum_1 = l_x_.sum()
1596    gt = sum_1 > 0;  sum_1 = None
1597    cond_true_0 = self.cond_true_0
1598    cond_false_0 = self.cond_false_0
1599    cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, [l_x_]);  gt = cond_true_0 = cond_false_0 = l_x_ = None
1600    getitem = cond[0];  cond = None
1601    return (getitem,)""",
1602            )
1603            self.assertExpectedInline(
1604                true_graph,
1605                """\
1606def forward(self, l_x_):
1607    l_x__1 = l_x_
1608    sin = torch.sin(l_x__1);  l_x__1 = None
1609    return (sin,)""",
1610            )
1611            self.assertExpectedInline(
1612                false_graph,
1613                """\
1614def forward(self, l_x_):
1615    l_x__1 = l_x_
1616    cos = torch.cos(l_x__1);  l_x__1 = None
1617    return (cos,)""",
1618            )
1619
1620    def test_cond_branches_no_arguments_no_closure(self):
1621        def fn(x):
1622            def true_fn():
1623                return torch.ones(3, 4)
1624
1625            def false_fn():
1626                return torch.ones(3, 4).sin()
1627
1628            return control_flow.cond(x.sum() > 0, true_fn, false_fn, ())
1629
1630        self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
1631        graphs = self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
1632        if graphs is not None:
1633            graph, true_graph, false_graph = graphs
1634            self.assertExpectedInline(
1635                graph,
1636                """\
1637def forward(self, L_x_ : torch.Tensor):
1638    l_x_ = L_x_
1639    sum_1 = l_x_.sum();  l_x_ = None
1640    gt = sum_1 > 0;  sum_1 = None
1641    cond_true_0 = self.cond_true_0
1642    cond_false_0 = self.cond_false_0
1643    cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, []);  gt = cond_true_0 = cond_false_0 = None
1644    getitem = cond[0];  cond = None
1645    return (getitem,)""",
1646            )
1647            self.assertExpectedInline(
1648                true_graph,
1649                """\
1650def forward(self):
1651    ones = torch.ones(3, 4)
1652    return (ones,)""",
1653            )
1654            self.assertExpectedInline(
1655                false_graph,
1656                """\
1657def forward(self):
1658    ones = torch.ones(3, 4)
1659    sin = ones.sin();  ones = None
1660    return (sin,)""",
1661            )
1662
1663    def test_cond_side_effect_in_one_branches(self):
1664        backend = EagerAndRecordGraphs()
1665        cnt = CompileCounterWithBackend(backend)
1666
1667        z = [torch.ones(4, 4)]
1668
1669        class Foo(torch.nn.Module):
1670            def __init__(self) -> None:
1671                super().__init__()
1672
1673            def forward(self, y, x):
1674                def true_fn(x):
1675                    z.append(x)
1676                    z.append(x)
1677                    z.pop()
1678                    return x.sum() + z[-1].sum()
1679
1680                def false_fn(x):
1681                    return x.sum() - z[0].sum()
1682
1683                return control_flow.cond(y, true_fn, false_fn, [x])
1684
1685        mod_for_eager = Foo()
1686        mod_for_compile = torch.compile(
1687            Foo(), backend=cnt, dynamic=True, fullgraph=False
1688        )
1689        with self.assertRaisesRegex(
1690            torch._dynamo.exc.UncapturedHigherOrderOpError,
1691            r"Cond doesn't work unless it is captured completely with torch.compile",
1692        ):
1693            mod_for_eager(torch.tensor(True), torch.tensor(5))
1694
1695        with self.assertRaisesRegex(
1696            torch._dynamo.exc.UncapturedHigherOrderOpError,
1697            r"Cond doesn't work unless it is captured completely with torch.compile",
1698        ):
1699            mod_for_compile(torch.tensor(True), torch.tensor(5))
1700
1701    def test_cond_with_constant_pred(self):
1702        def test(pred, x):
1703            def true_fn(x):
1704                return x
1705
1706            def false_fn(x):
1707                return -x
1708
1709            return control_flow.cond(pred, true_fn, false_fn, [x])
1710
1711        opt_test = torch.compile(test, backend="eager")
1712        inp = torch.ones(3, 3)
1713        self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp)))
1714        self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp)))
1715
1716    def test_map_graph_break(self):
1717        backend = EagerAndRecordGraphs()
1718        cnt = CompileCounterWithBackend(backend)
1719
1720        class Module(torch.nn.Module):
1721            def __init__(self) -> None:
1722                super().__init__()
1723                self.w = torch.nn.Buffer(torch.ones(6, 4))
1724
1725            def forward(self, xs):
1726                def body(x):
1727                    self.w += 1
1728                    return x
1729
1730                return control_flow.map(body, xs)
1731
1732        mod = Module()
1733
1734        mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False)
1735        mod_for_eager = Module()
1736
1737        res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
1738        # There is graph break right when we enter body of map
1739        self.assertEqual(len(backend.graphs), 0)
1740        self.assertEqual(
1741            res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
1742        )
1743
1744    def test_map_side_effect(self):
1745        backend = EagerAndRecordGraphs()
1746        cnt = CompileCounterWithBackend(backend)
1747
1748        z = [torch.ones(6, 4)]
1749
1750        class Module(torch.nn.Module):
1751            def __init__(self) -> None:
1752                super().__init__()
1753                self.w = torch.nn.Buffer(torch.ones(6, 4))
1754
1755            def forward(self, xs):
1756                def body(x):
1757                    z.append(x)
1758                    z.append(x)
1759                    z.pop()
1760                    return x + z[-1].sum()
1761
1762                return control_flow.map(body, xs)
1763
1764        mod = Module()
1765
1766        mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False)
1767        mod_for_eager = Module()
1768
1769        res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
1770        res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
1771
1772        eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
1773        eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
1774
1775        self.assertEqual(len(backend.graphs), 0)
1776        self.assertEqual(res, eager)
1777
1778    def test_wrap_subgraph_name_is_valid(self):
1779        backend = EagerAndRecordGraphs()
1780        cnt = CompileCounterWithBackend(backend)
1781
1782        x = torch.randn(3, 3)
1783        y = torch.randn(3, 3)
1784
1785        def inner(x, y):
1786            z = x + y
1787            return wrap(lambda x: wrap(lambda x: x + z, x), x)
1788
1789        @torch.compile(backend=cnt, fullgraph=True)
1790        def f(x, y):
1791            return wrap(inner, x, y)
1792
1793        result = f(x, y)
1794
1795        self.assertEqual(result, x + y + x)
1796        wrap_gm = backend.graphs[0]
1797        names = set()
1798        names.update(mod_name for mod_name, _ in wrap_gm.named_modules())
1799        self.assertEqual(
1800            names,
1801            {
1802                "",
1803                "wrap_body_2",
1804                "wrap_body_2.wrap_body_1",
1805                "wrap_body_2.wrap_body_1.wrap_body_0",
1806            },
1807        )
1808
1809    def test_wrap_allow_local_assign_in_body_fn(self):
1810        def f(arg1, arg2):
1811            def inner_f(arg1, arg2):
1812                a = arg1
1813                b = arg2
1814                ret = []
1815                for x in a:
1816                    ret.append(x + 1)
1817                for x in b:
1818                    ret.append(x + 1)
1819                return ret
1820
1821            return wrap(inner_f, arg1, arg2)
1822
1823        x = torch.ones(3)
1824
1825        def my_args_generator():
1826            yield [x], [x.sin()]
1827            yield (x,), (x.sin(),)
1828
1829        actual_graph = self._test_wrap_simple(
1830            f,
1831            my_args_generator(),
1832            3,
1833            3,
1834            return_graph=True,
1835        )
1836
1837        # Dynamic shapes produce a slightly different graph.
1838        if check_dynamic_shape_capture():
1839            return
1840
1841        self.assertExpectedInline(
1842            actual_graph,
1843            """\
1844class GraphModule(torch.nn.Module):
1845    def forward(self, L_arg1_0_: "f32[3]", L_arg2_0_: "f32[3]"):
1846        l_arg1_0_ = L_arg1_0_
1847        l_arg2_0_ = L_arg2_0_
1848
1849        wrap_body_0 = self.wrap_body_0
1850        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_arg1_0_, l_arg2_0_);  wrap_body_0 = l_arg1_0_ = l_arg2_0_ = None
1851        getitem: "f32[3]" = wrap[0]
1852        getitem_1: "f32[3]" = wrap[1];  wrap = None
1853        return (getitem, getitem_1)
1854
1855    class wrap_body_0(torch.nn.Module):
1856        def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
1857            child: "f32[3]" = l_arg1_0_ + 1;  l_arg1_0_ = None
1858
1859            child_1: "f32[3]" = l_arg2_0_ + 1;  l_arg2_0_ = None
1860            return (child, child_1)
1861""",
1862        )
1863
1864    def test_capture_global_num(self):
1865        def f(x):
1866            return wrap(lambda x: x + global_num, x)
1867
1868        x = torch.zeros([])
1869        # Numbers don't get lifted, so args is still 2.
1870        self._test_wrap_simple(f, default_args_generator((x,)), 2)
1871
1872    def test_capture_global_num_adds_guard(self):
1873        @torch.compile(backend="eager", fullgraph=True)
1874        def f(x):
1875            return wrap(lambda x: x + global_num, x)
1876
1877        global global_num
1878        x = torch.zeros([])
1879        result = f(x)
1880        self.assertEqual(result, x + global_num)
1881
1882        global_num = torch.randn([]).item()
1883        result = f(x)
1884        self.assertEqual(result, x + global_num)
1885
1886    def test_capture_input_num(self):
1887        def f(x, y):
1888            return wrap(lambda x: x + y, x)
1889
1890        x = torch.zeros([])
1891        y = 3.14
1892        # Numbers don't get lifted, so args is still 2.
1893        self._test_wrap_simple(f, default_args_generator((x, y)), 2)
1894
1895    def test_side_effect_in_body(self):
1896        counters.clear()
1897        backend = EagerAndRecordGraphs()
1898
1899        x = torch.randn([])
1900        y = torch.randn([])
1901
1902        def inner(x):
1903            nonlocal y
1904            y = x
1905            return x.clone()
1906
1907        @torch.compile(backend=backend)
1908        def f(x):
1909            return wrap(inner, x)
1910
1911        f(x)
1912        self.assertEqual(y, x)
1913        assert_dict_matches_regex(
1914            self,
1915            dict(counters["graph_break"]),
1916            {
1917                r".*HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)": 1
1918            },
1919        )
1920
1921    def test_fallback_on_graph_break_simple(self):
1922        # In the future, there should be a per-HigherOrderOperator switch
1923        # on whether or not to fallback or raise a loud error.
1924        # For now we just fallback by default.
1925        cnt = CompileCounter()
1926        x = torch.randn([])
1927
1928        def inner(x):
1929            y = x.sin()
1930            torch._dynamo.graph_break()
1931            z = y.sin()
1932            return z
1933
1934        @torch.compile(backend=cnt)
1935        def f(x):
1936            return wrap(inner, x)
1937
1938        result = f(x)
1939        self.assertEqual(result, inner(x))
1940        self.assertEqual(cnt.frame_count, 0)
1941
1942    def test_fallback_on_graph_break_complicated(self):
1943        cnt = CompileCounter()
1944        x = torch.randn([])
1945
1946        def inner(x):
1947            y = x.sin()
1948            y = y * global_var
1949            torch._dynamo.graph_break()
1950            z = y.sin()
1951            return z
1952
1953        @torch.compile(backend=cnt)
1954        def f(x):
1955            x = x.clone()
1956            result = wrap(inner, x)
1957            return result.clone()
1958
1959        result = f(x)
1960        self.assertEqual(result, inner(x))
1961        self.assertEqual(cnt.frame_count, 2)
1962
1963    def test_modules(self):
1964        counters.clear()
1965        backend = EagerAndRecordGraphs()
1966        cnt = CompileCounterWithBackend(backend)
1967        mod = torch.nn.Linear(3, 3)
1968        x = torch.randn(3, 3)
1969
1970        @torch.compile(backend=cnt, fullgraph=True)
1971        def f(x):
1972            return wrap(lambda x: mod(x), x)
1973
1974        result = f(x)
1975
1976        self.assertEqual(result, mod(x))
1977        self.assertEqual(cnt.frame_count, 1)
1978
1979        self.assertEqual(len(backend.graphs), 1)
1980        wrap_node = find_first_node(backend.graphs[0], wrap)
1981        # 3 args - 1 for input, and other 2 for the weight and bias
1982        self.assertTrue(len(wrap_node.args), 3)
1983
1984        # Check that the linear bias and weight are getattr in the outer graph
1985        if not torch._dynamo.config.inline_inbuilt_nn_modules:
1986            self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2)
1987
1988        # Check that the inner function has one op and its a linear op
1989        body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
1990        self.assertEqual(op_count(body_function), 1)
1991        linear_node = find_first_node(body_function, torch._C._nn.linear)
1992        self.assertTrue(linear_node is not None)
1993
1994        # Check that the innermost graph does not have any params
1995        self.assertTrue(len(dict(body_function.named_parameters())) == 0)
1996        self.assertTrue(len(dict(body_function.named_children())) == 0)
1997
1998    def test_flat_list_output(self):
1999        def f(x):
2000            return wrap(lambda x: [torch.sin(x), torch.cos(x)], x)
2001
2002        x = torch.randn(3)
2003        self._test_wrap_simple(f, default_args_generator((x,)), 2, expected_opcount=3)
2004
2005    def test_fallback_on_python_primitives_output(self):
2006        counters.clear()
2007        cnt = CompileCounter()
2008
2009        @torch.compile(backend=cnt)
2010        def f(x):
2011            return wrap(lambda x: [1, torch.sin(x), 2.0], x)
2012
2013        x = torch.randn(3)
2014        result = f(x)
2015        self.assertEqual(result, [1, torch.sin(x), 2.0])
2016        self.assertEqual(cnt.frame_count, 0)
2017        assert_dict_matches_regex(
2018            self,
2019            dict(counters["graph_break"]),
2020            {".*HigherOrderOperator body's output must consist of tensors only": 1},
2021        )
2022
2023    def test_nested_tuple_output(self):
2024        def f(x):
2025            ((a, b),) = wrap(lambda x: ((x.sin(), x.cos()),), x)
2026            return a + b
2027
2028        x = torch.randn(2, 3)
2029
2030        counters.clear()
2031        graph = self._test_wrap_simple(
2032            f, default_args_generator((x,)), 2, 4, return_graph=True
2033        )
2034        self.assertEqual(len(counters["graph_break"]), 0)
2035
2036        if check_dynamic_shape_capture():
2037            return
2038
2039        self.assertExpectedInline(
2040            graph,
2041            """\
2042class GraphModule(torch.nn.Module):
2043    def forward(self, L_x_: "f32[2, 3]"):
2044        l_x_ = L_x_
2045
2046        wrap_body_0 = self.wrap_body_0
2047        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
2048        a: "f32[2, 3]" = wrap[0]
2049        b: "f32[2, 3]" = wrap[1];  wrap = None
2050
2051        add: "f32[2, 3]" = a + b;  a = b = None
2052        return (add,)
2053
2054    class wrap_body_0(torch.nn.Module):
2055        def forward(self, l_x_: "f32[2, 3]"):
2056            child: "f32[2, 3]" = l_x_.sin()
2057            child_1: "f32[2, 3]" = l_x_.cos();  l_x_ = None
2058            return (child, child_1)
2059""",
2060        )
2061
2062    def test_output_with_dict(self):
2063        def f(x):
2064            return wrap(lambda x: [{"a": -x}], x)
2065
2066        x = torch.randn(3)
2067
2068        counters.clear()
2069        graph = self._test_wrap_simple(
2070            f, default_args_generator((x,)), 2, 2, return_graph=True
2071        )
2072        self.assertEqual(len(counters["graph_break"]), 0)
2073
2074        if check_dynamic_shape_capture():
2075            return
2076
2077        self.assertExpectedInline(
2078            graph,
2079            """\
2080class GraphModule(torch.nn.Module):
2081    def forward(self, L_x_: "f32[3]"):
2082        l_x_ = L_x_
2083
2084        wrap_body_0 = self.wrap_body_0
2085        wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
2086        getitem: "f32[3]" = wrap[0];  wrap = None
2087        return (getitem,)
2088
2089    class wrap_body_0(torch.nn.Module):
2090        def forward(self, l_x_: "f32[3]"):
2091            child: "f32[3]" = -l_x_;  l_x_ = None
2092            return (child,)
2093""",
2094        )
2095
2096    def test_access_module_attr(self):
2097        counters.clear()
2098        backend = EagerAndRecordGraphs()
2099        cnt = CompileCounterWithBackend(backend)
2100        mod = torch.nn.Linear(3, 3)
2101        x = torch.randn(3, 3)
2102
2103        @torch.compile(backend=cnt, fullgraph=True)
2104        def f(x):
2105            y = mod(x)
2106            return wrap(lambda y: y - mod.bias, y)
2107
2108        result = f(x)
2109        self.assertEqual(result, mod(x) - mod.bias)
2110        self.assertEqual(cnt.frame_count, 1)
2111
2112        self.assertEqual(len(backend.graphs), 1)
2113        wrap_node = find_first_node(backend.graphs[0], wrap)
2114        self.assertTrue(len(wrap_node.args), 3)
2115
2116        # Check that the linear bias and weight are getattr in the outer graph
2117        if not torch._dynamo.config.inline_inbuilt_nn_modules:
2118            self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2)
2119
2120        # Check that the inner function has one op and its a linear op
2121        body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
2122        self.assertEqual(op_count(body_function), 1)
2123
2124        # Check that the innermost graph does not have any params
2125        self.assertTrue(len(dict(body_function.named_parameters())) == 0)
2126        self.assertTrue(len(dict(body_function.named_children())) == 0)
2127
2128    def test_make_closure(self):
2129        def f(x, y):
2130            def g(x):
2131                return x + y
2132
2133            return g(x)
2134
2135        def h(x, y):
2136            return wrap(f, x, y)
2137
2138        x = torch.randn(3, 3)
2139        y = torch.randn(3, 3)
2140        self._test_wrap_simple(h, default_args_generator((x, y)), 3)
2141
2142    def test_internal_nonlocal(self):
2143        def f(x, y):
2144            w = 1
2145
2146            def g(x):
2147                nonlocal w
2148                w = x
2149                return x
2150
2151            def h(x):
2152                nonlocal w
2153                w = w + 1
2154                return x
2155
2156            g(x)
2157            h(x)
2158            return w + y
2159
2160        def h(x, y):
2161            return wrap(f, x, y)
2162
2163        x = torch.randn(3, 3)
2164        y = torch.randn(3, 3)
2165        self._test_wrap_simple(h, default_args_generator((x, y)), 3)
2166
2167    def test_capture_numpy_number(self):
2168        import numpy as np
2169
2170        y = np.float32(1.0)
2171
2172        def f(x):
2173            return wrap(lambda x: x + y, x)
2174
2175        x = torch.randn(3)
2176        # np.number are lifted to graph inputs
2177        self._test_wrap_simple(f, default_args_generator((x,)), 3)
2178
2179    def test_freevars_as_inputs_to_wrap(self):
2180        y = torch.randn(3)
2181
2182        def f(x):
2183            return wrap(lambda x, y: x + y, x, y)
2184
2185        x = torch.randn(3)
2186        self._test_wrap_simple(f, default_args_generator((x,)), 3)
2187
2188    def test_lift_tensor_constant(self):
2189        def f(x):
2190            y = torch.tensor(1.0)
2191            return wrap(lambda x: x + y, x)
2192
2193        x = torch.randn(3)
2194        self._test_wrap_simple(f, default_args_generator((x,)), 3, expected_opcount=3)
2195
2196    def test_nested_wrap(self):
2197        class MockModule(torch.nn.Module):
2198            def __init__(self) -> None:
2199                super().__init__()
2200                self.linear = torch.nn.Linear(10, 10)
2201
2202            def forward(self, x):
2203                return self.linear(x)
2204
2205        mod = MockModule()
2206
2207        # Two levels of wrap ops
2208        def gn(x):
2209            return torch.cos(x) + wrap(mod, x)
2210
2211        def fn(x):
2212            return wrap(gn, x)
2213
2214        self._test_wrap_simple(fn, default_args_generator((torch.randn(10, 10),)), 4)
2215
2216    def test_fn_with_kwargs_in_torch_ops(self):
2217        def fn(x):
2218            return wrap(lambda z: torch.cos(input=z), x)
2219
2220        x = torch.randn(3)
2221        self._test_wrap_simple(fn, default_args_generator((x,)), 2)
2222
2223    def test_hooks(self):
2224        class ToyModel(torch.nn.Module):
2225            def __init__(self) -> None:
2226                super().__init__()
2227                self.net = torch.nn.Linear(10, 10)
2228
2229            def forward(self, x):
2230                return self.net(x)
2231
2232        model = ToyModel()
2233        forward_handles = {}
2234        activations = {}
2235
2236        def save_activations(mod, inp, out):
2237            activations[name] = inp
2238
2239        for name, module in model.named_children():
2240            forward_handles[name] = module.register_forward_hook(save_activations)
2241
2242        @torch.compile(backend="eager")
2243        def fn(x):
2244            return wrap(lambda x: model(x), x)
2245
2246        for i in range(2):
2247            # second iteration is key, hooks would have fired during aot trace
2248            # on first iter
2249            activations.clear()
2250            x = torch.randn((10, 10))
2251            pred = fn(x)
2252            loss = pred.sum()
2253            loss.backward()
2254
2255        self.assertTrue(activations.keys() == forward_handles.keys())
2256
2257    def _get_source_fn_stack(self, gm, node_names):
2258        ret = {}
2259        for mod in gm.modules():
2260            for node in mod.graph.nodes:
2261                if node.name in node_names:
2262                    actual_stack = [
2263                        name for name, _ in node.meta.get("source_fn_stack", [])
2264                    ]
2265                    ret[node.name] = actual_stack
2266        return ret
2267
2268    def test_wrap_source_fn_stack(self):
2269        class MockModule(torch.nn.Module):
2270            def __init__(self) -> None:
2271                super().__init__()
2272                self.linear = torch.nn.Linear(4, 4)
2273
2274            def forward(self, x):
2275                return self.linear(x)
2276
2277        mod = MockModule()
2278
2279        def gn(x):
2280            return torch.cos(x) + wrap(mod, x)
2281
2282        def fn(x):
2283            return wrap(gn, x)
2284
2285        backend = EagerAndRecordGraphs()
2286        inp = torch.randn((4, 4))
2287        torch.compile(fn, backend=backend, fullgraph=True)(inp)
2288
2289        gm = backend.graphs[0]
2290        actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "linear"})
2291        self.assertExpectedInline(
2292            pprint.pformat(actual_stack),
2293            """\
2294{'add': ['wrap', 'add'],
2295 'cos': ['wrap', 'cos'],
2296 'linear': ['wrap', 'wrap', 'linear']}""",
2297        )
2298
2299    def test_cond_source_fn_stack(self):
2300        backend = EagerAndRecordGraphs()
2301
2302        @torch.compile(backend=backend, fullgraph=True)
2303        def cond_f(pred, pred2, x, y):
2304            def true_fn(pred2, x, y):
2305                return x + y
2306
2307            def false_fn(pred2, x, y):
2308                def true_fn2(x, y):
2309                    return x.sin() - y.cos()
2310
2311                def false_fn2(x, y):
2312                    return x.cos() - y.sin()
2313
2314                return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])
2315
2316            return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])
2317
2318        pred = torch.tensor(True)
2319        pred2 = torch.tensor(False)
2320        xs = torch.randn(2, 3, 3)
2321        y = torch.randn(3, 3)
2322        cond_f(pred, pred2, xs, y)
2323
2324        gm = backend.graphs[0]
2325        actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin", "sub"})
2326        self.assertExpectedInline(
2327            pprint.pformat(actual_stack),
2328            """\
2329{'add': ['cond', 'add'],
2330 'cos': ['cond', 'cond', 'cos'],
2331 'sin': ['cond', 'cond', 'sin'],
2332 'sub': ['cond', 'cond', 'sub']}""",
2333        )
2334
2335    def test_map_source_fn_stack(self):
2336        backend = EagerAndRecordGraphs()
2337
2338        xs = torch.randn(2, 3, 3)
2339        y = torch.randn(3)
2340
2341        @torch.compile(backend=backend, fullgraph=True)
2342        def map_f(xs, y):
2343            def inner(x, y):
2344                def inner2(x, y):
2345                    return x + y
2346
2347                return control_flow.map(inner2, x, y) * y.cos()
2348
2349            return control_flow.map(inner, xs, y).sin()
2350
2351        result = map_f(xs, y)
2352
2353        gm = backend.graphs[0]
2354        actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin"})
2355        self.assertExpectedInline(
2356            pprint.pformat(actual_stack),
2357            """{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""",
2358        )
2359
2360    def test_grad_source_fn_stack(self):
2361        backend = EagerAndRecordGraphs()
2362
2363        def fn(x):
2364            return x.sin().sum()
2365
2366        @torch.compile(backend=backend, fullgraph=False)
2367        def wrapper_fn(x):
2368            return torch.func.grad(torch.func.grad(fn))(x)
2369
2370        x = torch.randn(())
2371
2372        wrapper_fn(x)
2373        gm = backend.graphs[0]
2374        actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"})
2375        self.assertExpectedInline(
2376            pprint.pformat(actual_stack),
2377            """{'sin': ['sin']}""",
2378        )
2379
2380    def test_vmap_multiply_scalar(self):
2381        @torch.compile(backend="inductor", fullgraph=True)
2382        def g(x):
2383            return torch.vmap(torch.mul, in_dims=(0, None))(x, 3.14)
2384
2385        x = torch.randn(3)
2386        y = g(x)
2387        self.assertEqual(y, x * 3.14)
2388
2389        @torch.compile(backend="inductor", fullgraph=True)
2390        def f(x):
2391            return torch.vmap(torch.mul, in_dims=(0, None))(x, 314)
2392
2393        x = torch.randn(3)
2394        y = f(x)
2395        self.assertEqual(y, x * 314)
2396
2397    def test_vmap_source_fn_stack(self):
2398        backend = EagerAndRecordGraphs()
2399
2400        def inner_fn(x):
2401            return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
2402
2403        @torch.compile(backend=backend, fullgraph=True)
2404        def fn(x):
2405            return torch.func.vmap(lambda x: inner_fn(x.cos()))(x)
2406
2407        x = torch.randn(3, 3, 3, 3)
2408        fn(x)
2409        gm = backend.graphs[0]
2410        actual_stack = self._get_source_fn_stack(
2411            gm, {"sum_1", "sum_2", "batched_output"}
2412        )
2413        self.assertExpectedInline(
2414            pprint.pformat(actual_stack),
2415            """{'sum_1': ['sum_1'], 'sum_2': ['sum_2']}""",
2416        )
2417
2418    def test_cond_pytree_operands(self):
2419        def _construct_pytree():
2420            a = torch.randn(3, 3)
2421            b = torch.randn(3, 3)
2422            c = torch.randn(3, 3)
2423            d = torch.randn(3, 3)
2424            e = torch.randn(3, 3)
2425            f = torch.randn(3, 3)
2426            g = torch.randn(3, 3)
2427            return (a, [[[b]]], c, (d, (e,), f), {"g": g})
2428
2429        pred = torch.tensor(True)
2430        inp = _construct_pytree()
2431
2432        def _reduce_sum(flattened):
2433            init = 0
2434            for val in flattened:
2435                init += val
2436            return init
2437
2438        def _reduce_max(flattened):
2439            init = flattened[0]
2440            for val in flattened:
2441                init = max(val, init)
2442            return init
2443
2444        def true_fn(pytree_in):
2445            flattened, spec = pytree.tree_flatten(pytree_in)
2446            return _reduce_sum(flattened)
2447
2448        def false_fn(pytree_in):
2449            flattened, spec = pytree.tree_flatten(pytree_in)
2450            return _reduce_max(flattened)
2451
2452        def fn(pred, pytree_in):
2453            return torch.cond(pred, true_fn, false_fn, [pytree_in])
2454
2455        backend = EagerAndRecordGraphs()
2456        cnt = CompileCounterWithBackend(backend)
2457        compiled_res = torch.compile(fn, backend=backend)(pred, inp)
2458        eager_res = fn(pred, inp)
2459        self.assertEqual(compiled_res, eager_res)
2460        graph = backend.graphs[0]
2461
2462        # Dynamic shapes produce a slightly different graph.
2463        if check_dynamic_shape_capture():
2464            return
2465
2466        self.assertExpectedInline(
2467            graph.code.strip(),
2468            """\
2469def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor):
2470    l_pred_ = L_pred_
2471    l_pytree_in_0_ = L_pytree_in_0_
2472    l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_
2473    l_pytree_in_2_ = L_pytree_in_2_
2474    l_pytree_in_3_0_ = L_pytree_in_3_0_
2475    l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_
2476    l_pytree_in_3_2_ = L_pytree_in_3_2_
2477    l_pytree_in_4_g_ = L_pytree_in_4_g_
2478    cond_true_0 = self.cond_true_0
2479    cond_false_0 = self.cond_false_0
2480    cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]);  l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
2481    getitem = cond[0];  cond = None
2482    return (getitem,)""",  # noqa: B950
2483        )
2484
2485    def test_cond_pytree_operands_with_non_tensor_leaves(self):
2486        def fn(pred, pytree_in):
2487            return torch.cond(
2488                pred, lambda x: x[0] + 1, lambda x: x[0] * 2, (pytree_in,)
2489            )
2490
2491        pred = torch.tensor(True)
2492        for pytree_in in [(1,), ("string",), (1.0,)]:
2493            with self.assertRaisesRegex(
2494                RuntimeError,
2495                r"Expect operands to be a tuple of possibly nested dict/list/tuple",
2496            ):
2497                fn(pred, pytree_in)
2498
2499        for pytree_in in [(1,), ("string",), (1.0,)]:
2500            with self.assertRaisesRegex(
2501                torch._dynamo.exc.UncapturedHigherOrderOpError,
2502                r"Cond doesn't work unless it is captured completely with torch.compile",
2503            ):
2504                torch.compile(fn, backend="eager")(pred, pytree_in)
2505
2506    def test_hints_wrapper(self):
2507        def ref_fn(x, y):
2508            x = x + y
2509            x = torch.relu(x)
2510            x = x + y
2511            return torch.abs(x)
2512
2513        def fn_with_hints(x, y):
2514            x = x + y
2515
2516            def inner_body_fn(x, y):
2517                x = torch.relu(x)
2518                x = x + y
2519                return x
2520
2521            def outer_body_fn(x, y):
2522                x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True})
2523                x = torch.abs(x)
2524                return x
2525
2526            res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True})
2527            return res
2528
2529        backend = EagerAndRecordGraphs()
2530        cnt = CompileCounterWithBackend(backend)
2531
2532        x = torch.randn(2, 4)
2533        y = torch.ones(4)
2534
2535        eager_res = fn_with_hints(x, y)
2536        compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
2537        ref_res = ref_fn(x, y)
2538        self.assertEqual(eager_res, ref_res)
2539        self.assertEqual(compiled_res, ref_res)
2540        self.assertEqual(len(cnt.graphs), 1)
2541
2542        # Dynamic shapes produce a slightly different graph.
2543        if check_dynamic_shape_capture():
2544            return
2545
2546        graph = backend.graphs[0]
2547        self.assertExpectedInline(
2548            normalize_gm(graph.print_readable(print_output=False)),
2549            """\
2550class GraphModule(torch.nn.Module):
2551    def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"):
2552        l_x_ = L_x_
2553        l_y_ = L_y_
2554
2555        x: "f32[2, 4]" = l_x_ + l_y_;  l_x_ = None
2556
2557        hints_wrapper_body_1 = self.hints_wrapper_body_1
2558        hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True});  hints_wrapper_body_1 = x = l_y_ = None
2559        res: "f32[2, 4]" = hints_wrapper[0];  hints_wrapper = None
2560        return (res,)
2561
2562    class hints_wrapper_body_1(torch.nn.Module):
2563        def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
2564            hints_wrapper_body_0 = self.hints_wrapper_body_0
2565            hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True});  hints_wrapper_body_0 = x = l_y_ = None
2566            x_1: "f32[2, 4]" = hints_wrapper[0];  hints_wrapper = None
2567
2568            x_2: "f32[2, 4]" = torch.abs(x_1);  x_1 = None
2569            return (x_2,)
2570
2571        class hints_wrapper_body_0(torch.nn.Module):
2572            def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
2573                x_1: "f32[2, 4]" = torch.relu(x);  x = None
2574
2575                x_2: "f32[2, 4]" = x_1 + l_y_;  x_1 = l_y_ = None
2576                return (x_2,)
2577""",
2578        )
2579
2580    def test_hints_wrapper_no_hints(self):
2581        def fn_with_hints(x, y):
2582            def outer_body_fn(x, y):
2583                x = torch.add(x, y)
2584                return x
2585
2586            res = hints_wrapper(outer_body_fn, (x, y), {})
2587            return res
2588
2589        backend = EagerAndRecordGraphs()
2590        cnt = CompileCounterWithBackend(backend)
2591
2592        x = torch.randn(2, 4)
2593        y = torch.ones(4)
2594
2595        msg = "hints_wrapper - key hints not provided"
2596        with self.assertRaisesRegex(RuntimeError, msg):
2597            compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
2598
2599    def test_hints_wrapper_incorrect_type(self):
2600        def fn_with_hints(x, y):
2601            def outer_body_fn(x, y):
2602                x = torch.add(x, y)
2603                return x
2604
2605            res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)})
2606            return res
2607
2608        backend = EagerAndRecordGraphs()
2609        cnt = CompileCounterWithBackend(backend)
2610
2611        x = torch.randn(2, 4)
2612        y = torch.ones(4)
2613
2614        msg = r"hints must be a dict containing int, float, bool or str value,"
2615        with self.assertRaisesRegex(RuntimeError, msg):
2616            compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
2617
2618    def test_hints_wrapper_pytree_inputs(self):
2619        def fn_with_hints(x, y):
2620            def outer_body_fn(x):
2621                res = torch.add(x[0], x[1]["test"])
2622                return res
2623
2624            res = hints_wrapper(
2625                outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True}
2626            )
2627            return res
2628
2629        backend = EagerAndRecordGraphs()
2630        cnt = CompileCounterWithBackend(backend)
2631
2632        x = torch.randn(2, 4)
2633        y = torch.ones(4)
2634
2635        msg = r"args must be a tuple of tensors, ints, floats, or bools,"
2636        with self.assertRaisesRegex(RuntimeError, msg):
2637            fn_with_hints(x, y)
2638
2639
2640class HigherOrderOpVmapGuardTests(LoggingTestCase):
2641    @make_logging_test(recompiles=True)
2642    def test_vmap_grad_guard_ok(self, records):
2643        vmap = torch.vmap
2644        grad = torch.func.grad
2645
2646        def g(x):
2647            return vmap(grad(torch.sin))(x)
2648
2649        @torch.compile(backend="eager")
2650        def fn(x):
2651            return vmap(g)(x)
2652
2653        x = torch.randn(4, 5)
2654        y = fn(x)
2655        # sanity check
2656        self.assertEqual(len(records), 0)
2657        self.assertEqual(x.cos(), y)
2658
2659        # Calling the same function again won't have any effect on guards
2660        fn(x)
2661        self.assertEqual(len(records), 0)
2662
2663    @xfailIfTorchDynamo
2664    @make_logging_test(recompiles=True)
2665    def test_grad_guard_fail(self, records):
2666        grad = torch.func.grad
2667
2668        @torch.compile(backend="eager")
2669        def fn(x):
2670            return grad(torch.sin)(x.sum())
2671
2672        x = torch.randn([])
2673        fn(x)
2674        self.assertEqual(len(records), 0)
2675
2676        # calling again should not invalidate the graph
2677        fn(x)
2678        self.assertEqual(len(records), 0)
2679
2680        # call grad should retrigger compilation
2681        x = torch.randn(3)
2682        grad(fn)(x)
2683        self.assertGreater(len(records), 0)
2684        record = self.getRecord(records, "pyfunctorch")
2685        self.assertIn(
2686            """torch._functorch.pyfunctorch.compare_functorch_state([])""",
2687            munge_exc(record.getMessage()),
2688        )
2689
2690    @make_logging_test(recompiles=True)
2691    def test_dual_level_guard(self, records):
2692        fwAD = torch.autograd.forward_ad
2693
2694        @torch.compile(backend="eager", fullgraph=True)
2695        def fn(foo, tangent):
2696            with fwAD.dual_level():
2697                dual = fwAD.make_dual(foo, tangent[1:])
2698                return dual
2699
2700        foo = torch.rand(2)
2701        tangent = torch.rand(3)
2702        fn(foo, tangent)
2703        self.assertEqual(len(records), 0)
2704
2705        # calling again should not invalidate the graph
2706        fn(foo, tangent)
2707        self.assertEqual(len(records), 0)
2708
2709        # assertRaises is only here because Nested forward mode AD is not supported
2710        with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError):
2711            with fwAD.dual_level():
2712                fn(foo, tangent)
2713        self.assertGreater(len(records), 0)
2714        record = self.getRecord(records, "forward_ad")
2715        self.assertIn(
2716            """torch.autograd.forward_ad._current_level == -1""",
2717            munge_exc(record.getMessage()),
2718        )
2719
2720    @xfailIfTorchDynamo
2721    @make_logging_test(recompiles=True)
2722    def test_jvp_guard_fail(self, records):
2723        jvp = torch.func.jvp
2724        vmap = torch.func.vmap
2725
2726        @torch.compile(backend="eager")
2727        def fn(x):
2728            return jvp(torch.sin, (x,), (x,))
2729
2730        x = torch.randn(3, 4)
2731        fn(x)
2732        self.assertEqual(len(records), 0)
2733
2734        # calling again should not invalidate the graph
2735        fn(x)
2736        self.assertEqual(len(records), 0)
2737
2738        # call jvp should retrigger compilation
2739        x = torch.randn(3, 4, 5)
2740        jvp(vmap(fn), (x,), (x,))
2741
2742        self.assertGreater(len(records), 0)
2743        if self.hasRecord(records, "pyfunctorch"):
2744            record = self.getRecord(records, "pyfunctorch")
2745            self.assertIn(
2746                """torch._functorch.pyfunctorch.compare_functorch_state([])""",
2747                munge_exc(record.getMessage()),
2748            )
2749        elif self.hasRecord(records, "forward_ad"):
2750            record = self.getRecord(records, "forward_ad")
2751            self.assertIn(
2752                """torch.autograd.forward_ad._current_level == -1""",
2753                munge_exc(record.getMessage()),
2754            )
2755
2756    @make_logging_test(recompiles=True)
2757    def test_vmap_guard_ok(self, records):
2758        @torch.compile(backend="eager")
2759        def fn(x):
2760            return torch.vmap(lambda x: x.sin())(x)
2761
2762        x = torch.randn(3, 3, 4, 5)
2763        y = fn(x)
2764        # sanity check
2765        self.assertEqual(len(records), 0)
2766        self.assertEqual(x.sin(), y)
2767
2768        # Calling the same function again won't have any effect on guards
2769        z = fn(x)
2770        self.assertEqual(len(records), 0)
2771        self.assertEqual(x.sin(), z)
2772
2773        # calling with a different object will also not affect guards
2774        w = fn(z)
2775        self.assertEqual(len(records), 0)
2776        self.assertEqual(z.sin(), w)
2777
2778    @xfailIfTorchDynamo
2779    @make_logging_test(recompiles=True)
2780    def test_vmap_guard_fail_different_state(self, records):
2781        @torch.compile(backend="eager")
2782        def fn(x):
2783            return torch.vmap(lambda x: x.sin())(x)
2784
2785        x = torch.zeros(3, 4)
2786        y = torch.vmap(fn, randomness="same")(x)
2787        self.assertEqual(x.sin(), y)
2788        self.assertEqual(len(records), 0)
2789
2790        # call vmap(vmap(fn))(x) should retrigger compilation
2791        y = torch.vmap(fn, randomness="different")(x)
2792        self.assertEqual(x.sin(), y)
2793        self.assertGreater(len(records), 0)
2794        record = self.getRecord(records, "pyfunctorch")
2795        self.assertIn(
2796            """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
2797            record.getMessage(),
2798        )
2799
2800    @xfailIfTorchDynamo
2801    @make_logging_test(recompiles=True)
2802    def test_vmap_guard_fail(self, records):
2803        @torch.compile(backend="eager")
2804        def fn(x):
2805            return torch.vmap(lambda x: x.sin())(x)
2806
2807        x = torch.zeros(3, 3, 4, 5)
2808        y = torch.vmap(fn)(x)
2809        self.assertEqual(x.sin(), y)
2810        self.assertEqual(len(records), 0)
2811
2812        # call vmap(vmap(fn))(x) should retrigger compilation as
2813        # _functorch.current_level() is not the same
2814        x = torch.zeros(3, 3, 3, 4, 5)
2815        y = torch.vmap(torch.vmap(fn))(x)
2816        self.assertEqual(x.sin(), y)
2817        self.assertGreater(len(records), 0)
2818        record = self.getRecord(records, "pyfunctorch")
2819        self.assertIn(
2820            """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
2821            record.getMessage(),
2822        )
2823
2824    @xfailIfTorchDynamo
2825    @make_logging_test(recompiles=True)
2826    def test_vmap_grad_vmap_guard_fail(self, records):
2827        vmap = torch.vmap
2828        grad = torch.func.grad
2829
2830        def g(x):
2831            y = vmap(torch.sin, randomness="same")(x)
2832            return y.sum(0)
2833
2834        @torch.compile(backend="eager")
2835        def fn(x):
2836            return grad(g)(x)
2837
2838        x = torch.randn(3, 3)
2839        y = vmap(fn, randomness="error")(x)
2840        self.assertEqual(x.cos(), y)
2841
2842        # previous FX graph should be invalidated
2843        x = torch.randn(3, 3, 4)
2844        y = vmap(vmap(fn, randomness="different"))(x)
2845        self.assertGreater(len(records), 0)
2846        record = self.getRecord(records, "pyfunctorch")
2847        self.assertIn(
2848            """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
2849            munge_exc(record.getMessage()),
2850        )
2851
2852    @xfailIfTorchDynamo
2853    @make_logging_test(recompiles=True)
2854    def test_vmap_recompile_different_states(self, records):
2855        @torch.compile(backend="eager")
2856        def fn(x):
2857            return torch.vmap(lambda x: x.sin())(x)
2858
2859        x = torch.zeros(3, 3, 4, 5)
2860        y = torch.vmap(fn, randomness="same")(x)
2861        self.assertEqual(len(records), 0)  # sanity check
2862
2863        y = torch.vmap(fn, randomness="different")(x)
2864        self.assertGreater(len(records), 0)
2865        record = self.getRecord(records, "pyfunctorch")
2866        self.assertIn(
2867            """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
2868            munge_exc(record.getMessage()),
2869        )
2870
2871    @config.patch(capture_func_transforms=True)
2872    @make_logging_test(guards=True)
2873    def test_emit_functorch_guard_if_active(self, records):
2874        @torch.compile(backend="eager")
2875        def fn(x):
2876            return torch.sin(x)
2877
2878        x = torch.randn(3, 4)
2879        _ = fn(x)
2880        self.assertFalse(self.hasRecord(records, "pyfunctorch"))  # sanity check
2881
2882        _ = torch.vmap(fn)(x)
2883        self.assertTrue(self.hasRecord(records, "pyfunctorch"))
2884        record = self.getRecord(records, "pyfunctorch")
2885        self.assertIn(
2886            """torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
2887            munge_exc(record.getMessage()),
2888        )
2889
2890    @make_logging_test(recompiles=True)
2891    def test_linearize_recompiles(self, records):
2892        @torch.compile(backend="eager")
2893        def fn(x):
2894            out, jvp_fn = torch.func.linearize(torch.sin, x)
2895            return out, jvp_fn(x)
2896
2897        x = torch.randn(2, 3)
2898        fn(x)
2899        self.assertEqual(len(records), 0)
2900
2901        z = torch.randn(2, 3)
2902        fn(z)
2903        self.assertEqual(len(records), 0)
2904
2905        y = torch.randn(3, 4)
2906        fn(y)
2907        self.assertGreater(len(records), 0)
2908
2909
2910class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
2911    def tearDown(self):
2912        # Ensure that in the case of a test failure, the next test won't fail
2913        # because of a previous call to _vmap_increment_nesting that wasn't undone
2914        # i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1
2915        # and the call to increment nesting is not undone
2916        if not TEST_WITH_TORCHDYNAMO:
2917            return
2918
2919        warn = False
2920        while ci := torch._C._functorch.peek_interpreter_stack():
2921            if ci.key() == torch._C._functorch.TransformType.Vmap:
2922                warn = True
2923                torch._C._functorch._vmap_decrement_nesting()
2924            else:
2925                break
2926
2927        if warn:
2928            msg = (
2929                "Interpreter stack is not empty. Test should have called "
2930                "'torch._C._functorch._vmap_decrement_nesting()'"
2931            )
2932            warnings.warn(msg)
2933
2934    def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0):
2935        backend = EagerAndRecordGraphs()
2936        actual = fn(*inputs)
2937        expected = torch.compile(fn, backend=backend, fullgraph=fullgraph)(*inputs)
2938
2939        self.assertEqual(actual, expected)
2940
2941        wrapped_gm = backend.graphs[graph_idx]
2942        return wrapped_gm
2943
2944    def test_hessian(self):
2945        counters.clear()
2946
2947        def wrapper_fn(x):
2948            return torch.func.hessian(torch.sin)(x)
2949
2950        x = torch.randn(4, 3)
2951        wrapped_gm = self._compile_check(wrapper_fn, (x,))
2952        # Dynamic shapes produce a slightly different graph.
2953        if check_dynamic_shape_capture():
2954            return
2955
2956        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
2957        self.assertExpectedInline(
2958            actual,
2959            """\
2960class GraphModule(torch.nn.Module):
2961    def forward(self, L_x_: "f32[4, 3]"):
2962        l_x_ = L_x_
2963
2964        tensor: "i64[1]" = torch.tensor((12,))
2965        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
2966        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
2967        neg: "i64[0]" = getitem.neg();  getitem = None
2968        unbind = neg.unbind();  neg = unbind = None
2969
2970        chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
2971
2972        diagonal: "f32[12]" = chunk.diagonal(0)
2973        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
2974
2975        child: "f32[12, 4, 3]" = chunk.view(12, 4, 3);  chunk = None
2976
2977        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
2978
2979        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
2980
2981        child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
2982
2983        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,));  _jvp_treespec_compare = None
2984
2985        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
2986        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
2987        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
2988
2989        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
2990
2991        child_2: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
2992
2993        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
2994
2995        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
2996        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
2997
2998        diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3);  child_2 = None
2999
3000        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3001
3002        _set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
3003
3004        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3005
3006        o: "f32[4, 3]" = torch.sin(diff_primals)
3007
3008        results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3)
3009
3010        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3011        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3012
3013        tensor_1: "i64[1]" = torch.tensor((12,))
3014        cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0);  tensor_1 = None
3015        getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)];  cumsum_1 = None
3016        neg_1: "i64[0]" = getitem_1.neg();  getitem_1 = None
3017        unbind_1 = neg_1.unbind();  neg_1 = unbind_1 = None
3018
3019        chunk_1: "f32[12, 12]" = results.new_zeros(12, 12);  results = None
3020
3021        diagonal_1: "f32[12]" = chunk_1.diagonal(0)
3022        fill__1: "f32[12]" = diagonal_1.fill_(1);  diagonal_1 = fill__1 = None
3023
3024        basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3);  chunk_1 = None
3025
3026        lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions_1 = None
3027
3028        _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting_1 = None
3029
3030        _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3);  basis = None
3031
3032        _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1);  _vjp_treespec_compare = None
3033
3034        _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim_1 = None
3035        batched_outputs: "f32[4, 3]" = _autograd_grad[0];  _autograd_grad = None
3036
3037        chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0);  batched_outputs = None
3038
3039        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
3040
3041        split = chunked_result.split((12,), dim = 0);  chunked_result = None
3042        split_1: "f32[12, 4, 3]" = split[0];  split = None
3043
3044        output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3));  split_1 = None
3045
3046        _unpack_dual = torch._unpack_dual(output_input, level = 0);  output_input = None
3047        primal: "f32[4, 3, 4, 3]" = _unpack_dual[0]
3048        dual: "f32[4, 3, 4, 3]" = _unpack_dual[1];  _unpack_dual = None
3049
3050        primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = primals_out_unflatten = None
3051
3052        tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
3053
3054        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
3055        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
3056        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
3057
3058        results_1: "f32[12, 4, 3, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0);  tangents_out_unflatten = None
3059
3060        _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting_1 = None
3061
3062        movedim: "f32[4, 3, 4, 3, 12]" = results_1.movedim(0, -1);  results_1 = None
3063        split_2 = movedim.split((12,), dim = -1);  movedim = None
3064        jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0];  split_2 = None
3065
3066        unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3));  jac_out_in = None
3067        return (unflatten,)
3068""",
3069        )
3070
3071    def test_hessian_argnums(self):
3072        counters.clear()
3073
3074        def fn(x, y):
3075            return x.sin()
3076
3077        def wrapper_fn(x, y):
3078            return torch.func.hessian(fn, argnums=(1,))(x, y)
3079
3080        x = torch.randn(4, 3)
3081        y = torch.randn(3, 4)
3082        wrapped_gm = self._compile_check(wrapper_fn, (x, y))
3083        # Dynamic shapes produce a slightly different graph.
3084        if check_dynamic_shape_capture():
3085            return
3086
3087        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3088        self.assertExpectedInline(
3089            "\n".join(actual.split("\n")[:-2]),
3090            """\
3091class GraphModule(torch.nn.Module):
3092    def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
3093        l_x_ = L_x_
3094        l_y_ = L_y_
3095
3096        tensor: "i64[1]" = torch.tensor((12,))
3097        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
3098        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
3099        neg: "i64[0]" = getitem.neg();  getitem = None
3100        unbind = neg.unbind();  neg = unbind = None
3101
3102        chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
3103
3104        diagonal: "f32[12]" = chunk.diagonal(0)
3105        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
3106
3107        child: "f32[12, 3, 4]" = chunk.view(12, 3, 4);  chunk = None
3108
3109        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
3110
3111        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
3112
3113        child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
3114
3115        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,));  _jvp_treespec_compare = None
3116
3117        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
3118        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
3119        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
3120
3121        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
3122
3123        child_3: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
3124
3125        child_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = None
3126        _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
3127
3128        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3129        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3130
3131        _wrap_for_grad_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3);  child_2 = None
3132        child_4: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(child_3, 3);  child_3 = None
3133
3134        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3135
3136        _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4);  _set_tensor_requires_grad = None
3137
3138        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3139
3140        o: "f32[4, 3]" = _wrap_for_grad_2.sin();  _wrap_for_grad_2 = None
3141
3142        results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3)
3143
3144        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3145        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3146
3147        tensor_1: "i64[1]" = torch.tensor((12,))
3148        cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0);  tensor_1 = None
3149        getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)];  cumsum_1 = None
3150        neg_1: "i64[0]" = getitem_1.neg();  getitem_1 = None
3151        unbind_1 = neg_1.unbind();  neg_1 = unbind_1 = None
3152
3153        chunk_1: "f32[12, 12]" = results.new_zeros(12, 12);  results = None
3154
3155        diagonal_1: "f32[12]" = chunk_1.diagonal(0)
3156        fill__1: "f32[12]" = diagonal_1.fill_(1);  diagonal_1 = fill__1 = None
3157
3158        basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3);  chunk_1 = None
3159
3160        lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions_1 = None
3161
3162        _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting_1 = None
3163
3164        _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3);  basis = None
3165
3166        _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1);  _vjp_treespec_compare = None
3167
3168        _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True);  o = child_4 = _add_batch_dim_1 = None
3169        child_5: "f32[3, 4]" = _autograd_grad[0];  _autograd_grad = None
3170
3171        child_6: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0);  child_5 = None
3172
3173        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
3174
3175        split = child_6.split((12,), dim = 0);  child_6 = None
3176        split_1: "f32[12, 3, 4]" = split[0];  split = None
3177
3178        child_7: "f32[4, 3, 3, 4]" = split_1.view((4, 3, 3, 4));  split_1 = None
3179
3180        _unpack_dual = torch._unpack_dual(child_7, level = 0);  child_7 = None
3181        primal: "f32[4, 3, 3, 4]" = _unpack_dual[0];  _unpack_dual = None
3182
3183        tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal)
3184
3185        child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = child_8 = None
3186
3187        child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2);  tangent = None
3188
3189        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
3190        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
3191        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
3192
3193        child_10: "f32[12, 4, 3, 3, 4]" = torch._C._functorch._remove_batch_dim(child_9, 1, 12, 0);  child_9 = None
3194
3195        _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting_1 = None
3196
3197        movedim: "f32[4, 3, 3, 4, 12]" = child_10.movedim(0, -1);  child_10 = None
3198        split_2 = movedim.split((12,), dim = -1);  movedim = None
3199        jac_out_in: "f32[4, 3, 3, 4, 12]" = split_2[0];  split_2 = None
3200
3201        unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4));  jac_out_in = None""",
3202        )
3203
3204        self.assertExpectedInline(
3205            actual.split("\n")[-2],
3206            """        return (unflatten,)""",
3207        )
3208
3209    def test_hessian_disable_capture(self):
3210        counters.clear()
3211
3212        with config.patch(capture_func_transforms=False):
3213            # We have verified above that this
3214            # function compiles
3215            def wrapper_fn(x):
3216                return torch.func.hessian(torch.sin)(x)
3217
3218            x = torch.randn(3, 3, 3)
3219            actual = wrapper_fn(x)
3220            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
3221                x
3222            )
3223            self.assertEqual(len(counters["graph_break"]), 2)
3224            self.assertEqual(
3225                {
3226                    "torch.func.vmap capture is disabled, it can be "
3227                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 2,
3228                    "torch.func.hessian capture is disabled, it can be "
3229                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1,
3230                },
3231                dict(counters["graph_break"]),
3232            )
3233            self.assertEqual(actual, expected)
3234
3235    def test_jacrev(self):
3236        counters.clear()
3237
3238        def wrapper_fn(x):
3239            return torch.func.jacrev(torch.sin)(x)
3240
3241        x = torch.randn(4, 3)
3242        wrapped_gm = self._compile_check(wrapper_fn, (x,))
3243        # Dynamic shapes produce a slightly different graph.
3244        if check_dynamic_shape_capture():
3245            return
3246
3247        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3248        self.assertExpectedInline(
3249            actual,
3250            """\
3251class GraphModule(torch.nn.Module):
3252    def forward(self, L_x_: "f32[4, 3]"):
3253        l_x_ = L_x_
3254
3255        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3256        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3257
3258        diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
3259
3260        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3261
3262        _set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
3263
3264        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3265
3266        o: "f32[4, 3]" = torch.sin(diff_primals)
3267
3268        results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 1)
3269
3270        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3271        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3272
3273        tensor: "i64[1]" = torch.tensor((12,))
3274        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
3275        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
3276        neg: "i64[0]" = getitem.neg();  getitem = None
3277        unbind = neg.unbind();  neg = unbind = None
3278
3279        chunk: "f32[12, 12]" = results.new_zeros(12, 12);  results = None
3280
3281        diagonal: "f32[12]" = chunk.diagonal(0)
3282        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
3283
3284        basis: "f32[12, 4, 3]" = chunk.view(12, 4, 3);  chunk = None
3285
3286        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
3287
3288        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
3289
3290        _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
3291
3292        _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim);  _vjp_treespec_compare = None
3293
3294        _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None
3295        batched_outputs: "f32[4, 3]" = _autograd_grad[0];  _autograd_grad = None
3296
3297        chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None
3298
3299        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
3300
3301        split = chunked_result.split((12,), dim = 0);  chunked_result = None
3302        split_1: "f32[12, 4, 3]" = split[0];  split = None
3303
3304        output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3));  split_1 = None
3305        return (output_input,)
3306""",
3307        )
3308
3309    def test_jacrev_two_tensors_argnums(self):
3310        counters.clear()
3311
3312        def fn(x, y):
3313            return y.sin()
3314
3315        def wrapper_fn(x, y):
3316            return torch.func.jacrev(fn, argnums=1)(x, y)
3317
3318        x = torch.randn(4, 3)
3319        y = torch.randn(3, 4)
3320        wrapped_gm = self._compile_check(wrapper_fn, (x, y))
3321        # Dynamic shapes produce a slightly different graph.
3322        if check_dynamic_shape_capture():
3323            return
3324
3325        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3326        self.assertExpectedInline(
3327            actual,
3328            """\
3329class GraphModule(torch.nn.Module):
3330    def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
3331        l_x_ = L_x_
3332        l_y_ = L_y_
3333
3334        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3335        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3336
3337        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = _wrap_for_grad = None
3338        diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
3339
3340        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3341
3342        _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
3343
3344        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3345
3346        o: "f32[3, 4]" = diff_primals.sin()
3347
3348        results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1)
3349
3350        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3351        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3352
3353        tensor: "i64[1]" = torch.tensor((12,))
3354        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
3355        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
3356        neg: "i64[0]" = getitem.neg();  getitem = None
3357        unbind = neg.unbind();  neg = unbind = None
3358
3359        chunk: "f32[12, 12]" = results.new_zeros(12, 12);  results = None
3360
3361        diagonal: "f32[12]" = chunk.diagonal(0)
3362        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
3363
3364        basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4);  chunk = None
3365
3366        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
3367
3368        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
3369
3370        _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
3371
3372        _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim);  _vjp_treespec_compare = None
3373
3374        _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None
3375        batched_outputs: "f32[3, 4]" = _autograd_grad[0];  _autograd_grad = None
3376
3377        chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None
3378
3379        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
3380
3381        split = chunked_result.split((12,), dim = 0);  chunked_result = None
3382        split_1: "f32[12, 3, 4]" = split[0];  split = None
3383
3384        output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4));  split_1 = None
3385        return (output_input,)
3386""",
3387        )
3388
3389    def test_jacrev_has_aux(self):
3390        counters.clear()
3391
3392        def fn(x, y):
3393            return y.sin(), x
3394
3395        def wrapper_fn(x, y):
3396            return torch.func.jacrev(fn, argnums=1, has_aux=True)(x, y)
3397
3398        x = torch.randn(4, 3)
3399        y = torch.randn(3, 4)
3400        wrapped_gm = self._compile_check(wrapper_fn, (x, y))
3401        # Dynamic shapes produce a slightly different graph.
3402        if check_dynamic_shape_capture():
3403            return
3404
3405        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3406        self.assertExpectedInline(
3407            actual,
3408            """\
3409class GraphModule(torch.nn.Module):
3410    def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
3411        l_x_ = L_x_
3412        l_y_ = L_y_
3413
3414        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3415        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3416
3417        aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
3418        diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
3419
3420        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3421
3422        _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
3423
3424        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3425
3426        o: "f32[3, 4]" = diff_primals.sin()
3427
3428        aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
3429
3430        results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1)
3431
3432        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3433        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3434
3435        tensor: "i64[1]" = torch.tensor((12,))
3436        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
3437        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
3438        neg: "i64[0]" = getitem.neg();  getitem = None
3439        unbind = neg.unbind();  neg = unbind = None
3440
3441        chunk: "f32[12, 12]" = results.new_zeros(12, 12);  results = None
3442
3443        diagonal: "f32[12]" = chunk.diagonal(0)
3444        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
3445
3446        basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4);  chunk = None
3447
3448        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
3449
3450        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
3451
3452        _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
3453
3454        _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim);  _vjp_treespec_compare = None
3455
3456        _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None
3457        batched_outputs: "f32[3, 4]" = _autograd_grad[0];  _autograd_grad = None
3458
3459        chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None
3460
3461        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
3462
3463        split = chunked_result.split((12,), dim = 0);  chunked_result = None
3464        split_1: "f32[12, 3, 4]" = split[0];  split = None
3465
3466        output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4));  split_1 = None
3467        return (output_input, aux_1)
3468""",
3469        )
3470
3471    def test_jacrev_disable_capture(self):
3472        counters.clear()
3473
3474        with config.patch(capture_func_transforms=False):
3475            # We have verified above that this
3476            # function compiles
3477            def wrapper_fn(x):
3478                return torch.func.jacrev(torch.sin)(x)
3479
3480            x = torch.randn(3, 3, 3)
3481            actual = wrapper_fn(x)
3482            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
3483                x
3484            )
3485            self.assertEqual(len(counters["graph_break"]), 2)
3486            self.assertEqual(
3487                dict(counters["graph_break"]),
3488                {
3489                    "torch.func.vmap capture is disabled, it can be "
3490                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 2,
3491                    "torch.func.jacrev capture is disabled, it can be "
3492                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1,
3493                },
3494            )
3495            self.assertEqual(actual, expected)
3496
3497    def test_vjp(self):
3498        counters.clear()
3499
3500        def fn(x):
3501            return x.sin().sum()
3502
3503        def wrapper_fn(x, v):
3504            (out, vjpfunc) = torch.func.vjp(fn, x)
3505            return out
3506
3507        x = torch.randn([5])
3508        v = torch.randn(5)
3509        wrapped_gm = self._compile_check(wrapper_fn, (x, v))
3510
3511        # Dynamic shapes produce a slightly different graph.
3512        if check_dynamic_shape_capture():
3513            return
3514
3515        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3516        self.assertExpectedInline(
3517            actual,
3518            """\
3519class GraphModule(torch.nn.Module):
3520    def forward(self, L_x_: "f32[5]"):
3521        l_x_ = L_x_
3522
3523        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3524        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3525
3526        child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
3527
3528        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3529
3530        child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  child_1 = None
3531
3532        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3533
3534        sin: "f32[5]" = child.sin();  child = None
3535        o: "f32[]" = sin.sum();  sin = None
3536
3537        results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1);  o = None
3538
3539        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3540        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3541        return (results,)
3542""",
3543        )
3544
3545    def test_vjp_multiple_outputs(self):
3546        counters.clear()
3547
3548        def wrapper_fn(x, v):
3549            fn = lambda x: (x.sin(), x.cos())  # noqa: E731
3550            (out, vjpfunc) = torch.func.vjp(fn, x)
3551            vjps = vjpfunc((v, v))
3552            return out, vjps
3553
3554        x = torch.randn([5])
3555        v = torch.randn(5)
3556        wrapped_gm = self._compile_check(wrapper_fn, (x, v))
3557
3558        # Dynamic shapes produce a slightly different graph.
3559        if check_dynamic_shape_capture():
3560            return
3561
3562        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3563        self.assertExpectedInline(
3564            actual,
3565            """\
3566class GraphModule(torch.nn.Module):
3567    def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"):
3568        l_x_ = L_x_
3569        l_v_ = L_v_
3570
3571        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3572        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3573
3574        child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
3575
3576        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3577
3578        child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
3579
3580        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3581
3582        child_1: "f32[5]" = child.sin()
3583        child_2: "f32[5]" = child.cos();  child = None
3584
3585        _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
3586        _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
3587
3588        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3589        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3590
3591        _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare((child_1, child_2), (l_v_, l_v_));  _vjp_treespec_compare = None
3592
3593        _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, l_v_], retain_graph = True, create_graph = True);  child_1 = child_2 = child_3 = l_v_ = None
3594        getitem: "f32[5]" = _autograd_grad[0];  _autograd_grad = None
3595        return (_unwrap_for_grad, _unwrap_for_grad_1, getitem)
3596""",
3597        )
3598
3599    def test_vjp_multiple_outputs_python_struct(self):
3600        counters.clear()
3601
3602        def wrapper_fn(x, v):
3603            fn = lambda x: {"first": x.sin(), "second": x.cos()}  # noqa: E731
3604            (out, vjpfunc) = torch.func.vjp(fn, x)
3605            vjps = vjpfunc({"first": v, "second": v.sin()})
3606            return out, vjps
3607
3608        x = torch.randn([5])
3609        v = torch.randn(5)
3610        wrapped_gm = self._compile_check(wrapper_fn, (x, v))
3611
3612        # Dynamic shapes produce a slightly different graph.
3613        if check_dynamic_shape_capture():
3614            return
3615
3616        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3617        self.assertExpectedInline(
3618            actual,
3619            """\
3620class GraphModule(torch.nn.Module):
3621    def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"):
3622        l_x_ = L_x_
3623        l_v_ = L_v_
3624
3625        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3626        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3627
3628        child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
3629
3630        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3631
3632        child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
3633
3634        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3635
3636        child_1: "f32[5]" = child.sin()
3637        child_2: "f32[5]" = child.cos();  child = None
3638
3639        _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
3640        _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
3641
3642        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3643        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3644
3645        child_4: "f32[5]" = l_v_.sin()
3646
3647        _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare({'first': child_1, 'second': child_2}, {'first': l_v_, 'second': child_4});  _vjp_treespec_compare = None
3648
3649        _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True);  child_1 = child_2 = child_3 = l_v_ = child_4 = None
3650        getitem: "f32[5]" = _autograd_grad[0];  _autograd_grad = None
3651        return (_unwrap_for_grad, _unwrap_for_grad_1, getitem)
3652""",
3653        )
3654
3655    def test_vjp_has_aux(self):
3656        counters.clear()
3657
3658        def fn(x):
3659            return x.sin().sum(), x
3660
3661        def wrapper_fn(x, v):
3662            (out, vjpfunc, _) = torch.func.vjp(fn, x, has_aux=True)
3663            return out
3664
3665        x = torch.randn([5])
3666        v = torch.randn(5)
3667        wrapped_gm = self._compile_check(wrapper_fn, (x, v))
3668
3669        # Dynamic shapes produce a slightly different graph.
3670        if check_dynamic_shape_capture():
3671            return
3672
3673        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3674        self.assertExpectedInline(
3675            actual,
3676            """\
3677class GraphModule(torch.nn.Module):
3678    def forward(self, L_x_: "f32[5]"):
3679        l_x_ = L_x_
3680
3681        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3682        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3683
3684        child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
3685
3686        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3687
3688        child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  child_1 = None
3689
3690        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3691
3692        sin: "f32[5]" = child.sin()
3693        o: "f32[]" = sin.sum();  sin = None
3694
3695        aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1);  child = aux = None
3696
3697        results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1);  o = None
3698
3699        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3700        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3701        return (results,)
3702""",
3703        )
3704
3705    def test_vjp_disable_capture(self):
3706        counters.clear()
3707
3708        with config.patch(capture_func_transforms=False):
3709            # We have verified above that this
3710            # function compiles
3711            def wrapper_fn(x):
3712                (out, vjpfunc) = torch.func.vjp(torch.sin, x)
3713                return out
3714
3715            x = torch.randn(3, 3, 3)
3716            actual = wrapper_fn(x)
3717            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
3718                x
3719            )
3720            self.assertEqual(len(counters["graph_break"]), 1)
3721            self.assertEqual(
3722                dict(counters["graph_break"]),
3723                {
3724                    "torch.func.vjp capture is disabled, it can be "
3725                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1
3726                },
3727            )
3728            self.assertEqual(actual, expected)
3729
3730    @config.patch(inline_inbuilt_nn_modules=True)
3731    def test_functional_call(self):
3732        def wrapper_fn(model, params, inputs, targets):
3733            prediction = torch.func.functional_call(model, params, (inputs,))
3734            return torch.nn.functional.mse_loss(prediction, targets)
3735
3736        model = torch.nn.Linear(3, 3)
3737        params = dict(model.named_parameters())
3738        inputs = torch.randn(64, 3)
3739        targets = torch.randn(64, 3)
3740
3741        wrapped_gm = self._compile_check(wrapper_fn, (model, params, inputs, targets))
3742        # Dynamic shapes produce a slightly different graph.
3743        if check_dynamic_shape_capture():
3744            return
3745
3746        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3747        if torch._dynamo.config.inline_inbuilt_nn_modules:
3748            self.assertExpectedInline(
3749                actual,
3750                """\
3751class GraphModule(torch.nn.Module):
3752    def forward(self, L_model_parameters_weight_: "f32[3, 3]", L_model_parameters_bias_: "f32[3]", L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"):
3753        l_model_parameters_weight_ = L_model_parameters_weight_
3754        l_model_parameters_bias_ = L_model_parameters_bias_
3755        l_inputs_ = L_inputs_
3756        l_targets_ = L_targets_
3757
3758        prediction: "f32[64, 3]" = torch._C._nn.linear(l_inputs_, l_model_parameters_weight_, l_model_parameters_bias_);  l_inputs_ = l_model_parameters_weight_ = l_model_parameters_bias_ = None
3759
3760        mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_);  prediction = l_targets_ = None
3761        return (mse_loss,)
3762""",
3763            )
3764        else:
3765            self.assertExpectedInline(
3766                actual,
3767                """\
3768class GraphModule(torch.nn.Module):
3769    def forward(self, L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"):
3770        l_inputs_ = L_inputs_
3771        l_targets_ = L_targets_
3772
3773        prediction: "f32[64, 3]" = self.model(l_inputs_);  l_inputs_ = None
3774
3775        mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_);  prediction = l_targets_ = None
3776        return (mse_loss,)
3777""",
3778            )
3779
3780    @config.patch(inline_inbuilt_nn_modules=True)
3781    def test_functional_call_sequential_params_and_buffers(self):
3782        # copied from test/test_stateless.py
3783        class MockModule(torch.nn.Module):
3784            def __init__(self) -> None:
3785                super().__init__()
3786                self.l1 = torch.nn.Linear(1, 1)
3787                self.register_buffer("buffer", torch.ones(1))
3788                self.foo = 0.0
3789
3790            def forward(self, x):
3791                return self.l1(x) + self.buffer
3792
3793        def wrapper_fn(model, params, buffers, inputs):
3794            # two separate dictionaries
3795            return torch.func.functional_call(model, (params, buffers), inputs)
3796
3797        model = MockModule()
3798        params = dict(model.named_parameters())
3799        buffers = dict(model.named_buffers())
3800        inputs = torch.tensor([[1.5]])
3801
3802        wrapped_gm = self._compile_check(
3803            wrapper_fn, (model, params, buffers, inputs), fullgraph=False
3804        )
3805        # Dynamic shapes produce a slightly different graph.
3806        if check_dynamic_shape_capture():
3807            return
3808
3809        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3810        if torch._dynamo.config.inline_inbuilt_nn_modules:
3811            expected = """\
3812class GraphModule(torch.nn.Module):
3813    def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"):
3814        l_params_l1_weight_ = L_params_l1_weight_
3815        l_params_l1_bias_ = L_params_l1_bias_
3816        l_buffers_buffer_ = L_buffers_buffer_
3817        l_inputs_ = L_inputs_
3818
3819        linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_);  l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None
3820
3821        add: "f32[1, 1]" = linear + l_buffers_buffer_;  linear = l_buffers_buffer_ = None
3822        return (add,)
3823"""
3824            # We found Windows/Linux have some empty line difference, empty_line_normalizer will help fix it.
3825            self.assertExpectedInline(
3826                empty_line_normalizer(actual),
3827                empty_line_normalizer(normalize_gm(expected)),
3828            )
3829        else:
3830            self.assertExpectedInline(
3831                actual,
3832                """\
3833class GraphModule(torch.nn.Module):
3834    def forward(self, L_x_: "f32[1, 1]"):
3835        l_x_ = L_x_
3836
3837        l__self___l1: "f32[1, 1]" = self.L__self___l1(l_x_);  l_x_ = None
3838        l__self___buffer: "f32[1]" = self.L__self___buffer
3839        add: "f32[1, 1]" = l__self___l1 + l__self___buffer;  l__self___l1 = l__self___buffer = None
3840        return (add,)
3841""",
3842            )
3843
3844    @config.patch(inline_inbuilt_nn_modules=True)
3845    def test_functional_call_disable_capture(self):
3846        counters.clear()
3847
3848        with config.patch(capture_func_transforms=False):
3849            # We have verified above that this
3850            # function compiles
3851            def wrapper_fn(model, params, inputs, targets):
3852                prediction = torch.func.functional_call(model, params, (inputs,))
3853                return torch.nn.functional.mse_loss(prediction, targets)
3854
3855            model = torch.nn.Linear(3, 3)
3856            params = dict(model.named_parameters())
3857            inputs = torch.randn(64, 3)
3858            targets = torch.randn(64, 3)
3859
3860            actual = wrapper_fn(model, params, inputs, targets)
3861            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
3862                model, params, inputs, targets
3863            )
3864            self.assertEqual(len(counters["graph_break"]), 1)
3865            self.assertEqual(
3866                {
3867                    "torch.func.functional_call capture is disabled, it can be "
3868                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1,
3869                },
3870                dict(counters["graph_break"]),
3871            )
3872            self.assertEqual(actual, expected)
3873
3874    @config.patch(inline_inbuilt_nn_modules=False)
3875    def test_functional_call_disable_inline_nn_module(self):
3876        counters.clear()
3877
3878        def wrapper_fn(model, params, inputs, targets):
3879            prediction = torch.func.functional_call(model, params, (inputs,))
3880            return torch.nn.functional.mse_loss(prediction, targets)
3881
3882        model = torch.nn.Linear(3, 3)
3883        params = dict(model.named_parameters())
3884        inputs = torch.randn(64, 3)
3885        targets = torch.randn(64, 3)
3886
3887        actual = wrapper_fn(model, params, inputs, targets)
3888        expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
3889            model, params, inputs, targets
3890        )
3891        self.assertEqual(len(counters["graph_break"]), 1)
3892        self.assertEqual(
3893            {
3894                "torch.func.functional_call capture is disabled, it can be "
3895                "turned on by setting `torch._dynamo.config.inline_inbuilt_nn_modules=True`": 1,
3896            },
3897            dict(counters["graph_break"]),
3898        )
3899        self.assertEqual(actual, expected)
3900
3901    def test_grad(self):
3902        counters.clear()
3903
3904        def fn(x):
3905            return x.sin().sum()
3906
3907        def wrapper_fn(x):
3908            return torch.func.grad(fn)(x)
3909
3910        x = torch.randn(3, 3, 3)
3911        wrapped_gm = self._compile_check(wrapper_fn, (x,))
3912
3913        # Dynamic shapes produce a slightly different graph.
3914        if check_dynamic_shape_capture():
3915            return
3916
3917        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3918        self.assertExpectedInline(
3919            actual,
3920            """\
3921class GraphModule(torch.nn.Module):
3922    def forward(self, L_x_: "f32[3, 3, 3]"):
3923        l_x_ = L_x_
3924
3925        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3926        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3927
3928        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
3929
3930        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3931
3932        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
3933
3934        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
3935
3936        sin: "f32[3, 3, 3]" = diff_args.sin()
3937        output: "f32[]" = sin.sum();  sin = None
3938
3939        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
3940        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
3941
3942        grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
3943
3944        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
3945
3946        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
3947        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
3948        return (grad_input_1,)
3949""",
3950        )
3951
3952    def test_grad_freevar_tensor(self):
3953        counters.clear()
3954        y = torch.randn(3, 3)
3955
3956        def fn(x):
3957            return (x.sin() + y).sum()
3958
3959        def wrapper_fn(x):
3960            return torch.func.grad(fn)(x)
3961
3962        x = torch.randn(3, 3, 3)
3963        expected = wrapper_fn(x)
3964        actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
3965        self.assertEqual(actual, expected)
3966
3967    def test_grad_freevar_python_scalar(self):
3968        counters.clear()
3969        y = 3
3970
3971        def fn(x):
3972            return (x.sin() + y).sum()
3973
3974        def wrapper_fn(x):
3975            return torch.func.grad(fn)(x)
3976
3977        x = torch.randn(3, 3, 3)
3978        wrapped_gm = self._compile_check(wrapper_fn, (x,))
3979
3980        # Dynamic shapes produce a slightly different graph.
3981        if check_dynamic_shape_capture():
3982            return
3983
3984        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
3985        self.assertExpectedInline(
3986            actual,
3987            """\
3988class GraphModule(torch.nn.Module):
3989    def forward(self, L_x_: "f32[3, 3, 3]"):
3990        l_x_ = L_x_
3991
3992        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
3993        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
3994
3995        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
3996
3997        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
3998
3999        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
4000
4001        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4002
4003        sin: "f32[3, 3, 3]" = diff_args.sin()
4004        add: "f32[3, 3, 3]" = sin + 3;  sin = None
4005        output: "f32[]" = add.sum();  add = None
4006
4007        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
4008        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
4009
4010        grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
4011
4012        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
4013
4014        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4015        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4016        return (grad_input_1,)
4017""",
4018        )
4019
4020    def test_grad_capture_tensor(self):
4021        counters.clear()
4022
4023        def wrapper_fn(x):
4024            y = torch.randn(3)
4025
4026            def fn(x):
4027                return (x.sin() + y).sum()
4028
4029            return torch.func.grad(fn)(x)
4030
4031        x = torch.randn(3, 3, 3)
4032
4033        wrapped_gm = self._compile_check(wrapper_fn, (x,))
4034
4035        # Dynamic shapes produce a slightly different graph.
4036        if check_dynamic_shape_capture():
4037            return
4038
4039        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4040        self.assertExpectedInline(
4041            actual,
4042            """\
4043class GraphModule(torch.nn.Module):
4044    def forward(self, L_x_: "f32[3, 3, 3]"):
4045        l_x_ = L_x_
4046
4047        y: "f32[3]" = torch.randn(3)
4048
4049        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
4050        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
4051
4052        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
4053
4054        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
4055
4056        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
4057
4058        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4059
4060        sin: "f32[3, 3, 3]" = diff_args.sin()
4061        add: "f32[3, 3, 3]" = sin + y;  sin = None
4062        output: "f32[]" = add.sum();  add = None
4063
4064        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
4065        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
4066
4067        grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
4068
4069        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
4070
4071        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4072        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4073        return (y, grad_input_1)
4074""",
4075        )
4076
4077    def test_grad_closure_scalar(self):
4078        counters.clear()
4079
4080        def wrapper_fn(x):
4081            y = 3.14
4082
4083            def fn(x):
4084                return (x.sin() + y).sum()
4085
4086            return torch.func.grad(fn)(x)
4087
4088        x = torch.randn(3, 3, 3)
4089
4090        # Graph break because dynamo is unable to get source `fn` and
4091        # functools.wraps in `grad` leads to graph-break
4092        wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False)
4093
4094        # Dynamic shapes produce a slightly different graph.
4095        if check_dynamic_shape_capture():
4096            return
4097
4098        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4099        self.assertExpectedInline(
4100            actual,
4101            """\
4102class GraphModule(torch.nn.Module):
4103    def forward(self, L_x_: "f32[3, 3, 3]"):
4104        l_x_ = L_x_
4105
4106        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
4107        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
4108
4109        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
4110
4111        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
4112
4113        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
4114
4115        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4116
4117        sin: "f32[3, 3, 3]" = diff_args.sin()
4118        add: "f32[3, 3, 3]" = sin + 3.14;  sin = None
4119        output: "f32[]" = add.sum();  add = None
4120
4121        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
4122        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
4123
4124        grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
4125
4126        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
4127
4128        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4129        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4130        return (grad_input_1,)
4131""",
4132        )
4133
4134    def test_grad_has_aux(self):
4135        counters.clear()
4136
4137        y = 3.14
4138
4139        def fn(x):
4140            return ((x.sin() + y).sum(), x.cos())
4141
4142        def wrapper_fn(x):
4143            return torch.func.grad(fn, has_aux=True)(x)
4144
4145        x = torch.randn(3, 3, 3)
4146        wrapped_gm = self._compile_check(wrapper_fn, (x,))
4147
4148        # Dynamic shapes produce a slightly different graph.
4149        if check_dynamic_shape_capture():
4150            return
4151
4152        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4153        self.assertExpectedInline(
4154            actual,
4155            """\
4156class GraphModule(torch.nn.Module):
4157    def forward(self, L_x_: "f32[3, 3, 3]"):
4158        l_x_ = L_x_
4159
4160        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
4161        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
4162
4163        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
4164
4165        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
4166
4167        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
4168
4169        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4170
4171        sin: "f32[3, 3, 3]" = diff_args.sin()
4172        add: "f32[3, 3, 3]" = sin + 3.14;  sin = None
4173        output: "f32[]" = add.sum();  add = None
4174        aux: "f32[3, 3, 3]" = diff_args.cos()
4175
4176        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
4177        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
4178
4179        grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
4180
4181        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
4182
4183        aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
4184
4185        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4186        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4187        return (grad_input_1, aux_1)
4188""",
4189        )
4190
4191    def test_grad_two_tensor_has_aux(self):
4192        counters.clear()
4193
4194        def fn(x, y):
4195            return ((x.sin() + y).sum(), x.cos())
4196
4197        def wrapper_fn(x, y):
4198            return torch.func.grad(fn, has_aux=True)(x, y)
4199
4200        y = torch.randn(3, 3, 3)
4201        x = torch.randn(3, 3, 3)
4202        wrapped_gm = self._compile_check(wrapper_fn, (x, y))
4203
4204        # Dynamic shapes produce a slightly different graph.
4205        if check_dynamic_shape_capture():
4206            return
4207
4208        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4209        self.assertExpectedInline(
4210            actual,
4211            """\
4212class GraphModule(torch.nn.Module):
4213    def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
4214        l_x_ = L_x_
4215        l_y_ = L_y_
4216
4217        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
4218        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
4219
4220        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
4221        _wrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
4222
4223        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
4224
4225        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
4226
4227        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4228
4229        sin: "f32[3, 3, 3]" = diff_args.sin()
4230        add: "f32[3, 3, 3]" = sin + _wrap_for_grad_1;  sin = _wrap_for_grad_1 = None
4231        output: "f32[]" = add.sum();  add = None
4232        aux: "f32[3, 3, 3]" = diff_args.cos()
4233
4234        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
4235        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
4236
4237        grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
4238
4239        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
4240
4241        aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
4242
4243        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4244        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4245        return (grad_input_1, aux_1)
4246""",
4247        )
4248
4249    def test_grad_two_tensor_all_grad_has_aux(self):
4250        counters.clear()
4251
4252        nums = (0, 1)
4253
4254        def fn(x, y):
4255            return ((x.sin() + y).sum(), x.cos())
4256
4257        def wrapper_fn_const_var(x, y):
4258            return torch.func.grad(fn, argnums=(0, 1), has_aux=True)(x, y)
4259
4260        def wrapper_fn_tuple_var(x, y):
4261            return torch.func.grad(fn, argnums=nums, has_aux=True)(x, y)
4262
4263        y = torch.randn(3, 3, 3)
4264        x = torch.randn(3, 3, 3)
4265        wrapped_gm_const_var = self._compile_check(wrapper_fn_const_var, (x, y))
4266        wrapped_gm_tuple_var = self._compile_check(wrapper_fn_tuple_var, (x, y))
4267
4268        # Dynamic shapes produce a slightly different graph.
4269        if check_dynamic_shape_capture():
4270            return
4271
4272        actual_const_var = normalize_gm(
4273            wrapped_gm_const_var.print_readable(print_output=False)
4274        )
4275        actual_tuple_var = normalize_gm(
4276            wrapped_gm_tuple_var.print_readable(print_output=False)
4277        )
4278        self.assertExpectedInline(
4279            actual_const_var,
4280            """\
4281class GraphModule(torch.nn.Module):
4282    def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
4283        l_x_ = L_x_
4284        l_y_ = L_y_
4285
4286        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
4287        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
4288
4289        child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
4290        child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
4291
4292        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
4293
4294        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  _set_tensor_requires_grad = None
4295
4296        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4297        set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed_2 = None
4298
4299        _set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1);  _set_tensor_requires_grad_1 = None
4300
4301        set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_3 = None
4302
4303        sin: "f32[3, 3, 3]" = child.sin()
4304        add: "f32[3, 3, 3]" = sin + child_1;  sin = None
4305        output: "f32[]" = add.sum();  add = None
4306        aux: "f32[3, 3, 3]" = child.cos()
4307
4308        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True);  child = child_1 = None
4309        child_2: "f32[3, 3, 3]" = _autograd_grad[0]
4310        child_3: "f32[3, 3, 3]" = _autograd_grad[1];  _autograd_grad = None
4311
4312        _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1);  child_2 = None
4313        _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1);  child_3 = None
4314
4315        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
4316
4317        aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
4318
4319        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4320        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4321        return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1)
4322""",
4323        )
4324        self.assertExpectedInline(
4325            actual_tuple_var,
4326            """\
4327class GraphModule(torch.nn.Module):
4328    def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
4329        l_x_ = L_x_
4330        l_y_ = L_y_
4331
4332        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
4333        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
4334
4335        child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
4336        child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
4337
4338        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
4339
4340        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  _set_tensor_requires_grad = None
4341
4342        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4343        set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed_2 = None
4344
4345        _set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1);  _set_tensor_requires_grad_1 = None
4346
4347        set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_3 = None
4348
4349        sin: "f32[3, 3, 3]" = child.sin()
4350        add: "f32[3, 3, 3]" = sin + child_1;  sin = None
4351        output: "f32[]" = add.sum();  add = None
4352        aux: "f32[3, 3, 3]" = child.cos()
4353
4354        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True);  child = child_1 = None
4355        child_2: "f32[3, 3, 3]" = _autograd_grad[0]
4356        child_3: "f32[3, 3, 3]" = _autograd_grad[1];  _autograd_grad = None
4357
4358        _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1);  child_2 = None
4359        _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1);  child_3 = None
4360
4361        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
4362
4363        aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
4364
4365        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4366        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4367        return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1)
4368""",
4369        )
4370
4371    def test_grad_over_grad(self):
4372        counters.clear()
4373
4374        def fn(x):
4375            return x.sin().sum()
4376
4377        def wrapper_fn(x):
4378            return torch.func.grad(torch.func.grad(fn))(x)
4379
4380        x = torch.randn(())
4381        wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False)
4382
4383        if check_dynamic_shape_capture():
4384            return
4385
4386        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4387        self.assertExpectedInline(
4388            actual,
4389            """\
4390class GraphModule(torch.nn.Module):
4391    def forward(self, L_x_: "f32[]"):
4392        l_x_ = L_x_
4393
4394        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
4395        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
4396
4397        diff_args: "f32[]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
4398
4399        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
4400
4401        _set_tensor_requires_grad: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
4402
4403        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4404        _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable_1 = None
4405        _grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting_1 = None
4406
4407        diff_args_1: "f32[]" = torch._C._functorch._wrap_for_grad(diff_args, 2)
4408
4409        set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed_2 = None
4410
4411        _set_tensor_requires_grad_1: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1);  _set_tensor_requires_grad_1 = None
4412
4413        set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_3 = None
4414
4415        sin: "f32[]" = diff_args_1.sin()
4416        output: "f32[]" = sin.sum();  sin = None
4417
4418        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True);  diff_args_1 = None
4419        grad_input: "f32[]" = _autograd_grad[0];  _autograd_grad = None
4420
4421        grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2);  grad_input = None
4422
4423        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2);  output = output_1 = None
4424
4425        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4426        _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable_2 = None
4427
4428        _autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((grad_input_1,), [diff_args], create_graph = True);  diff_args = None
4429        grad_input_2: "f32[]" = _autograd_grad_1[0];  _autograd_grad_1 = None
4430
4431        grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1);  grad_input_2 = None
4432
4433        output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1);  grad_input_1 = output_2 = None
4434
4435        _grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting_1 = None
4436        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4437        return (grad_input_3,)
4438""",
4439        )
4440
4441    def test_grad_with_graph_break(self):
4442        counters.clear()
4443
4444        def fn(x):
4445            torch._dynamo.graph_break()
4446            return x.sin().sum()
4447
4448        def wrapper_fn(x):
4449            return torch.func.grad(fn)(x)
4450
4451        x = torch.randn(3, 3, 3)
4452        actual = wrapper_fn(x)
4453        expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
4454        self.assertEqual(len(counters["graph_break"]), 1)
4455        self.assertEqual(actual, expected)
4456
4457    def test_grad_with_side_effect(self):
4458        counters.clear()
4459
4460        foo = [1, 2]
4461
4462        def fn(x):
4463            foo.append(3)
4464            return x.sin().sum()
4465
4466        def wrapper_fn(x):
4467            return torch.func.grad(fn)(x)
4468
4469        x = torch.randn(3, 3, 3)
4470        actual = wrapper_fn(x)
4471        expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
4472        self.assertEqual(len(counters["graph_break"]), 0)
4473        self.assertEqual(actual, expected)
4474
4475    def test_grad_pytree(self):
4476        counters.clear()
4477
4478        def fn(x):
4479            x1, x2 = x
4480            return x1.sin().sum() + x2
4481
4482        def wrapper_fn(x):
4483            return torch.func.grad(fn)(x)
4484
4485        x1 = torch.randn(3, 3, 3)
4486        x2 = torch.randn(())
4487        actual = wrapper_fn((x1, x2))
4488        expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
4489            (x1, x2)
4490        )
4491        self.assertEqual(len(counters["graph_break"]), 0)
4492        self.assertEqual(actual, expected)
4493
4494    def test_grad_non_tensor_input(self):
4495        counters.clear()
4496
4497        def fn(x, y):
4498            return x.sin().sum() + y
4499
4500        def wrapper_fn(x, y):
4501            return torch.func.grad(fn)(x, y)
4502
4503        x = torch.randn(3, 3, 3)
4504        y = 3.0
4505        wrapped_gm = self._compile_check(wrapper_fn, (x, y))
4506
4507        # Dynamic shapes produce a slightly different graph.
4508        if check_dynamic_shape_capture():
4509            return
4510
4511        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4512        self.assertExpectedInline(
4513            actual,
4514            """\
4515class GraphModule(torch.nn.Module):
4516    def forward(self, L_x_: "f32[3, 3, 3]"):
4517        l_x_ = L_x_
4518
4519        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
4520        _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
4521
4522        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
4523
4524        set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
4525
4526        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
4527
4528        set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
4529
4530        sin: "f32[3, 3, 3]" = diff_args.sin()
4531        sum_1: "f32[]" = sin.sum();  sin = None
4532        output: "f32[]" = sum_1 + 3.0;  sum_1 = None
4533
4534        _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
4535        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
4536
4537        grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
4538
4539        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1);  output = output_1 = None
4540
4541        _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
4542        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
4543        return (grad_input_1,)
4544""",
4545        )
4546
4547    def test_grad_disable_capture(self):
4548        counters.clear()
4549
4550        with config.patch(capture_func_transforms=False):
4551            # We have verified above that this
4552            # function compiles
4553            def fn(x):
4554                return x.sin().sum()
4555
4556            def wrapper_fn(x):
4557                return torch.func.grad(fn)(x)
4558
4559            x = torch.randn(3, 3)
4560            actual = wrapper_fn(x)
4561            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
4562                x
4563            )
4564            self.assertEqual(len(counters["graph_break"]), 1)
4565            self.assertEqual(
4566                dict(counters["graph_break"]),
4567                {
4568                    "torch.func.grad capture is disabled, it can be turned "
4569                    "on by setting `torch._dynamo.config.capture_func_transforms=True`": 2
4570                },
4571            )
4572            self.assertEqual(actual, expected)
4573
4574    def test_grad_fn_with_kwargs(self):
4575        def fn(x, y):
4576            return (x + y).sum()
4577
4578        def wrapper_fn(x, y):
4579            return torch.func.grad(fn)(x, y=y)
4580
4581        x = torch.randn(3, 3)
4582        y = torch.randn(3, 3)
4583        actual = wrapper_fn(x, y)
4584        expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
4585        self.assertEqual(len(counters["graph_break"]), 0)
4586        self.assertEqual(actual, expected)
4587
4588    def test_jacfwd(self):
4589        counters.clear()
4590
4591        def wrapper_fn(x):
4592            return torch.func.jacfwd(torch.sin)(x)
4593
4594        x = torch.randn(4, 3)
4595        wrapped_gm = self._compile_check(wrapper_fn, (x,))
4596        # Dynamic shapes produce a slightly different graph.
4597        if check_dynamic_shape_capture():
4598            return
4599
4600        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4601        self.assertExpectedInline(
4602            actual,
4603            """\
4604class GraphModule(torch.nn.Module):
4605    def forward(self, L_x_: "f32[4, 3]"):
4606        l_x_ = L_x_
4607
4608        tensor: "i64[1]" = torch.tensor((12,))
4609        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
4610        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
4611        neg: "i64[0]" = getitem.neg();  getitem = None
4612        unbind = neg.unbind();  neg = unbind = None
4613
4614        chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
4615
4616        diagonal: "f32[12]" = chunk.diagonal(0)
4617        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
4618
4619        child: "f32[12, 4, 3]" = chunk.view(12, 4, 3);  chunk = None
4620
4621        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
4622
4623        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
4624
4625        child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
4626
4627        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,));  _jvp_treespec_compare = None
4628
4629        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
4630        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
4631        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
4632
4633        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
4634
4635        _make_dual: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
4636
4637        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
4638
4639        result_duals: "f32[4, 3]" = torch.sin(_make_dual);  _make_dual = None
4640
4641        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
4642        primal: "f32[4, 3]" = _unpack_dual[0]
4643        dual: "f32[4, 3]" = _unpack_dual[1];  _unpack_dual = None
4644
4645        primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = primals_out_unflatten = None
4646
4647        tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
4648
4649        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
4650        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
4651        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
4652
4653        results: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0);  tangents_out_unflatten = None
4654
4655        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
4656
4657        movedim: "f32[4, 3, 12]" = results.movedim(0, -1);  results = None
4658        split = movedim.split((12,), dim = -1);  movedim = None
4659        jac_out_in: "f32[4, 3, 12]" = split[0];  split = None
4660
4661        unflatten: "f32[4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3));  jac_out_in = None
4662        return (unflatten,)
4663""",
4664        )
4665
4666    def test_jacfwd_two_tensors_argnums(self):
4667        counters.clear()
4668
4669        def fn(x, y):
4670            return y.sin()
4671
4672        def wrapper_fn(x, y):
4673            return torch.func.jacfwd(fn, argnums=1)(x, y)
4674
4675        x = torch.randn(4, 3)
4676        y = torch.randn(3, 4)
4677        wrapped_gm = self._compile_check(wrapper_fn, (x, y))
4678        # Dynamic shapes produce a slightly different graph.
4679        if check_dynamic_shape_capture():
4680            return
4681
4682        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4683        self.assertExpectedInline(
4684            actual,
4685            """\
4686class GraphModule(torch.nn.Module):
4687    def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
4688        l_x_ = L_x_
4689        l_y_ = L_y_
4690
4691        tensor: "i64[1]" = torch.tensor((12,))
4692        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
4693        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
4694        neg: "i64[0]" = getitem.neg();  getitem = None
4695        unbind = neg.unbind();  neg = unbind = None
4696
4697        chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
4698
4699        diagonal: "f32[12]" = chunk.diagonal(0)
4700        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
4701
4702        child: "f32[12, 3, 4]" = chunk.view(12, 3, 4);  chunk = None
4703
4704        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
4705
4706        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
4707
4708        child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
4709
4710        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,));  _jvp_treespec_compare = None
4711
4712        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
4713        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
4714        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
4715
4716        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
4717
4718        _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
4719
4720        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
4721        _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
4722
4723        result_duals: "f32[3, 4]" = _make_dual.sin();  _make_dual = None
4724
4725        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
4726        primal: "f32[3, 4]" = _unpack_dual[0]
4727        dual: "f32[3, 4]" = _unpack_dual[1];  _unpack_dual = None
4728
4729        primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = primals_out_unflatten = None
4730
4731        tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
4732
4733        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
4734        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
4735        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
4736
4737        results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0);  tangents_out_unflatten = None
4738
4739        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
4740
4741        movedim: "f32[3, 4, 12]" = results.movedim(0, -1);  results = None
4742        split = movedim.split((12,), dim = -1);  movedim = None
4743        jac_out_in: "f32[3, 4, 12]" = split[0];  split = None
4744
4745        unflatten: "f32[3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4));  jac_out_in = None
4746        return (unflatten,)
4747""",
4748        )
4749
4750    def test_jacfwd_has_aux(self):
4751        counters.clear()
4752
4753        def fn(x, y):
4754            return y.sin(), x
4755
4756        def wrapper_fn(x, y):
4757            return torch.func.jacfwd(fn, argnums=1, has_aux=True)(x, y)
4758
4759        x = torch.randn(4, 3)
4760        y = torch.randn(3, 4)
4761        wrapped_gm = self._compile_check(wrapper_fn, (x, y))
4762        # Dynamic shapes produce a slightly different graph.
4763        if check_dynamic_shape_capture():
4764            return
4765
4766        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4767        self.assertExpectedInline(
4768            actual,
4769            """\
4770class GraphModule(torch.nn.Module):
4771    def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
4772        l_x_ = L_x_
4773        l_y_ = L_y_
4774
4775        tensor: "i64[1]" = torch.tensor((12,))
4776        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
4777        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
4778        neg: "i64[0]" = getitem.neg();  getitem = None
4779        unbind = neg.unbind();  neg = unbind = None
4780
4781        chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
4782
4783        diagonal: "f32[12]" = chunk.diagonal(0)
4784        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
4785
4786        child: "f32[12, 3, 4]" = chunk.view(12, 3, 4);  chunk = None
4787
4788        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
4789
4790        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
4791
4792        child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
4793
4794        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,));  _jvp_treespec_compare = None
4795
4796        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
4797        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
4798        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
4799
4800        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
4801
4802        _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
4803
4804        aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = None
4805        _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
4806
4807        result_duals: "f32[3, 4]" = _make_dual.sin();  _make_dual = None
4808
4809        aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 2);  aux = None
4810
4811        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
4812        primal: "f32[3, 4]" = _unpack_dual[0]
4813        dual: "f32[3, 4]" = _unpack_dual[1];  _unpack_dual = None
4814
4815        primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = primals_out_unflatten = None
4816
4817        tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
4818
4819        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
4820        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
4821        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
4822
4823        results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0);  tangents_out_unflatten = None
4824        aux_2: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(aux_1, 1, 12, 0);  aux_1 = None
4825
4826        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
4827
4828        aux_3: "f32[4, 3]" = aux_2[0];  aux_2 = None
4829
4830        movedim: "f32[3, 4, 12]" = results.movedim(0, -1);  results = None
4831        split = movedim.split((12,), dim = -1);  movedim = None
4832        jac_out_in: "f32[3, 4, 12]" = split[0];  split = None
4833
4834        unflatten: "f32[3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4));  jac_out_in = None
4835        return (unflatten, aux_3)
4836""",
4837        )
4838
4839    def test_jacfwd_randomness(self):
4840        counters.clear()
4841
4842        def fn(x, y):
4843            return y.sin(), x
4844
4845        def wrapper_fn(x, y):
4846            return torch.func.jacfwd(fn, randomness="same")(x, y)
4847
4848        x = torch.randn(4, 3)
4849        y = torch.randn(3, 4)
4850        wrapped_gm = self._compile_check(wrapper_fn, (x, y))
4851        # Dynamic shapes produce a slightly different graph.
4852        if check_dynamic_shape_capture():
4853            return
4854
4855        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4856        self.assertExpectedInline(
4857            actual,
4858            """\
4859class GraphModule(torch.nn.Module):
4860    def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
4861        l_x_ = L_x_
4862        l_y_ = L_y_
4863
4864        tensor: "i64[1]" = torch.tensor((12,))
4865        cumsum: "i64[1]" = tensor.cumsum(dim = 0);  tensor = None
4866        getitem: "i64[0]" = cumsum[slice(None, -1, None)];  cumsum = None
4867        neg: "i64[0]" = getitem.neg();  getitem = None
4868        unbind = neg.unbind();  neg = unbind = None
4869
4870        chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
4871
4872        diagonal: "f32[12]" = chunk.diagonal(0)
4873        fill_: "f32[12]" = diagonal.fill_(1);  diagonal = fill_ = None
4874
4875        child: "f32[12, 4, 3]" = chunk.view(12, 4, 3);  chunk = None
4876
4877        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
4878
4879        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same');  _vmap_increment_nesting = None
4880
4881        child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
4882
4883        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,));  _jvp_treespec_compare = None
4884
4885        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
4886        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
4887        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
4888
4889        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
4890
4891        child_3: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
4892
4893        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
4894        _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = None
4895
4896        child_2: "f32[3, 4]" = _wrap_for_grad_1.sin();  _wrap_for_grad_1 = None
4897
4898        _unpack_dual = torch._unpack_dual(child_2, level = 0);  child_2 = None
4899        primal: "f32[3, 4]" = _unpack_dual[0];  _unpack_dual = None
4900
4901        tangent: "f32[3, 4]" = torch.zeros_like(primal)
4902
4903        _unpack_dual_1 = torch._unpack_dual(child_3, level = 0);  child_3 = None
4904        primal_1: "f32[4, 3]" = _unpack_dual_1[0]
4905        dual: "f32[4, 3]" = _unpack_dual_1[1];  _unpack_dual_1 = None
4906
4907        child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = child_4 = None
4908        child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2);  primal_1 = child_5 = None
4909
4910        child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2);  tangent = None
4911        child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
4912
4913        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
4914        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
4915        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
4916
4917        child_8: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_6, 1, 12, 0);  child_6 = None
4918        child_9: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(child_7, 1, 12, 0);  child_7 = None
4919
4920        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
4921
4922        movedim: "f32[3, 4, 12]" = child_8.movedim(0, -1);  child_8 = None
4923        split = movedim.split((12,), dim = -1);  movedim = None
4924        jac_out_in: "f32[3, 4, 12]" = split[0];  split = None
4925
4926        unflatten: "f32[3, 4, 4, 3]" = jac_out_in.unflatten(-1, (4, 3));  jac_out_in = None
4927
4928        movedim_1: "f32[4, 3, 12]" = child_9.movedim(0, -1);  child_9 = None
4929        split_1 = movedim_1.split((12,), dim = -1);  movedim_1 = None
4930        jac_out_in_1: "f32[4, 3, 12]" = split_1[0];  split_1 = None
4931
4932        unflatten_1: "f32[4, 3, 4, 3]" = jac_out_in_1.unflatten(-1, (4, 3));  jac_out_in_1 = None
4933        return (unflatten, unflatten_1)
4934""",
4935        )
4936
4937    def test_jacfwd_disable_capture(self):
4938        counters.clear()
4939
4940        with config.patch(capture_func_transforms=False):
4941            # We have verified above that this
4942            # function compiles
4943            def wrapper_fn(x):
4944                return torch.func.jacfwd(torch.sin)(x)
4945
4946            x = torch.randn(3, 3, 3)
4947            actual = wrapper_fn(x)
4948            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
4949                x
4950            )
4951            self.assertEqual(len(counters["graph_break"]), 2)
4952            self.assertEqual(
4953                dict(counters["graph_break"]),
4954                {
4955                    "torch.func.vmap capture is disabled, it can be "
4956                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 2,
4957                    "torch.func.jacfwd capture is disabled, it can be "
4958                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1,
4959                },
4960            )
4961            self.assertEqual(actual, expected)
4962
4963    def test_jvp_simple(self):
4964        counters.clear()
4965
4966        def fn(x):
4967            return x.sin().sum()
4968
4969        def wrapper_fn(x, v):
4970            return torch.func.jvp(fn, (x,), (v,))
4971
4972        x = torch.randn(3, 3)
4973        v = torch.randn(3, 3)
4974        wrapped_gm = self._compile_check(wrapper_fn, (x, v))
4975
4976        # Dynamic shapes produce a slightly different graph.
4977        if check_dynamic_shape_capture():
4978            return
4979
4980        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
4981        self.assertExpectedInline(
4982            actual,
4983            """\
4984class GraphModule(torch.nn.Module):
4985    def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
4986        l_x_ = L_x_
4987        l_v_ = L_v_
4988
4989        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,));  _jvp_treespec_compare = None
4990
4991        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
4992        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
4993        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
4994
4995        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
4996
4997        _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
4998
4999        sin: "f32[3, 3]" = _make_dual.sin();  _make_dual = None
5000        result_duals: "f32[]" = sin.sum();  sin = None
5001
5002        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
5003        primal: "f32[]" = _unpack_dual[0]
5004        dual: "f32[]" = _unpack_dual[1];  _unpack_dual = None
5005
5006        primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
5007
5008        tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1);  dual = None
5009
5010        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
5011        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
5012        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
5013        return (primals_out_unflatten, tangents_out_unflatten)
5014""",
5015        )
5016
5017    def test_jvp_has_aux(self):
5018        counters.clear()
5019
5020        def fn(x):
5021            return x.sin().sum(), x
5022
5023        def wrapper_fn(x, v):
5024            return torch.func.jvp(fn, (x,), (v,), has_aux=True)
5025
5026        x = torch.randn(3, 3)
5027        v = torch.randn(3, 3)
5028        wrapped_gm = self._compile_check(wrapper_fn, (x, v))
5029
5030        # Dynamic shapes produce a slightly different graph.
5031        if check_dynamic_shape_capture():
5032            return
5033
5034        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5035        self.assertExpectedInline(
5036            actual,
5037            """\
5038class GraphModule(torch.nn.Module):
5039    def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
5040        l_x_ = L_x_
5041        l_v_ = L_v_
5042
5043        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,));  _jvp_treespec_compare = None
5044
5045        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
5046        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
5047        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
5048
5049        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
5050
5051        aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
5052
5053        sin: "f32[3, 3]" = aux.sin()
5054        result_duals: "f32[]" = sin.sum();  sin = None
5055
5056        aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
5057
5058        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
5059        primal: "f32[]" = _unpack_dual[0]
5060        dual: "f32[]" = _unpack_dual[1];  _unpack_dual = None
5061
5062        primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
5063
5064        tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1);  dual = None
5065
5066        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
5067        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
5068        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
5069        return (primals_out_unflatten, tangents_out_unflatten, aux_1)
5070""",
5071        )
5072
5073    def test_jvp_two_tensors_has_aux(self):
5074        counters.clear()
5075
5076        def fn(x, y):
5077            return (x.sin().sum() + y.cos()), x
5078
5079        def wrapper_fn(x, y, v):
5080            return torch.func.jvp(fn, (x, y), (v, v), has_aux=True)
5081
5082        x = torch.randn(3, 3)
5083        y = torch.randn(3, 3)
5084        v = torch.randn(3, 3)
5085        wrapped_gm = self._compile_check(wrapper_fn, (x, y, v))
5086
5087        # Dynamic shapes produce a slightly different graph.
5088        if check_dynamic_shape_capture():
5089            return
5090
5091        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5092        self.assertExpectedInline(
5093            actual,
5094            """\
5095class GraphModule(torch.nn.Module):
5096    def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]", L_v_: "f32[3, 3]"):
5097        l_x_ = L_x_
5098        l_y_ = L_y_
5099        l_v_ = L_v_
5100
5101        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_, l_y_), (l_v_, l_v_));  _jvp_treespec_compare = None
5102
5103        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
5104        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
5105        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
5106
5107        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
5108
5109        aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = None
5110
5111        _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions_1 = None
5112
5113        _make_dual_1: "f32[3, 3]" = torch._make_dual(l_y_, l_v_, level = 0);  l_y_ = l_v_ = None
5114
5115        sin: "f32[3, 3]" = aux.sin()
5116        sum_1: "f32[]" = sin.sum();  sin = None
5117        cos: "f32[3, 3]" = _make_dual_1.cos();  _make_dual_1 = None
5118        result_duals: "f32[3, 3]" = sum_1 + cos;  sum_1 = cos = None
5119
5120        aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
5121
5122        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
5123        primal: "f32[3, 3]" = _unpack_dual[0]
5124        dual: "f32[3, 3]" = _unpack_dual[1];  _unpack_dual = None
5125
5126        primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
5127
5128        tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1);  dual = None
5129
5130        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
5131        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
5132        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
5133        return (primals_out_unflatten, tangents_out_unflatten, aux_1)
5134""",
5135        )
5136
5137    def test_jvp_two_tensors_disable_grad(self):
5138        counters.clear()
5139
5140        def fn(x):
5141            return x.sin().sum()
5142
5143        def wrapper_fn(x, v):
5144            with torch.autograd.forward_ad._set_fwd_grad_enabled(False):
5145                return torch.func.jvp(fn, (x,), (v,))
5146
5147        x = torch.randn(3, 3)
5148        v = torch.randn(3, 3)
5149        wrapped_gm = self._compile_check(wrapper_fn, (x, v))
5150
5151        # Dynamic shapes produce a slightly different graph.
5152        if check_dynamic_shape_capture():
5153            return
5154
5155        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5156        self.assertExpectedInline(
5157            actual,
5158            """\
5159class GraphModule(torch.nn.Module):
5160    def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
5161        l_x_ = L_x_
5162        l_v_ = L_v_
5163
5164        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False);  _set_fwd_grad_enabled = None
5165
5166        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,));  _jvp_treespec_compare = None
5167
5168        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
5169        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
5170        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
5171
5172        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
5173
5174        _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
5175
5176        sin: "f32[3, 3]" = _make_dual.sin();  _make_dual = None
5177        result_duals: "f32[]" = sin.sum();  sin = None
5178
5179        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
5180        primal: "f32[]" = _unpack_dual[0]
5181        dual: "f32[]" = _unpack_dual[1];  _unpack_dual = None
5182
5183        primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
5184
5185        tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1);  dual = None
5186
5187        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
5188        _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False);  _set_fwd_grad_enabled_2 = None
5189        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
5190        _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_3 = None
5191        return (primals_out_unflatten, tangents_out_unflatten)
5192""",
5193        )
5194
5195    def test_jvp_two_tensors_disable_enable_disable_grad(self):
5196        counters.clear()
5197
5198        def fn(x):
5199            return x.sin().sum()
5200
5201        def wrapper_fn(x, v):
5202            with torch.autograd.forward_ad._set_fwd_grad_enabled(False):  # (1)
5203                with torch.autograd.forward_ad._set_fwd_grad_enabled(True):  # (2)
5204                    with torch.autograd.forward_ad._set_fwd_grad_enabled(False):  # (3)
5205                        return torch.func.jvp(fn, (x,), (v,))  # (4)
5206
5207            # Start True
5208            # False      (1)
5209            #   True     (2)
5210            #     False  (3)
5211            #       True (4)
5212            #     True   (undo 3)
5213            #   False    (undo 2)
5214            # True       (undo 1)
5215
5216        x = torch.randn(3, 3)
5217        v = torch.randn(3, 3)
5218        wrapped_gm = self._compile_check(wrapper_fn, (x, v))
5219
5220        # Dynamic shapes produce a slightly different graph.
5221        if check_dynamic_shape_capture():
5222            return
5223
5224        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5225        self.assertExpectedInline(
5226            actual,
5227            """\
5228class GraphModule(torch.nn.Module):
5229    def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
5230        l_x_ = L_x_
5231        l_v_ = L_v_
5232
5233        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False);  _set_fwd_grad_enabled = None
5234        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
5235        _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False);  _set_fwd_grad_enabled_2 = None
5236
5237        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,));  _jvp_treespec_compare = None
5238
5239        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
5240        _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_3 = None
5241        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
5242
5243        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
5244
5245        _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
5246
5247        sin: "f32[3, 3]" = _make_dual.sin();  _make_dual = None
5248        result_duals: "f32[]" = sin.sum();  sin = None
5249
5250        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
5251        primal: "f32[]" = _unpack_dual[0]
5252        dual: "f32[]" = _unpack_dual[1];  _unpack_dual = None
5253
5254        primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
5255
5256        tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1);  dual = None
5257
5258        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
5259        _set_fwd_grad_enabled_4 = torch._C._set_fwd_grad_enabled(False);  _set_fwd_grad_enabled_4 = None
5260        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
5261        _set_fwd_grad_enabled_5 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_5 = None
5262        _set_fwd_grad_enabled_6 = torch._C._set_fwd_grad_enabled(False);  _set_fwd_grad_enabled_6 = None
5263        _set_fwd_grad_enabled_7 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_7 = None
5264        return (primals_out_unflatten, tangents_out_unflatten)
5265""",
5266        )
5267
5268    def test_jvp_freevar_tensor(self):
5269        counters.clear()
5270        y = torch.randn(3, 3)
5271
5272        def fn(x):
5273            return (x.sin() + y).sum()
5274
5275        def wrapper_fn(x):
5276            return torch.func.jvp(fn, (x,), (x,))
5277
5278        x = torch.randn(3, 3)
5279        expected = wrapper_fn(x)
5280        actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
5281        self.assertEqual(actual, expected)
5282
5283    def test_jvp_jvp(self):
5284        counters.clear()
5285
5286        if check_dynamic_shape_capture():
5287            self.skipTest("test fails with dynamic shapes")
5288
5289        def fn(x):
5290            return torch.func.jvp(torch.sin, (x,), (x,))
5291
5292        def wrapper_fn(x):
5293            return torch.func.jvp(fn, (x,), (x,))
5294
5295        x = torch.randn(3, 3, 3)
5296        wrapped_gm = self._compile_check(wrapper_fn, (x,))
5297
5298        # Dynamic shapes produce a slightly different graph.
5299        if check_dynamic_shape_capture():
5300            return
5301
5302        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5303        self.assertExpectedInline(
5304            actual,
5305            """\
5306class GraphModule(torch.nn.Module):
5307    def forward(self, L_x_: "f32[3, 3, 3]"):
5308        l_x_ = L_x_
5309
5310        _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_x_,));  _jvp_treespec_compare = None
5311
5312        _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting = None
5313        _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled = None
5314        _enter_dual_level = torch._C._enter_dual_level();  _enter_dual_level = None
5315
5316        _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
5317
5318        child: "f32[3, 3, 3]" = torch._make_dual(l_x_, l_x_, level = 0);  l_x_ = None
5319
5320        _jvp_treespec_compare_1 = torch._functorch.eager_transforms._jvp_treespec_compare((child,), (child,));  _jvp_treespec_compare_1 = None
5321
5322        _jvp_increment_nesting_1 = torch._C._functorch._jvp_increment_nesting();  _jvp_increment_nesting_1 = None
5323        _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
5324
5325        _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions_1 = None
5326
5327        _make_dual_1: "f32[3, 3, 3]" = torch._make_dual(child, child, level = 0);  child = None
5328
5329        result_duals: "f32[3, 3, 3]" = torch.sin(_make_dual_1);  _make_dual_1 = None
5330
5331        _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
5332        primal: "f32[3, 3, 3]" = _unpack_dual[0]
5333        dual: "f32[3, 3, 3]" = _unpack_dual[1];  _unpack_dual = None
5334
5335        primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = None
5336
5337        tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
5338
5339        _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_2 = None
5340        _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
5341
5342        _unpack_dual_1 = torch._unpack_dual(primals_out_unflatten, level = 0);  primals_out_unflatten = None
5343        primal_1: "f32[3, 3, 3]" = _unpack_dual_1[0]
5344        dual_1: "f32[3, 3, 3]" = _unpack_dual_1[1];  _unpack_dual_1 = None
5345        _unpack_dual_2 = torch._unpack_dual(tangents_out_unflatten, level = 0);  tangents_out_unflatten = None
5346        primal_2: "f32[3, 3, 3]" = _unpack_dual_2[0]
5347        dual_2: "f32[3, 3, 3]" = _unpack_dual_2[1];  _unpack_dual_2 = None
5348
5349        _unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1);  primal_1 = None
5350        _unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1);  primal_2 = None
5351
5352        _unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1);  dual_1 = None
5353        _unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1);  dual_2 = None
5354
5355        _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
5356        _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_3 = None
5357        _jvp_decrement_nesting_1 = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting_1 = None
5358        return (_unwrap_for_grad_2, _unwrap_for_grad_3, _unwrap_for_grad_4, _unwrap_for_grad_5)
5359""",
5360        )
5361
5362    def test_jvp_freevar_python_scalar(self):
5363        counters.clear()
5364        y = 3
5365
5366        def fn(x):
5367            return (x.sin() + y).sum()
5368
5369        def wrapper_fn(x):
5370            return torch.func.jvp(fn, (x,), (x,))
5371
5372        x = torch.randn(3, 3, 3)
5373        expected = wrapper_fn(x)
5374        actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
5375        self.assertEqual(actual, expected)
5376
5377    def test_jvp_disable_capture(self):
5378        counters.clear()
5379
5380        with config.patch(capture_func_transforms=False):
5381            # We have verified above that this
5382            # function compiles
5383            def wrapper_fn(x):
5384                return torch.func.jvp(torch.sin, (x,), (x,))
5385
5386            x = torch.randn(3, 3, 3)
5387            actual = wrapper_fn(x)
5388            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
5389                x
5390            )
5391            self.assertEqual(len(counters["graph_break"]), 1)
5392            self.assertEqual(
5393                dict(counters["graph_break"]),
5394                {
5395                    "torch.func.jvp capture is disabled, it can be "
5396                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1
5397                },
5398            )
5399        self.assertEqual(actual, expected)
5400
5401    @config.patch(capture_func_transforms=True)
5402    def test_linearize_jvp_fn(self):
5403        counters.clear()
5404
5405        def wrapper_fn(x):
5406            output, jvp_fn = torch.func.linearize(torch.sin, x)
5407            return output, jvp_fn(x)
5408
5409        x = torch.randn(3, 3, 3)
5410        wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False, graph_idx=0)
5411
5412        # Dynamic shapes produce a slightly different graph.
5413        if check_dynamic_shape_capture():
5414            return
5415
5416        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5417        self.assertExpectedInline(
5418            actual,
5419            """\
5420class GraphModule(torch.nn.Module):
5421    def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"):
5422        l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_
5423
5424        alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_);  l_self_buffers_tensor_constant0_ = None
5425
5426        sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default)
5427
5428        alias_default_1: "f32[3, 3, 3]" = torch.ops.aten.alias.default(alias_default)
5429
5430        cos_default: "f32[3, 3, 3]" = torch.ops.aten.cos.default(alias_default_1);  alias_default_1 = None
5431
5432        alias_default_2: "f32[3, 3, 3]" = torch.ops.aten.alias.default(sin_default);  alias_default_2 = None
5433        return (alias_default, cos_default, sin_default)
5434""",
5435        )
5436
5437        wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False, graph_idx=1)
5438        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5439        self.assertExpectedInline(
5440            actual,
5441            """\
5442class GraphModule(torch.nn.Module):
5443    def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
5444        l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_
5445        l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_
5446        l_flat_tangents_1_ = L_flat_tangents_1_
5447
5448        _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_);  l_self_modules_fx_const_folded_attrs_parameters_0_ = None
5449
5450        copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_);  _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None
5451
5452        mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_);  copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None
5453        return (mul_tensor,)
5454""",
5455        )
5456
5457    def test_linearize_disable_capture(self):
5458        counters.clear()
5459        with config.patch(capture_func_transforms=False):
5460            # We have verified above that this
5461            # function compiles
5462            def wrapper_fn(x):
5463                out, _ = torch.func.linearize(torch.sin, x)
5464                return out
5465
5466            x = torch.randn(2, 3)
5467            actual = wrapper_fn(x)
5468            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
5469                x
5470            )
5471            self.assertEqual(len(counters["graph_break"]), 1)
5472            self.assertEqual(
5473                {
5474                    "torch.func.linearize capture is disabled, it can be "
5475                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1,
5476                },
5477                dict(counters["graph_break"]),
5478            )
5479            self.assertEqual(actual, expected)
5480
5481    @config.patch(capture_func_transforms=True)
5482    @config.patch(error_on_recompile=True)
5483    def test_vmap_recompile(self):
5484        @torch.compile(backend="eager")
5485        def fn(x):
5486            return torch.vmap(lambda x: x.sin())(x)
5487
5488        x = torch.zeros(3, 3, 4, 5)
5489        y = torch.vmap(fn)(x)
5490        # should not recompile on second call. See Pytorch issue #118493
5491        y = torch.vmap(fn)(x)
5492
5493    @xfailIfTorchDynamo
5494    @config.patch(error_on_recompile=True)
5495    def test_vmap_recompile_different_config(self):
5496        @torch.compile(backend="eager")
5497        def fn(x):
5498            return torch.vmap(lambda x: x.sin())(x)
5499
5500        x = torch.zeros(3, 3, 4, 5)
5501        y = torch.vmap(fn)(x)
5502        with self.assertRaises(torch._dynamo.exc.RecompileError):
5503            fn(x)
5504
5505    @config.patch(error_on_recompile=True)
5506    def test_vmap_recompile_same_config(self):
5507        @torch.compile(backend="eager")
5508        def fn(x):
5509            return torch.vmap(lambda x: x.sin())(x)
5510
5511        x = torch.zeros(3, 3, 4, 5)
5512        torch.vmap(torch.vmap(fn, randomness="same"), randomness="same")(x)
5513        with self.assertRaises(torch._dynamo.exc.RecompileError):
5514            torch.vmap(torch.vmap(fn, randomness="same"), randomness="error")(x)
5515
5516    @config.patch(error_on_recompile=True)
5517    def test_vmap_recompile_with_randomness(self):
5518        @torch.compile(backend="eager")
5519        def fn(x):
5520            return torch.vmap(lambda x: x.sin())(x)
5521
5522        x = torch.zeros(3, 3, 4, 5)
5523        torch.vmap(fn, randomness="same")(x)
5524        with self.assertRaises(torch._dynamo.exc.RecompileError):
5525            torch.vmap(fn, randomness="different")(x)
5526
5527    def test_vmap_call_torch_compile_fn(self):
5528        def wrapped_fn(x):
5529            return x.sin()
5530
5531        x = torch.randn(3, 4)
5532        fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
5533
5534        with self.assertRaisesRegex(
5535            torch._dynamo.exc.Unsupported,
5536            "Calling torch.func.vmap\\(compiled_fn\\) function from eager mode is not supported",
5537        ):
5538            torch.func.vmap(fn)(x)
5539
5540    def test_grad_call_torch_compile_fn(self):
5541        def wrapped_fn(x):
5542            return x.sin().sum()
5543
5544        x = torch.randn(3, 4)
5545        fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
5546
5547        with self.assertRaisesRegex(
5548            torch._dynamo.exc.Unsupported,
5549            "Calling torch.func.grad\\(compiled_fn\\) function from eager mode is not supported",
5550        ):
5551            torch.func.grad(fn)(x)
5552
5553    def test_jvp_call_torch_compile_fn(self):
5554        def wrapped_fn(x):
5555            return x.sin().sum()
5556
5557        x = torch.randn(3, 4)
5558        fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
5559
5560        with self.assertRaisesRegex(
5561            torch._dynamo.exc.Unsupported,
5562            "Calling torch.func.jvp\\(compiled_fn\\) function from eager mode is not supported",
5563        ):
5564            torch.func.jvp(fn, (x,), (x,))
5565
5566    @config.patch(error_on_recompile=True)
5567    def test_grad_recompile(self):
5568        @torch.compile(backend="eager")
5569        def fn(x):
5570            return torch.func.grad(torch.sin)(x)
5571
5572        x = torch.randn([])
5573        torch.func.grad(fn)(x)
5574        # should not recompile on second call
5575        torch.func.grad(fn)(x)
5576
5577    def test_vmap_get_wrapped(self):
5578        counters.clear()
5579
5580        def g(x):
5581            return x.sin()
5582
5583        @torch.compile(backend="aot_eager", fullgraph=True)
5584        def fn():
5585            return torch.vmap(g)
5586
5587        x = torch.randn(3, 4)
5588        expected = torch.vmap(g)(x)
5589        wrapper = fn()
5590        got = wrapper(x)
5591        self.assertEqual(expected, got)
5592
5593    def test_vmap_with_conditional_graph_break(self):
5594        def g(x):
5595            if len(x.shape) < 2:
5596                torch._dynamo.graph_break()
5597                return x.sin()
5598            else:
5599                return x.cos()
5600
5601        @torch.compile(backend="aot_eager")
5602        def fn(x):
5603            return torch.vmap(g)(x)
5604
5605        counters.clear()
5606        x = torch.randn(2, 3)
5607        expected = x.sin()
5608        got = fn(x)
5609        self.assertEqual(expected, got)
5610        self.assertEqual(len(counters["graph_break"]), 1)
5611
5612        counters.clear()
5613        y = torch.randn(2, 3, 4)
5614        expected = y.cos()
5615        got = fn(y)
5616        self.assertEqual(expected, got)
5617        self.assertEqual(len(counters["graph_break"]), 0)
5618
5619    def test_vmap_with_graph_break(self):
5620        counters.clear()
5621
5622        def g(x):
5623            y = x.cos()
5624            print("hi")
5625            return y.sin()
5626
5627        def fn(x):
5628            return torch.vmap(g)(x)
5629
5630        x = torch.randn(3, 4)
5631        opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
5632        expected = fn(x)
5633        got = opt(x)
5634        self.assertEqual(len(counters["graph_break"]), 1)
5635        self.assertEqual(expected, got)
5636
5637    def test_vmap_with_graph_break_2(self):
5638        counters.clear()
5639
5640        def cos(x):
5641            print("cos")
5642            return x.cos()
5643
5644        def sin(x):
5645            print("sin")
5646            return x.sin()
5647
5648        def g(x):
5649            y = cos(x)
5650            return sin(y)
5651
5652        def fn(x):
5653            return torch.vmap(g, randomness="same")(x)
5654
5655        x = torch.randn(3, 4)
5656        opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
5657        expected = fn(x)
5658        got = opt(x)
5659        self.assertEqual(len(counters["graph_break"]), 1)
5660        self.assertEqual(expected, got)
5661
5662    def test_vmap_with_graph_break_lambda(self):
5663        counters.clear()
5664
5665        def sin(x):
5666            print("sin")
5667            return x.sin()
5668
5669        def fn(x):
5670            return torch.vmap(lambda x: sin(x))(x)
5671
5672        x = torch.randn(3, 4)
5673        opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
5674        expected = fn(x)
5675        got = opt(x)
5676        self.assertEqual(len(counters["graph_break"]), 1)
5677        self.assertEqual(expected, got)
5678
5679    def test_vmap(self):
5680        def fn(x):
5681            return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
5682
5683        x = torch.randn(3, 3, 3)
5684        wrapped_gm = self._compile_check(fn, (x,))
5685
5686        # Dynamic shapes produce a slightly different graph.
5687        if check_dynamic_shape_capture():
5688            return
5689
5690        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5691        self.assertExpectedInline(
5692            actual,
5693            """\
5694class GraphModule(torch.nn.Module):
5695    def forward(self, L_x_: "f32[3, 3, 3]"):
5696        l_x_ = L_x_
5697
5698        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
5699
5700        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
5701
5702        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
5703
5704        sum_1: "f32[3]" = _add_batch_dim.sum(0)
5705        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
5706        batched_outputs: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
5707
5708        _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
5709
5710        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
5711        return (_remove_batch_dim,)
5712""",
5713        )
5714
5715    def test_vmap_free_const(self):
5716        y = 3
5717
5718        def fn(x):
5719            return torch.func.vmap(lambda x: x.sum(0) + x.sum(1) + y)(x)
5720
5721        x = torch.randn(3, 3, 3)
5722        wrapped_gm = self._compile_check(fn, (x,))
5723
5724        # Dynamic shapes produce a slightly different graph.
5725        if check_dynamic_shape_capture():
5726            return
5727
5728        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5729        self.assertExpectedInline(
5730            actual,
5731            """\
5732class GraphModule(torch.nn.Module):
5733    def forward(self, L_x_: "f32[3, 3, 3]"):
5734        l_x_ = L_x_
5735
5736        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
5737
5738        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
5739
5740        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
5741
5742        sum_1: "f32[3]" = _add_batch_dim.sum(0)
5743        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
5744        add: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
5745        batched_outputs: "f32[3]" = add + 3;  add = None
5746
5747        _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
5748
5749        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
5750        return (_remove_batch_dim,)
5751""",
5752        )
5753
5754    def test_vmap_free_tensor(self):
5755        y = torch.randn(3, 3)
5756
5757        def fn(x):
5758            return torch.func.vmap(lambda x: x.sum(0) + x.sum(1) + y)(x)
5759
5760        x = torch.randn(3, 3, 3)
5761        wrapped_gm = self._compile_check(fn, (x,))
5762
5763        # Dynamic shapes produce a slightly different graph.
5764        if check_dynamic_shape_capture():
5765            return
5766
5767        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5768        self.assertExpectedInline(
5769            actual,
5770            """\
5771class GraphModule(torch.nn.Module):
5772    def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
5773        l_x_ = L_x_
5774        l_y_ = L_y_
5775
5776        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
5777
5778        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
5779
5780        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
5781
5782        sum_1: "f32[3]" = _add_batch_dim.sum(0)
5783        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
5784        add: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
5785        batched_outputs: "f32[3, 3]" = add + l_y_;  add = l_y_ = None
5786
5787        _remove_batch_dim: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
5788
5789        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
5790        return (_remove_batch_dim,)
5791""",
5792        )
5793
5794    def test_vmap_two_inputs(self):
5795        def fn(x, y):
5796            return torch.func.vmap(
5797                lambda x, y: x.sum(0) + x.sum(1) + y, in_dims=(0, 1)
5798            )(x, y)
5799
5800        x = torch.randn(3, 3, 3)
5801        y = torch.randn(3, 3)
5802        wrapped_gm = self._compile_check(fn, (x, y))
5803
5804        # Dynamic shapes produce a slightly different graph.
5805        if check_dynamic_shape_capture():
5806            return
5807
5808        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5809        self.assertExpectedInline(
5810            actual,
5811            """\
5812class GraphModule(torch.nn.Module):
5813    def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
5814        l_x_ = L_x_
5815        l_y_ = L_y_
5816
5817        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
5818
5819        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
5820
5821        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
5822        _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1);  l_y_ = None
5823
5824        sum_1: "f32[3]" = _add_batch_dim.sum(0)
5825        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
5826        add: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
5827        batched_outputs: "f32[3]" = add + _add_batch_dim_1;  add = _add_batch_dim_1 = None
5828
5829        _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
5830
5831        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
5832        return (_remove_batch_dim,)
5833""",
5834        )
5835
5836    def test_vmap_two_inputs_tuple_in_dims(self):
5837        in_dims = (0, 1)
5838
5839        def fn(x, y):
5840            return torch.func.vmap(
5841                lambda x, y: x.sum(0) + x.sum(1) + y, in_dims=in_dims
5842            )(x, y)
5843
5844        x = torch.randn(3, 3, 3)
5845        y = torch.randn(3, 3)
5846        wrapped_gm = self._compile_check(fn, (x, y))
5847
5848        # Dynamic shapes produce a slightly different graph.
5849        if check_dynamic_shape_capture():
5850            return
5851
5852        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5853        self.assertExpectedInline(
5854            actual,
5855            """\
5856class GraphModule(torch.nn.Module):
5857    def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
5858        l_x_ = L_x_
5859        l_y_ = L_y_
5860
5861        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
5862
5863        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
5864
5865        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
5866        _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1);  l_y_ = None
5867
5868        sum_1: "f32[3]" = _add_batch_dim.sum(0)
5869        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
5870        add: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
5871        batched_outputs: "f32[3]" = add + _add_batch_dim_1;  add = _add_batch_dim_1 = None
5872
5873        _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
5874
5875        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
5876        return (_remove_batch_dim,)
5877""",
5878        )
5879
5880    def test_vmap_over_vmap_two_inputs(self):
5881        def fn(x, y):
5882            return torch.func.vmap(torch.func.vmap(lambda x, y: x + y, in_dims=1))(x, y)
5883
5884        x = torch.randn(3, 3, 3)
5885        y = torch.randn(3, 3, 3)
5886        wrapped_gm = self._compile_check(fn, (x, y))
5887
5888        # Dynamic shapes produce a slightly different graph.
5889        if check_dynamic_shape_capture():
5890            return
5891
5892        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5893        self.assertExpectedInline(
5894            actual,
5895            """\
5896class GraphModule(torch.nn.Module):
5897    def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
5898        l_x_ = L_x_
5899        l_y_ = L_y_
5900
5901        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
5902
5903        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
5904
5905        child: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
5906        child_1: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1);  l_y_ = None
5907
5908        lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions_1 = None
5909
5910        _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting_1 = None
5911
5912        _add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(child, 1, 2);  child = None
5913        _add_batch_dim_3: "f32[3]" = torch._C._functorch._add_batch_dim(child_1, 1, 2);  child_1 = None
5914
5915        batched_outputs: "f32[3]" = _add_batch_dim_2 + _add_batch_dim_3;  _add_batch_dim_2 = _add_batch_dim_3 = None
5916
5917        batched_outputs_1: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0);  batched_outputs = None
5918
5919        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
5920
5921        _remove_batch_dim_1: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 3, 0);  batched_outputs_1 = None
5922
5923        _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting_1 = None
5924        return (_remove_batch_dim_1,)
5925""",
5926        )
5927
5928    def test_vmap_over_vmap_captured(self):
5929        x = torch.ones(2, 3)
5930        y = torch.ones(5, 3)
5931
5932        def fn(x):
5933            return torch.func.vmap(torch.func.vmap(lambda y: x * y))(y)
5934
5935        wrapped_gm = self._compile_check(fn, (x,))
5936
5937        # Dynamic shapes produce a slightly different graph.
5938        if check_dynamic_shape_capture():
5939            return
5940
5941        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5942        self.assertExpectedInline(
5943            actual,
5944            """\
5945class GraphModule(torch.nn.Module):
5946    def forward(self, L_y_: "f32[5, 3]", L_x_: "f32[2, 3]"):
5947        l_y_ = L_y_
5948        l_x_ = L_x_
5949
5950        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
5951
5952        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error');  _vmap_increment_nesting = None
5953
5954        child: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1);  l_y_ = None
5955
5956        lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions_1 = None
5957
5958        _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting_1 = None
5959
5960        _add_batch_dim_1: "f32[]" = torch._C._functorch._add_batch_dim(child, 0, 2);  child = None
5961
5962        batched_outputs: "f32[2, 3]" = l_x_ * _add_batch_dim_1;  l_x_ = _add_batch_dim_1 = None
5963
5964        batched_outputs_1: "f32[3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0);  batched_outputs = None
5965
5966        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
5967
5968        _remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 5, 0);  batched_outputs_1 = None
5969
5970        _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting_1 = None
5971        return (_remove_batch_dim_1,)
5972""",
5973        )
5974
5975    def test_vmap_multiple_outputs(self):
5976        x = torch.ones(2, 4, 3)
5977
5978        def fn(x):
5979            return torch.vmap(lambda x: (x.sum(0), x.sum(1)))(x)
5980
5981        wrapped_gm = self._compile_check(fn, (x,))
5982
5983        # Dynamic shapes produce a slightly different graph.
5984        if check_dynamic_shape_capture():
5985            return
5986
5987        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
5988        self.assertExpectedInline(
5989            actual,
5990            """\
5991class GraphModule(torch.nn.Module):
5992    def forward(self, L_x_: "f32[2, 4, 3]"):
5993        l_x_ = L_x_
5994
5995        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
5996
5997        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error');  _vmap_increment_nesting = None
5998
5999        _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
6000
6001        child: "f32[3]" = _add_batch_dim.sum(0)
6002        child_1: "f32[4]" = _add_batch_dim.sum(1);  _add_batch_dim = None
6003
6004        _remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 0);  child = None
6005        _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0);  child_1 = None
6006
6007        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
6008        return (_remove_batch_dim, _remove_batch_dim_1)
6009""",
6010        )
6011
6012    def test_vmap_multiple_outputs_diff_dims(self):
6013        x = torch.ones(2, 4, 3)
6014
6015        def fn(x):
6016            return torch.vmap(lambda x: (x.sum(0), x.sum(1)), out_dims=(1, 0))(x)
6017
6018        wrapped_gm = self._compile_check(fn, (x,))
6019
6020        # Dynamic shapes produce a slightly different graph.
6021        if check_dynamic_shape_capture():
6022            return
6023
6024        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
6025        self.assertExpectedInline(
6026            actual,
6027            """\
6028class GraphModule(torch.nn.Module):
6029    def forward(self, L_x_: "f32[2, 4, 3]"):
6030        l_x_ = L_x_
6031
6032        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
6033
6034        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error');  _vmap_increment_nesting = None
6035
6036        _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
6037
6038        child: "f32[3]" = _add_batch_dim.sum(0)
6039        child_1: "f32[4]" = _add_batch_dim.sum(1);  _add_batch_dim = None
6040
6041        _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1);  child = None
6042        _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0);  child_1 = None
6043
6044        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
6045        return (_remove_batch_dim, _remove_batch_dim_1)
6046""",
6047        )
6048
6049    def test_vmap_multiple_outputs_out_dims_tuple(self):
6050        x = torch.ones(2, 4, 3)
6051        out_dims = (1, 0)
6052
6053        def fn(x):
6054            return torch.vmap(lambda x: (x.sum(0), x.sum(1)), out_dims=out_dims)(x)
6055
6056        wrapped_gm = self._compile_check(fn, (x,))
6057
6058        # Dynamic shapes produce a slightly different graph.
6059        if check_dynamic_shape_capture():
6060            return
6061
6062        actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
6063        self.assertExpectedInline(
6064            actual,
6065            """\
6066class GraphModule(torch.nn.Module):
6067    def forward(self, L_x_: "f32[2, 4, 3]"):
6068        l_x_ = L_x_
6069
6070        lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions = None
6071
6072        _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error');  _vmap_increment_nesting = None
6073
6074        _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
6075
6076        child: "f32[3]" = _add_batch_dim.sum(0)
6077        child_1: "f32[4]" = _add_batch_dim.sum(1);  _add_batch_dim = None
6078
6079        _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1);  child = None
6080        _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0);  child_1 = None
6081
6082        _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
6083        return (_remove_batch_dim, _remove_batch_dim_1)
6084""",
6085        )
6086
6087    def test_vmap_kwargs(self):
6088        counters.clear()
6089        x = torch.ones(2, 3)
6090        y = torch.randn(2, 3)
6091
6092        def fn(x, y):
6093            return torch.func.vmap(lambda x, y: x + y)(x, y=y)
6094
6095        actual = fn(x, y)
6096        expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
6097        self.assertEqual(len(counters["graph_break"]), 0)
6098        self.assertEqual(actual, expected)
6099
6100    def test_vmap_pytree_inputs(self):
6101        counters.clear()
6102        x = torch.ones(2, 3)
6103        y = torch.randn(2, 3)
6104
6105        def vmap_fn(inps):
6106            x = inps["x"]
6107            y = inps["y"]
6108            return x + y
6109
6110        def fn(x, y):
6111            return torch.func.vmap(vmap_fn)({"x": x, "y": y})
6112
6113        actual = fn(x, y)
6114        expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
6115        self.assertEqual(len(counters["graph_break"]), 0)
6116        self.assertEqual(actual, expected)
6117
6118    def test_vmap_side_effects(self):
6119        counters.clear()
6120        x = torch.ones(2, 3)
6121        y = torch.randn(2, 3)
6122
6123        some_list = []
6124
6125        def f(x, y):
6126            some_list.append(1)
6127            return x + y
6128
6129        def wrapper_fn(x, y):
6130            return torch.func.vmap(f)(x, y)
6131
6132        actual = wrapper_fn(x, y)
6133        expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
6134        self.assertEqual(len(counters["graph_break"]), 0)
6135        self.assertEqual(actual, expected)
6136        self.assertEqual(some_list, [1, 1])
6137
6138    @unittest.expectedFailure
6139    def test_vmap_side_effects_append_input(self):
6140        counters.clear()
6141        x = torch.ones(2, 3)
6142        y = torch.randn(2, 3)
6143
6144        some_list = []
6145
6146        def f(x, y):
6147            some_list.append(x)
6148            return x + y
6149
6150        def wrapper_fn(x, y):
6151            return torch.func.vmap(f)(x, y)
6152
6153        actual = wrapper_fn(x, y)
6154        expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
6155        self.assertEqual(len(counters["graph_break"]), 0)
6156        self.assertEqual(actual, expected)
6157
6158    def test_vmap_previous_illegal_op_no_graph_break(self):
6159        counters.clear()
6160
6161        # calling .stride() would previously graph break
6162        def bad_fn(x):
6163            y = x.view((4, 3))
6164            y.stride()
6165            return y
6166
6167        def wrapper_fn(x):
6168            return torch.func.vmap(bad_fn)(x)
6169
6170        x = torch.randn(2, 3, 4)
6171        actual = wrapper_fn(x)
6172        expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
6173        self.assertEqual(len(counters["graph_break"]), 0)
6174        self.assertEqual(actual, expected)
6175
6176    def test_vmap_disable_capture(self):
6177        counters.clear()
6178
6179        with config.patch(capture_func_transforms=False):
6180            # We have verified above that this
6181            # function compiles
6182            def wrapper_fn(x):
6183                return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
6184
6185            x = torch.randn(3, 3, 3)
6186            actual = wrapper_fn(x)
6187            expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
6188                x
6189            )
6190            self.assertEqual(len(counters["graph_break"]), 1)
6191            self.assertEqual(
6192                dict(counters["graph_break"]),
6193                {
6194                    "torch.func.vmap capture is disabled, it can be "
6195                    "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 2
6196                },
6197            )
6198            self.assertEqual(actual, expected)
6199
6200    def test_vmap_multiple_invocation_in_dims(self):
6201        counters.clear()
6202
6203        def wrapper_fn(x, in_dims):
6204            return torch.func.vmap(torch.sum, in_dims)(x)
6205
6206        x = torch.randn(3, 3, 3, 3)
6207        cnt = CompileCounter()
6208        opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
6209        expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
6210        # Third invocation of `opt` makes `in_dims` as SymInt.
6211        actual = opt(x, 0), opt(x, 1), opt(x, 2)
6212        self.assertEqual(expected, actual)
6213        self.assertEqual(cnt.frame_count, 3)
6214        self.assertEqual(cnt.op_count, 21)
6215
6216    def test_vmap_multiple_invocation_out_dims(self):
6217        counters.clear()
6218
6219        def wrapper_fn(x, out_dims):
6220            return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)
6221
6222        x = torch.randn(3, 3, 3, 3)
6223        cnt = CompileCounter()
6224        opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
6225        expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
6226        # Third invocation of `opt` makes `in_dims` as SymInt.
6227        actual = opt(x, 0), opt(x, 1), opt(x, 2)
6228        self.assertEqual(expected, actual)
6229        self.assertEqual(cnt.frame_count, 3)
6230        self.assertEqual(cnt.op_count, 21)
6231
6232    def test_vmap_new_tensor_in_body(self):
6233        def fn(x):
6234            return x + torch.ones(3)
6235
6236        def wrapper_fn(x):
6237            return torch.func.vmap(fn)(x)
6238
6239        x = torch.randn(
6240            3,
6241        )
6242        opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
6243        expected = wrapper_fn(x)
6244        actual = opt(x)
6245        self.assertEqual(expected, actual)
6246
6247    def test_vmap_new_tensor_unused_in_body(self):
6248        def fn(x):
6249            return torch.tensor(0.5)
6250
6251        def wrapper_fn(x):
6252            return torch.func.vmap(fn)(x)
6253
6254        x = torch.randn(3)
6255        opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
6256        expected = wrapper_fn(x)
6257        actual = opt(x)
6258        self.assertEqual(expected, actual)
6259
6260    def test_vmap_new_tensor_implicit_via_op(self):
6261        def wrapper_fn(x):
6262            return torch.func.vmap(lambda t: torch.add(t, 0.5))(x)
6263
6264        x = torch.randn(3)
6265        opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
6266        expected = wrapper_fn(x)
6267        actual = opt(x)
6268        self.assertEqual(expected, actual)
6269
6270
6271class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
6272    def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
6273        cloned_args = []
6274        for arg in args:
6275            cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad))
6276
6277        torch.manual_seed(0)
6278        expected = fn(*args)
6279        expected.sum().backward()
6280
6281        opt_fn = torch.compile(fn, fullgraph=fullgraph, backend=backend)
6282        torch.manual_seed(0)
6283        result = opt_fn(*cloned_args)
6284        result.sum().backward()
6285
6286        if not skip_check:
6287            self.assertEqual(result, expected)
6288            for arg, cloned_arg in zip(args, cloned_args):
6289                self.assertEqual(arg.grad, cloned_arg.grad)
6290
6291    @requires_cuda
6292    @torch._functorch.config.patch(functionalize_rng_ops=True)
6293    def test_function(self):
6294        def gn(x, y):
6295            return torch.sigmoid(torch.matmul(x, y))
6296
6297        def fn(x, y):
6298            return torch.utils.checkpoint.checkpoint(
6299                gn, torch.sin(x), y, use_reentrant=True
6300            )
6301
6302        x = torch.randn(4, 4, requires_grad=True)
6303        y = torch.randn(4, 4, requires_grad=True)
6304
6305        fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
6306        bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
6307        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
6308        self._validate(fn, backend, x, y)
6309
6310    @requires_cuda
6311    @torch._functorch.config.patch(functionalize_rng_ops=True)
6312    def test_function_with_kwargs(self):
6313        def gn(x, y):
6314            return torch.sigmoid(torch.matmul(x, y))
6315
6316        def fn(x, y):
6317            return torch.utils.checkpoint.checkpoint(
6318                gn,
6319                torch.sin(x),
6320                y,
6321                use_reentrant=True,
6322                preserve_rng_state=False,
6323            )
6324
6325        x = torch.randn(4, 4, requires_grad=True)
6326        y = torch.randn(4, 4, requires_grad=True)
6327
6328        fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
6329        bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
6330        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
6331        self._validate(fn, backend, x, y)
6332
6333    @requires_cuda
6334    @torch._functorch.config.patch(functionalize_rng_ops=True)
6335    def test_dropout(self):
6336        def gn(x, y):
6337            return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)
6338
6339        def fn(x, y):
6340            return torch.utils.checkpoint.checkpoint(
6341                gn, torch.sin(x), y, use_reentrant=True
6342            )
6343
6344        x = torch.randn(4, 4, device="cuda", requires_grad=True)
6345        y = torch.randn(4, 4, device="cuda", requires_grad=True)
6346
6347        fw_compiler = functools.partial(
6348            count_ops, freq=1, op=torch.ops.rngprims.philox_rand.default
6349        )
6350        # philox_rand is passed from fwd
6351        bw_compiler = functools.partial(
6352            count_ops, freq=0, op=torch.ops.rngprims.philox_rand.default
6353        )
6354        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
6355        self._validate(
6356            fn, backend, x, y, skip_check=True
6357        )  # dropout decomp is known to diverge with eager
6358
6359    @requires_cuda
6360    @torch._functorch.config.patch(functionalize_rng_ops=True)
6361    def test_dropout_inductor(self):
6362        def gn(x, y):
6363            return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)
6364
6365        def fn(x, y):
6366            return torch.utils.checkpoint.checkpoint(
6367                gn, torch.sin(x), y, use_reentrant=True
6368            )
6369
6370        x = torch.randn(4, 4, device="cuda", requires_grad=True)
6371        y = torch.randn(4, 4, device="cuda", requires_grad=True)
6372
6373        backend = "inductor"
6374        self._validate(
6375            fn, backend, x, y, skip_check=True
6376        )  # dropout decomp is known to diverge with eager
6377
6378    @requires_cuda
6379    @torch._functorch.config.patch(functionalize_rng_ops=True)
6380    def test_fallback(self):
6381        def gn(x, y):
6382            torch._dynamo.graph_break()
6383            return torch.sigmoid(torch.matmul(x, y))
6384
6385        def fn(x, y):
6386            return torch.cos(
6387                torch.utils.checkpoint.checkpoint(
6388                    gn, torch.sin(x), y, use_reentrant=True
6389                ),
6390            )
6391
6392        x = torch.randn(4, 4, requires_grad=True)
6393        y = torch.randn(4, 4, requires_grad=True)
6394        args = (x, y)
6395
6396        backend = EagerAndRecordGraphs()
6397        cnt = CompileCounterWithBackend(backend)
6398
6399        expected = fn(*args)
6400        result = torch.compile(fn, backend=cnt)(*args)
6401
6402        self.assertEqual(result, expected)
6403
6404        # One graph for torch.sin on the input, and other for torch.cos.
6405        self.assertEqual(cnt.frame_count, 2)
6406        self.assertEqual(cnt.op_count, 2)
6407        self.assertEqual(len(backend.graphs), 2)
6408
6409    @requires_cuda
6410    @torch._functorch.config.patch(functionalize_rng_ops=True)
6411    def test_module(self):
6412        class MockModule(torch.nn.Module):
6413            def __init__(self) -> None:
6414                super().__init__()
6415                self.linear = torch.nn.Linear(10, 10)
6416
6417            def forward(self, x):
6418                return torch.sigmoid(self.linear(x))
6419
6420        mod = MockModule()
6421
6422        def fn(x):
6423            return torch.utils.checkpoint.checkpoint(
6424                mod, torch.sin(x), use_reentrant=True
6425            )
6426
6427        x = torch.randn(10, 10, requires_grad=True)
6428
6429        fw_compiler = functools.partial(
6430            count_ops, freq=1, op=torch.ops.aten.sigmoid.default
6431        )
6432        # sigmoid passed from fwd
6433        bw_compiler = functools.partial(
6434            count_ops, freq=0, op=torch.ops.aten.sigmoid.default
6435        )
6436        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
6437        self._validate(fn, backend, x)
6438
6439    def test_override_fallthrough_dispatch_key(self):
6440        class _FallthroughTestOnly(torch._ops.HigherOrderOperator):
6441            def __init__(self):
6442                super().__init__("_fallthrough_test_only")
6443
6444            def __call__(self, *args, **kwargs):
6445                return super().__call__(*args, **kwargs)
6446
6447        test_op = _FallthroughTestOnly()
6448        default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
6449        self.assertTrue(
6450            not any(test_op.non_fallthrough_keys.has(key) for key in default_keys)
6451        )
6452
6453        foos = [lambda x=i: x for i, k in enumerate(default_keys)]
6454        for foo, fallthrough_key in zip(foos, default_keys):
6455            test_op.py_impl(fallthrough_key)(foo)
6456
6457        self.assertTrue(
6458            all(test_op.non_fallthrough_keys.has(key) for key in default_keys)
6459        )
6460        self.assertEqual(
6461            list(range(len(default_keys))),
6462            [test_op.py_kernels[key]() for key in default_keys],
6463        )
6464
6465    def test_cond_with_kwargs(self):
6466        from torch._higher_order_ops.cond import cond_op
6467
6468        def test(pred, x):
6469            def true_fn(x):
6470                return x
6471
6472            def false_fn(x):
6473                return -x
6474
6475            return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
6476
6477        cnt = CompileCounter()
6478        opt_test = torch.compile(test, backend=cnt, fullgraph=True)
6479        inp = torch.ones(3, 3)
6480        true_pred = torch.Tensor([True])
6481        false_pred = torch.Tensor([False])
6482        self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp)))
6483        self.assertEqual(cnt.frame_count, 1)
6484        self.assertTrue(
6485            torch.allclose(test(false_pred, inp), opt_test(false_pred, inp))
6486        )
6487        self.assertEqual(cnt.frame_count, 1)
6488
6489    def test_cond_with_invalid_kwargs(self):
6490        from torch._higher_order_ops.cond import cond_op
6491
6492        def test(pred, mode, x):
6493            def true_fn(x):
6494                return x
6495
6496            def false_fn(x):
6497                return -x
6498
6499            if mode:
6500                return cond_op(
6501                    pred=pred,
6502                    true_fn=true_fn,
6503                    false_fn=false_fn,
6504                    operands=[x],
6505                    invalid=True,
6506                )
6507            else:
6508                return cond_op(
6509                    pred,
6510                    pred=pred,
6511                    true_fn=true_fn,
6512                    false_fn=false_fn,
6513                    operands=[x],
6514                )
6515
6516        cnt = CompileCounter()
6517        opt_test = torch.compile(test, backend=cnt)
6518        inp = torch.ones(3, 3)
6519        with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
6520            opt_test(True, True, inp)
6521
6522        with self.assertRaises(AssertionError):
6523            opt_test(True, False, inp)
6524
6525    def test_non_aliasing_util(self):
6526        from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing
6527
6528        a = [torch.tensor(1), {"a": torch.tensor(1)}]
6529        b = (torch.tensor(1),)
6530        _assert_tensors_nonaliasing(a, b)
6531
6532        with self.assertRaisesRegex(
6533            AssertionError, "inputs to function body cannot alias outputs"
6534        ):
6535            _assert_tensors_nonaliasing(a, a)
6536
6537
6538if __name__ == "__main__":
6539    from torch._dynamo.test_case import run_tests
6540
6541    run_tests()
6542