xref: /aosp_15_r20/external/pytorch/test/dynamo/test_export.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2"""
3PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
4with test_export_persist_assert)
5"""
6import copy
7import functools
8import inspect
9import io
10import operator
11import unittest
12from enum import Enum
13from typing import Dict, List, Sequence
14from unittest.mock import patch
15
16import torch
17import torch._dynamo
18import torch._dynamo.test_case
19import torch._dynamo.testing
20from functorch.experimental.control_flow import cond
21from torch._dynamo import config
22from torch._dynamo.exc import UserError
23from torch._dynamo.testing import normalize_gm
24from torch._higher_order_ops.out_dtype import out_dtype
25from torch._subclasses import fake_tensor
26from torch.fx.experimental.proxy_tensor import make_fx
27from torch.fx.experimental.symbolic_shapes import (
28    ConstraintViolationError,
29    DimDynamic,
30    ShapeEnv,
31    StatelessSymbolicContext,
32)
33from torch.testing._internal import common_utils
34from torch.testing._internal.common_cuda import TEST_CUDA
35
36
37class ExportTests(torch._dynamo.test_case.TestCase):
38    # TODO(voz): Refactor to a shared test function.
39    # The tests in this file are a little redundant,
40    # They all take a func, run it with eager, then export it, then compare
41    def test_export(self):
42        def pre_attention_state_ops(input, mems, state):
43            lc_key = state[0]
44            lc_val = state[1]
45            bar = []
46            for i in range(0, 4):
47                bar2 = []
48                for j in range(0, 3):
49                    bar2.append(
50                        lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
51                    )
52                bar.append(bar2)
53
54            return bar
55
56        def func():
57            mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
58            state = [
59                torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
60                torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
61            ]
62            i = torch.tensor(
63                [
64                    [0.0313, -0.1487, -0.3846, -0.5321],
65                    [-1.7073, 1.3331, -0.0890, -1.4935],
66                    [-0.8314, -0.1862, -0.5935, 1.5232],
67                ]
68            )
69            return pre_attention_state_ops(i, mems, state)
70
71        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
72        real_result = opt_func()
73
74        torch._dynamo.reset()
75
76        exported = torch._dynamo.export(func)()
77        out_graph = exported[0]
78
79        dynamo_result = out_graph()
80        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
81
82    def test_no_tensor_computation_fail(self):
83        with self.assertRaisesRegex(
84            AssertionError,
85            "Failed to produce a graph",
86        ):
87            inp = [torch.randn(3)]
88            inp2 = 2
89            inps = [inp, inp2]
90
91            def func(x, y):
92                return x
93
94            exported = torch._dynamo.export(func, same_signature=False)(*inps)
95
96    def test_no_tensor_computation(self):
97        inp = [torch.randn(3)]
98        inp2 = 2
99        inps = [inp, inp2]
100
101        def func(x, y):
102            return x
103
104        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
105        real_result = opt_func(*inps)
106
107        torch._dynamo.reset()
108
109        exported = torch._dynamo.export(func)(*inps)
110        out_graph = exported[0]
111
112        dynamo_result = out_graph(*inps)
113
114        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
115        self.assertExpectedInline(
116            out_graph.code.strip(),
117            """\
118def forward(self, x, y):
119    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
120    x = arg0
121    return pytree.tree_unflatten([x], self._out_spec)""",
122        )
123
124    def test_no_tensor_computation_2(self):
125        inp = torch.randn(3)
126        inp2 = 2
127        inps = [inp, inp2]
128
129        def func(x, y):
130            return y
131
132        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
133        real_result = opt_func(*inps)
134
135        torch._dynamo.reset()
136
137        exported = torch._dynamo.export(func)(*inps)
138        out_graph = exported[0]
139
140        dynamo_result = out_graph(*inps)
141
142        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
143        self.assertExpectedInline(
144            out_graph.code.strip(),
145            """\
146def forward(self, x, y):
147    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
148    x = arg0
149    return pytree.tree_unflatten([2], self._out_spec)""",
150        )
151
152    def test_export_mismatched_out(self):
153        def func(x):
154            y = x + 1
155            return ([x, x], (y, y))
156
157        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
158        real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
159
160        torch._dynamo.reset()
161
162        exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]]))
163        out_graph = exported[0]
164
165        dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
166
167        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
168
169    def test_export_shape_control_flow_1(self):
170        def func(x):
171            if x.shape[0] > 10:
172                return x.cos()
173            return x.sin()
174
175        opt_func = torch._dynamo.optimize("eager")(func)
176        real_result = opt_func(torch.ones(6, 4))
177
178        torch._dynamo.reset()
179
180        exported = torch._dynamo.export(func)(torch.ones(6, 4))
181        out_graph, out_guards = exported
182
183        dynamo_result = out_graph(torch.ones(6, 4))
184
185        from torch._guards import GuardSource
186
187        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
188        hit = False
189        for guard in out_guards:
190            if guard.source == GuardSource.SHAPE_ENV:
191                hit = True
192                self.assertExpectedInline(
193                    guard.code_list,
194                    """["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""",  # noqa: B950
195                )
196                break
197
198        self.assertTrue(hit)
199
200    def test_export_control_flow_with_getattr(self):
201        class Animal(Enum):
202            COW = "moo"
203
204        class MyModule(torch.nn.Module):
205            def __init__(self, a):
206                super().__init__()
207                self.a = a
208
209            def forward(self, x):
210                if self.a == Animal.COW.value:
211                    return x * x
212                else:
213                    raise ValueError("bad")
214
215        module = MyModule("moo")
216        input = (torch.ones(4, 3),)
217        resA = module(*input)
218        graph, _ = torch._dynamo.export(module)(*input)
219        resB = graph(*input)
220        self.assertTrue(torch._dynamo.utils.same(resA, resB))
221
222    def test_export_graph_bypass(self):
223        inp = [
224            torch.tensor([0.1, 0.1]),
225            torch.tensor([0.2, 0.2]),
226            torch.tensor([0.3, 0.3]),
227        ]
228
229        def func(x):
230            first = x[2]
231            second = x[2]
232            return first * second
233
234        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
235        real_result = opt_func(inp)
236
237        torch._dynamo.reset()
238
239        exported = torch._dynamo.export(func)(inp)
240        out_graph = exported[0]
241
242        dynamo_result = out_graph(inp)
243
244        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
245
246    def test_list_unpack(self):
247        inp = [
248            torch.tensor([0.1, 0.1]),
249            torch.tensor([0.2, 0.2]),
250            torch.tensor([0.3, 0.3]),
251        ]
252
253        def func(x):
254            first = x[2]
255            second = x[2]
256            return x[0], first * second, x[1], x[2]
257
258        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
259        real_result = opt_func(inp)
260
261        torch._dynamo.reset()
262
263        exported = torch._dynamo.export(func)(inp)
264        out_graph = exported[0]
265
266        dynamo_result = out_graph(inp)
267
268        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
269
270    def test_export_with_shallow_list_copy_wo_side_effects(self):
271        def f(x):
272            y = x.copy()
273            return y[0] + y[1]
274
275        inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
276        gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
277            inp
278        ).graph_module
279        self.assertTrue(torch._dynamo.utils.same(gm(inp), f(inp)))
280
281    def test_export_with_shallow_list_copy_with_side_effects(self):
282        def f(x):
283            y = x.copy()
284            x[0] = x[1]
285            y.append(torch.tensor([[100]]))
286            return x[0] + x[1], y[0] + y[1], y[2]
287
288        inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
289        gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
290            inp
291        ).graph_module
292        res = gm(inp)
293        ref = f(inp)
294        self.assertTrue(torch._dynamo.utils.same(res, ref))
295        self.assertEqual(res[0], res[1])
296
297    def test_export_mismatched_out_2(self):
298        def func(x):
299            y = x + 1
300            return ([x, x], (y, y))
301
302        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
303        real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
304
305        torch._dynamo.reset()
306
307        exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]]))
308        out_graph = exported[0]
309
310        dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
311
312        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
313
314    def test_export_graph_with_list(self):
315        inp = [
316            torch.tensor([0.1, 0.1]),
317            torch.tensor([0.2, 0.2]),
318            torch.tensor([0.3, 0.3]),
319            torch.tensor([0.4, 0.4]),
320        ]
321
322        def func(x):
323            first = x[2]
324            second = x[2]
325            return first * second, x
326
327        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
328        real_result = opt_func(inp)
329
330        torch._dynamo.reset()
331
332        exported = torch._dynamo.export(func)(inp)
333        out_graph = exported[0]
334
335        dynamo_result = out_graph(inp)
336
337        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
338
339    def test_export_graph_with_complex_reorder(self):
340        inp = [
341            torch.tensor([0.1, 0.1]),
342            torch.tensor([0.2, 0.2]),
343            torch.tensor([0.3, 0.3]),
344            torch.tensor([0.4, 0.4]),
345        ]
346
347        def func(x):
348            first = x[0]
349            second = x[1]
350            third = x[2]
351            return third, first, second, first * second, first * third
352
353        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
354        real_result = opt_func(inp)
355
356        torch._dynamo.reset()
357
358        exported = torch._dynamo.export(func)(inp)
359        out_graph = exported[0]
360
361        dynamo_result = out_graph(inp)
362
363        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
364
365    def test_dupes(self):
366        inp = torch.tensor([0.1, 0.1])
367
368        def func(x):
369            y = x + 1
370            return y, y
371
372        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
373        real_result = opt_func(inp)
374
375        torch._dynamo.reset()
376
377        exported = torch._dynamo.export(func)(inp)
378        out_graph = exported[0]
379
380        dynamo_result = out_graph(inp)
381
382        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
383
384    def test_dupes_2(self):
385        inp = torch.tensor([0.1, 0.1])
386
387        def func(x):
388            y = x + 1
389            return y, y
390
391        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
392        real_result = opt_func(inp)
393
394        torch._dynamo.reset()
395
396        exported = torch._dynamo.export(func)(inp)
397        out_graph = exported[0]
398
399        dynamo_result = out_graph(inp)
400
401        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
402
403    def test_dupes_and_bypass(self):
404        inp = torch.tensor([0.1, 0.1])
405        inp2 = torch.tensor([0.4, 0.4])
406        inps = [inp, inp2]
407
408        def func(x, z):
409            y = x + 1
410            return y, y, z
411
412        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
413        real_result = opt_func(*inps)
414
415        torch._dynamo.reset()
416
417        exported = torch._dynamo.export(func)(*inps)
418        out_graph = exported[0]
419
420        dynamo_result = out_graph(*inps)
421
422        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
423
424    def test_dupes_and_bypass_with_non_tensor_arg(self):
425        inp = torch.tensor([0.1, 0.1])
426        inp2 = torch.tensor([0.1, 0.1])
427        inp3 = 4
428        inps = [inp, inp2, inp3]
429
430        def func(x, z, k):
431            y = x + k
432            return y, y, z
433
434        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
435        real_result = opt_func(*inps)
436
437        torch._dynamo.reset()
438
439        exported = torch._dynamo.export(func)(*inps)
440        out_graph = exported[0]
441
442        dynamo_result = out_graph(*inps)
443
444        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
445
446    def test_dupes_and_bypass_reorder_with_non_tensor_arg(self):
447        inp = torch.tensor([0.1, 0.1])
448        inp2 = torch.tensor([0.1, 0.1])
449        inp3 = 4
450        inps = [inp, inp2, inp3]
451
452        def func(x, z, k):
453            y = x + k
454            return z, y, y
455
456        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
457        real_result = opt_func(*inps)
458
459        torch._dynamo.reset()
460
461        exported = torch._dynamo.export(func)(*inps)
462        out_graph = exported[0]
463
464        dynamo_result = out_graph(*inps)
465
466        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
467
468    @config.patch(capture_scalar_outputs=True)
469    def test_dupes_and_bypass_with_non_tensor_output(self):
470        inp = torch.tensor([0.1, 0.1])
471        inp2 = torch.tensor([0.1, 0.1])
472        inp3 = 4
473        inps = [inp, inp2, inp3]
474
475        def func(x, z, k):
476            y = x + k
477            return y[0].item(), y, z
478
479        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
480        real_result = opt_func(*inps)
481
482        torch._dynamo.reset()
483
484        exported = torch._dynamo.export(func)(*inps)
485        out_graph = exported[0]
486
487        dynamo_result = out_graph(*inps)
488
489        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
490
491    def test_zeroes_in_and_out_different_shape_on_test(self):
492        inp = torch.zeros(10)
493        inp2 = torch.zeros(10)
494        inp3 = torch.zeros(10)
495        inps = [inp, inp2, inp3]
496
497        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
498
499        def func(a, b, c):
500            return [[a], [b, c], [a + b], [[c + c]]]
501
502        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
503        real_result = opt_func(*inps_rand)
504
505        torch._dynamo.reset()
506
507        exported = torch._dynamo.export(func)(*inps)
508        out_graph = exported[0]
509
510        dynamo_result = out_graph(*inps_rand)
511
512        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
513
514    @config.patch(capture_scalar_outputs=True)
515    def test_zeroes_in_new_shape_scalar_out(self):
516        inp = torch.zeros(10)
517        inp2 = torch.zeros(10)
518        inp3 = torch.zeros(10)
519        inps = [inp, inp2, inp3]
520
521        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
522
523        def func(a, b, c):
524            return a[0].item() + b[0].item() + c[0].item()
525
526        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
527        real_result = opt_func(*inps_rand)
528
529        torch._dynamo.reset()
530
531        exported = torch._dynamo.export(func)(*inps)
532        out_graph = exported[0]
533
534        dynamo_result = out_graph(*inps_rand)
535
536        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
537
538    @config.patch(capture_scalar_outputs=True)
539    def test_zeroes_in_new_shape_scalar_out_permute(self):
540        inp = torch.zeros(10)
541        inp2 = torch.zeros(10)
542        inp3 = torch.zeros(10)
543        inps = [inp, inp2, inp3]
544
545        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
546
547        def func(a, b, c):
548            return b[0].item() + c[0].item() + a[0].item() + a[0].item()
549
550        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
551        real_result = opt_func(*inps_rand)
552
553        torch._dynamo.reset()
554
555        exported = torch._dynamo.export(func)(*inps)
556        out_graph = exported[0]
557
558        dynamo_result = out_graph(*inps_rand)
559
560        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
561
562    @config.patch(capture_scalar_outputs=True)
563    def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
564        inp = torch.zeros(10)
565        inp2 = torch.zeros(10)
566        inp3 = torch.zeros(10)
567        inps = [inp, inp2, inp3]
568
569        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
570
571        def func(a, b, c):
572            return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a
573
574        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
575        real_result = opt_func(*inps_rand)
576
577        torch._dynamo.reset()
578
579        exported = torch._dynamo.export(func)(*inps)
580        out_graph = exported[0]
581
582        dynamo_result = out_graph(*inps_rand)
583
584        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
585
586    def test_func_return(self):
587        inp = torch.zeros(10)
588        inp2 = torch.zeros(10)
589        inp3 = torch.zeros(10)
590        inps = [inp, inp2, inp3]
591
592        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
593
594        def func(a, b, c):
595            x = a + b + c
596
597            def func2(y):
598                return x * y
599
600            return func2(x)
601
602        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
603        real_result = opt_func(*inps_rand)
604
605        torch._dynamo.reset()
606
607        exported = torch._dynamo.export(func)(*inps)
608        out_graph = exported[0]
609
610        dynamo_result = out_graph(*inps_rand)
611
612        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
613
614    def test_dict_return(self):
615        inp = torch.zeros(10)
616        inp2 = torch.zeros(10)
617        inp3 = torch.zeros(10)
618        inps = [inp, inp2, inp3]
619
620        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
621
622        def func(a, b, c):
623            x = a + b + c
624            return {"a": x}
625
626        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
627        real_result = opt_func(*inps_rand)
628
629        torch._dynamo.reset()
630
631        exported = torch._dynamo.export(func)(*inps)
632        out_graph = exported[0]
633
634        dynamo_result = out_graph(*inps_rand)
635
636        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
637
638    def test_export_with_aten_graph(self):
639        def pre_attention_state_ops(input, mems, state):
640            lc_key = state[0]
641            lc_val = state[1]
642            bar = []
643            for i in range(0, 4):
644                bar2 = []
645                for j in range(0, 3):
646                    bar2.append(
647                        lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
648                    )
649                bar.append(bar2)
650
651            return bar
652
653        def func():
654            mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
655            state = [
656                torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
657                torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
658            ]
659            i = torch.tensor(
660                [
661                    [0.0313, -0.1487, -0.3846, -0.5321],
662                    [-1.7073, 1.3331, -0.0890, -1.4935],
663                    [-0.8314, -0.1862, -0.5935, 1.5232],
664                ]
665            )
666            return pre_attention_state_ops(i, mems, state)
667
668        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
669        real_result = opt_func()
670
671        torch._dynamo.reset()
672
673        exported = torch._dynamo.export(func, aten_graph=True)()
674        out_graph = exported[0]
675
676        dynamo_result = out_graph()
677        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
678
679    def test_export_no_tensor_computation_with_aten_graph(self):
680        inp = [torch.randn(3)]
681        inp2 = 2
682        inps = [inp, inp2]
683
684        def func(x, y):
685            return x
686
687        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
688        real_result = opt_func(*inps)
689
690        torch._dynamo.reset()
691
692        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
693        out_graph = exported[0]
694
695        dynamo_result = out_graph(*inps)
696
697        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
698        self.assertExpectedInline(
699            out_graph.code.strip(),
700            """\
701def forward(self, x, y):
702    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
703    arg0_1 = arg0
704    return pytree.tree_unflatten([arg0_1], self._out_spec)""",
705        )
706
707    def test_no_tensor_computation_2_with_aten_graph(self):
708        inp = torch.randn(3)
709        inp2 = 2
710        inps = [inp, inp2]
711
712        def func(x, y):
713            return y
714
715        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
716        real_result = opt_func(*inps)
717
718        torch._dynamo.reset()
719
720        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
721        out_graph = exported[0]
722
723        dynamo_result = out_graph(*inps)
724
725        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
726        self.assertExpectedInline(
727            out_graph.code.strip(),
728            """\
729def forward(self, x, y):
730    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
731    arg0_1 = arg0
732    return pytree.tree_unflatten([2], self._out_spec)""",
733        )
734
735    def test_export_mismatched_out_with_aten_graph(self):
736        def func(x):
737            y = x + 1
738            return ([x, x], (y, y))
739
740        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
741        real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
742
743        torch._dynamo.reset()
744
745        exported = torch._dynamo.export(func, aten_graph=True)(
746            torch.tensor([[[1.3737, 0.1]]])
747        )
748        out_graph = exported[0]
749
750        dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
751
752        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
753
754    def test_export_graph_bypass_with_aten_graph(self):
755        inp = [
756            torch.tensor([0.1, 0.1]),
757            torch.tensor([0.2, 0.2]),
758            torch.tensor([0.3, 0.3]),
759        ]
760
761        def func(x):
762            first = x[2]
763            second = x[2]
764            return first * second
765
766        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
767        real_result = opt_func(inp)
768
769        torch._dynamo.reset()
770
771        exported = torch._dynamo.export(func, aten_graph=True)(inp)
772        out_graph = exported[0]
773
774        dynamo_result = out_graph(inp)
775
776        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
777
778    def test_list_unpack_with_aten_graph(self):
779        inp = [
780            torch.tensor([0.1, 0.1]),
781            torch.tensor([0.2, 0.2]),
782            torch.tensor([0.3, 0.3]),
783        ]
784
785        def func(x):
786            first = x[2]
787            second = x[2]
788            return x[0], first * second, x[1], x[2]
789
790        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
791        real_result = opt_func(inp)
792
793        torch._dynamo.reset()
794
795        exported = torch._dynamo.export(func, aten_graph=True)(inp)
796        out_graph = exported[0]
797
798        dynamo_result = out_graph(inp)
799
800        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
801
802    def test_export_mismatched_out_2_with_aten_graph(self):
803        def func(x):
804            y = x + 1
805            return ([x, x], (y, y))
806
807        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
808        real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
809
810        torch._dynamo.reset()
811
812        exported = torch._dynamo.export(func, aten_graph=True)(
813            torch.tensor([[[1.3737, 0.1]]])
814        )
815        out_graph = exported[0]
816
817        dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
818
819        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
820
821    def test_export_graph_with_list_with_aten_graph(self):
822        inp = [
823            torch.tensor([0.1, 0.1]),
824            torch.tensor([0.2, 0.2]),
825            torch.tensor([0.3, 0.3]),
826            torch.tensor([0.4, 0.4]),
827        ]
828
829        def func(x):
830            first = x[2]
831            second = x[2]
832            return first * second, x
833
834        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
835        real_result = opt_func(inp)
836
837        torch._dynamo.reset()
838
839        exported = torch._dynamo.export(func, aten_graph=True)(inp)
840        out_graph = exported[0]
841
842        dynamo_result = out_graph(inp)
843
844        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
845
846    def test_export_graph_with_complex_reorder_with_aten_graph(self):
847        inp = [
848            torch.tensor([0.1, 0.1]),
849            torch.tensor([0.2, 0.2]),
850            torch.tensor([0.3, 0.3]),
851            torch.tensor([0.4, 0.4]),
852        ]
853
854        def func(x):
855            first = x[0]
856            second = x[1]
857            third = x[2]
858            return third, first, second, first * second, first * third
859
860        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
861        real_result = opt_func(inp)
862
863        torch._dynamo.reset()
864
865        exported = torch._dynamo.export(func, aten_graph=True)(inp)
866        out_graph = exported[0]
867
868        dynamo_result = out_graph(inp)
869
870        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
871
872    def test_dupes_with_aten_graph(self):
873        inp = torch.tensor([0.1, 0.1])
874
875        def func(x):
876            y = x + 1
877            return y, y
878
879        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
880        real_result = opt_func(inp)
881
882        torch._dynamo.reset()
883
884        exported = torch._dynamo.export(func, aten_graph=True)(inp)
885        out_graph = exported[0]
886
887        dynamo_result = out_graph(inp)
888
889        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
890
891    def test_dupes_2_with_aten_graph(self):
892        inp = torch.tensor([0.1, 0.1])
893
894        def func(x):
895            y = x + 1
896            return y, y
897
898        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
899        real_result = opt_func(inp)
900
901        torch._dynamo.reset()
902
903        exported = torch._dynamo.export(func, aten_graph=True)(inp)
904        out_graph = exported[0]
905
906        dynamo_result = out_graph(inp)
907
908        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
909
910    def test_dupes_and_bypass_with_aten_graph(self):
911        inp = torch.tensor([0.1, 0.1])
912        inp2 = torch.tensor([0.4, 0.4])
913        inps = [inp, inp2]
914
915        def func(x, z):
916            y = x + 1
917            return y, y, z
918
919        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
920        real_result = opt_func(*inps)
921
922        torch._dynamo.reset()
923
924        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
925        out_graph = exported[0]
926
927        dynamo_result = out_graph(*inps)
928
929        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
930
931    def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self):
932        inp = torch.tensor([0.1, 0.1])
933        inp2 = torch.tensor([0.1, 0.1])
934        inp3 = 4
935        inps = [inp, inp2, inp3]
936
937        def func(x, z, k):
938            y = x + k
939            return y, y, z
940
941        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
942        real_result = opt_func(*inps)
943
944        torch._dynamo.reset()
945
946        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
947        out_graph = exported[0]
948
949        dynamo_result = out_graph(*inps)
950
951        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
952
953    def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self):
954        inp = torch.tensor([0.1, 0.1])
955        inp2 = torch.tensor([0.1, 0.1])
956        inp3 = 4
957        inps = [inp, inp2, inp3]
958
959        def func(x, z, k):
960            y = x + k
961            return z, y, y
962
963        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
964        real_result = opt_func(*inps)
965
966        torch._dynamo.reset()
967
968        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
969        out_graph = exported[0]
970
971        dynamo_result = out_graph(*inps)
972
973        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
974
975    @config.patch(capture_scalar_outputs=True)
976    def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
977        inp = torch.tensor([0.1, 0.1])
978        inp2 = torch.tensor([0.1, 0.1])
979        inp3 = 4
980        inps = [inp, inp2, inp3]
981
982        def func(x, z, k):
983            y = x + k
984            return y[0].item(), y, z
985
986        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
987        real_result = opt_func(*inps)
988
989        torch._dynamo.reset()
990
991        exported = torch._dynamo.export(func)(*inps)
992        out_graph = exported[0]
993
994        dynamo_result = out_graph(*inps)
995
996        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
997
998    def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self):
999        inp = torch.zeros(10)
1000        inp2 = torch.zeros(10)
1001        inp3 = torch.zeros(10)
1002        inps = [inp, inp2, inp3]
1003
1004        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
1005
1006        def func(a, b, c):
1007            return [[a], [b, c], [a + b], [[c + c]]]
1008
1009        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
1010        real_result = opt_func(*inps_rand)
1011
1012        torch._dynamo.reset()
1013
1014        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
1015        out_graph = exported[0]
1016
1017        dynamo_result = out_graph(*inps_rand)
1018
1019        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
1020
1021    def test_func_return_with_aten_graph(self):
1022        inp = torch.zeros(10)
1023        inp2 = torch.zeros(10)
1024        inp3 = torch.zeros(10)
1025        inps = [inp, inp2, inp3]
1026
1027        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
1028
1029        def func(a, b, c):
1030            x = a + b + c
1031
1032            def func2(y):
1033                return x * y
1034
1035            return func2(x)
1036
1037        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
1038        real_result = opt_func(*inps_rand)
1039
1040        torch._dynamo.reset()
1041
1042        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
1043        out_graph = exported[0]
1044
1045        dynamo_result = out_graph(*inps_rand)
1046
1047        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
1048
1049    def test_dict_return_with_aten_graph(self):
1050        inp = torch.zeros(10)
1051        inp2 = torch.zeros(10)
1052        inp3 = torch.zeros(10)
1053        inps = [inp, inp2, inp3]
1054
1055        inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
1056
1057        def func(a, b, c):
1058            x = a + b + c
1059            return {"a": x}
1060
1061        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
1062        real_result = opt_func(*inps_rand)
1063
1064        torch._dynamo.reset()
1065
1066        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
1067        out_graph = exported[0]
1068
1069        dynamo_result = out_graph(*inps_rand)
1070
1071        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
1072
1073    def test_export_with_stack_trace(self):
1074        inp = torch.randn(4, 4)
1075
1076        class MyBlock(torch.nn.Module):
1077            def forward(self, x):
1078                x = torch.nn.functional.linear(x, torch.randn(4, 4))
1079                return torch.cos(x).relu() + 1
1080
1081        class MyModule(torch.nn.Module):
1082            def __init__(self) -> None:
1083                super().__init__()
1084                self.block = MyBlock()
1085
1086            def forward(self, x):
1087                out = self.block(x)
1088                return out
1089
1090        exported = torch._dynamo.export(MyModule(), aten_graph=False)(inp)
1091        out_graph = exported[0]
1092
1093        for node in out_graph.graph.nodes:
1094            if node.op not in {"placeholder", "output"}:
1095                self.assertTrue(node.stack_trace is not None)
1096                self.assertTrue(node.meta["nn_module_stack"] is not None)
1097                self.assertTrue(node.meta["source_fn_stack"] is not None)
1098
1099        torch._dynamo.reset()
1100
1101        exported = torch._dynamo.export(MyModule(), aten_graph=True)(inp)
1102        out_graph = exported[0]
1103        for node in out_graph.graph.nodes:
1104            if node.op == "call_function":
1105                self.assertTrue(node.stack_trace is not None)
1106                self.assertTrue(node.meta["nn_module_stack"] is not None)
1107                self.assertTrue(node.meta["source_fn_stack"] is not None)
1108                self.assertTrue(node.meta["val"] is not None)
1109                self.assertTrue(node.meta["original_aten"] is not None)
1110
1111    def test_export_preserves_nn_module_stack_for_get_attr(self):
1112        inp = torch.randn(4, 4)
1113
1114        class MyBlock(torch.nn.Module):
1115            def __init__(self) -> None:
1116                super().__init__()
1117                self.weight = torch.nn.Parameter(torch.ones(1, 1))
1118                self.buffer = torch.nn.Buffer(torch.ones(1, 1))
1119
1120            def forward(self, x):
1121                x = torch.nn.functional.linear(x, torch.randn(4, 4))
1122                return torch.cos(x).relu() + self.weight + self.buffer
1123
1124        class MyModule(torch.nn.Module):
1125            def __init__(self) -> None:
1126                super().__init__()
1127                self.block = MyBlock()
1128
1129            def forward(self, x):
1130                out = self.block(x)
1131                return out
1132
1133        m = MyModule()
1134        exported = torch._dynamo.export(m, aten_graph=False)(inp)
1135        out_graph = exported[0]
1136
1137        attr_access_count = 0
1138        for node in out_graph.graph.nodes:
1139            if node.op == "get_attr":
1140                attr_access_count += 1
1141                self.assertTrue(node.meta["nn_module_stack"] is not None)
1142        self.assertEqual(attr_access_count, 2)
1143
1144        torch._dynamo.reset()
1145
1146        exported = torch._dynamo.export(m, aten_graph=True)(inp)
1147        out_graph = exported[0]
1148
1149        attr_access_count = 0
1150        for node in out_graph.graph.nodes:
1151            if node.op == "get_attr":
1152                attr_access_count += 1
1153                self.assertTrue(node.meta["nn_module_stack"] is not None)
1154        self.assertEqual(attr_access_count, 2)
1155
1156    def test_export_compare_optimize_with_make_fx(self):
1157        inp = torch.tensor([0.1, 0.1])
1158        linear = torch.nn.Linear(2, 2)
1159
1160        def func(x):
1161            x = x + 1
1162            y = x.t()
1163            y = y.relu()
1164            y = linear(y)
1165            return y
1166
1167        exported = torch._dynamo.export(func, aten_graph=True)(inp)
1168        out_graph = exported[0]
1169        export_result = out_graph(inp)
1170
1171        torch._dynamo.reset()
1172
1173        def compiler(gm, sample_inputs):
1174            def fw(*args):
1175                aten_gm = make_fx(gm)(*args)
1176                return aten_gm(*args)
1177
1178            return fw
1179
1180        opt_func = torch._dynamo.optimize(compiler, nopython=True, dynamic=True)(func)
1181        make_fx_result_through_backend = opt_func(inp)
1182
1183        fx_g = make_fx(func)(inp)
1184        make_fx_result_through_direct = fx_g(inp)
1185
1186        self.assertTrue(
1187            torch._dynamo.utils.same(make_fx_result_through_backend, export_result)
1188        )
1189        self.assertTrue(
1190            torch._dynamo.utils.same(make_fx_result_through_direct, export_result)
1191        )
1192
1193    def test_export_with_constant_method_on_module(self):
1194        class MyModule(torch.nn.Module):
1195            def __init__(self) -> None:
1196                super().__init__()
1197                self.param = torch.nn.Parameter(torch.rand(4, 2))
1198                self.linear = torch.nn.Linear(2, 2)
1199
1200            @torch._dynamo.assume_constant_result
1201            def helper_fn(self, x):
1202                return torch.nonzero(x)
1203
1204            def forward(self, x):
1205                y = torch.sin(x)
1206                x = self.linear(x)
1207                y = self.helper_fn(x)
1208                return y
1209
1210        module = MyModule()
1211        real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
1212        module = MyModule()
1213        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
1214        result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
1215        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1216        result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
1217        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1218
1219    def test_export_with_constant_method_on_module_invoke_twice(self):
1220        class MyModule(torch.nn.Module):
1221            def __init__(self) -> None:
1222                super().__init__()
1223                self.param = torch.nn.Parameter(torch.rand(4, 2))
1224                self.linear = torch.nn.Linear(2, 2)
1225
1226            @torch._dynamo.assume_constant_result
1227            def helper_fn(self, x):
1228                return torch.nonzero(x)
1229
1230            def forward(self, x):
1231                y = torch.sin(x)
1232                x = self.linear(x)
1233                y = self.helper_fn(x) + self.helper_fn(x)
1234                return y
1235
1236        module = MyModule()
1237        real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
1238        module = MyModule()
1239        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
1240        result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
1241        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1242        result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
1243        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1244
1245    def test_export_with_constant_free_function(self):
1246        @torch._dynamo.assume_constant_result
1247        def helper_fn(x):
1248            return torch.nonzero(x)
1249
1250        class MyModule(torch.nn.Module):
1251            def __init__(self) -> None:
1252                super().__init__()
1253                self.param = torch.nn.Parameter(torch.rand(4, 2))
1254                self.linear = torch.nn.Linear(2, 2)
1255
1256            @torch._dynamo.assume_constant_result
1257            def helper_fn(self, x):
1258                return torch.nonzero(x)
1259
1260            def forward(self, x):
1261                y = torch.sin(x)
1262                x = self.linear(x)
1263                y = helper_fn(x) + self.helper_fn(x)
1264                return y
1265
1266        module = MyModule()
1267        real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
1268        module = MyModule()
1269        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
1270        result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
1271        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1272        result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
1273        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1274
1275    def test_export_with_constant_free_function_and_class_method(self):
1276        @torch._dynamo.assume_constant_result
1277        def helper_fn(x):
1278            return torch.nonzero(x)
1279
1280        class MyModule(torch.nn.Module):
1281            def __init__(self) -> None:
1282                super().__init__()
1283                self.param = torch.nn.Parameter(torch.rand(4, 2))
1284                self.linear = torch.nn.Linear(2, 2)
1285
1286            def forward(self, x):
1287                y = torch.sin(x)
1288                x = self.linear(x)
1289                y = helper_fn(x)
1290                return y
1291
1292        module = MyModule()
1293        real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
1294        module = MyModule()
1295        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
1296        result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
1297        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1298        result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
1299        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1300
1301    def test_export_with_constant_free_function_and_class_method_multiarg(self):
1302        @torch._dynamo.assume_constant_result
1303        def helper_fn(x):
1304            return torch.nonzero(x)
1305
1306        class MyModule(torch.nn.Module):
1307            def __init__(self) -> None:
1308                super().__init__()
1309                self.param = torch.nn.Parameter(torch.rand(4, 2))
1310                self.linear = torch.nn.Linear(2, 2)
1311
1312            def forward(self, x, z):
1313                y = torch.sin(x)
1314                x = self.linear(x)
1315                y = helper_fn(x) + helper_fn(z)
1316                return y
1317
1318        module = MyModule()
1319        real_result = module(
1320            torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
1321        )
1322        module = MyModule()
1323        graph, _ = torch._dynamo.export(module)(
1324            torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
1325        )
1326        result = graph(
1327            torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]])
1328        )
1329        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1330        result = graph(
1331            torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]])
1332        )
1333        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1334
1335    def test_export_with_constant_free_function_and_class_method_multiarg_diff(self):
1336        @torch._dynamo.assume_constant_result
1337        def helper_fn(x):
1338            return torch.nonzero(x)
1339
1340        class MyModule(torch.nn.Module):
1341            def forward(self, x, z):
1342                y = helper_fn(x) + helper_fn(z)
1343                return y
1344
1345        module = MyModule()
1346        real_result = module(
1347            torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
1348        )
1349        module = MyModule()
1350        graph, _ = torch._dynamo.export(module)(
1351            torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]])
1352        )
1353        result = graph(
1354            torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]])
1355        )
1356        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1357        result = graph(
1358            torch.tensor([[1, 0], [0.25, 0.25]]),
1359            torch.tensor([[0.33, 0.33], [0.25, 0.25]]),
1360        )
1361        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1362
1363    def test_export_with_constant_tuple_nonzero(self):
1364        class MyModule(torch.nn.Module):
1365            @torch._dynamo.assume_constant_result
1366            def helper_fn(self, x):
1367                return (torch.nonzero(x), torch.nonzero(x))
1368
1369            def forward(self, x):
1370                y = torch.tensor([0.5])
1371                elements = self.helper_fn(x)
1372                all_y = []
1373                for element in elements:
1374                    for item in element:
1375                        all_y.append(y * item)
1376                return all_y
1377
1378        module = MyModule()
1379        real_result = module(torch.tensor([1.0, 1.0]))
1380        graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
1381
1382        # Tensor input can be almost anything here, and the result will capture what we
1383        # made constant at compile time.
1384        result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
1385        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1386
1387    def test_export_with_constant_list_nonzero(self):
1388        class MyModule(torch.nn.Module):
1389            @torch._dynamo.assume_constant_result
1390            def helper_fn(self, x):
1391                return [torch.nonzero(x), torch.nonzero(x)]
1392
1393            def forward(self, x):
1394                y = torch.tensor([0.5])
1395                elements = self.helper_fn(x)
1396                all_y = []
1397                for element in elements:
1398                    for item in element:
1399                        all_y.append(y * item)
1400                return all_y
1401
1402        module = MyModule()
1403        real_result = module(torch.tensor([1.0, 1.0]))
1404        graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
1405
1406        # Tensor input can be almost anything here, and the result will capture what we
1407        # made constant at compile time.
1408        result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
1409        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1410
1411    def test_export_with_constant_list_nonzero_free_function(self):
1412        @torch._dynamo.assume_constant_result
1413        def helper_fn(x):
1414            return [torch.nonzero(x), torch.nonzero(x)]
1415
1416        class MyModule(torch.nn.Module):
1417            def forward(self, x):
1418                y = torch.tensor([0.5])
1419                elements = helper_fn(x)
1420                all_y = []
1421                for element in elements:
1422                    for item in element:
1423                        all_y.append(y * item)
1424                return all_y
1425
1426        module = MyModule()
1427        real_result = module(torch.tensor([1.0, 1.0]))
1428        graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
1429
1430        # Tensor input can be almost anything here, and the result will capture what we
1431        # made constant at compile time.
1432        result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
1433        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1434
1435    def test_export_with_constant_dict_values(self):
1436        class MyModule(torch.nn.Module):
1437            @torch._dynamo.assume_constant_result
1438            def helper_fn(self, x):
1439                return {"x": x, "x^2": x * x}
1440
1441            def forward(self, x):
1442                y = torch.tensor([0.5])
1443                elements = self.helper_fn(x)
1444                y = y * elements["x"]
1445                y = y * elements["x^2"]
1446                return y
1447
1448        module = MyModule()
1449        real_result = module(torch.tensor([2.0, 2.0]))
1450        graph, guards = torch._dynamo.export(module)(torch.tensor([2.0, 2.0]))
1451
1452        # Tensor input can be almost anything here, and the result will capture what we
1453        # made constant at compile time.
1454        result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
1455        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1456
1457    def test_export_with_constant_none_control_flow(self):
1458        class MyModule(torch.nn.Module):
1459            @torch._dynamo.assume_constant_result
1460            def helper_fn(self, x):
1461                if x.item() < 0:
1462                    return None
1463                else:
1464                    return x
1465
1466            def forward(self, x):
1467                y = torch.tensor([0.5])
1468                x = self.helper_fn(x)
1469                if x is None:
1470                    return y
1471                return y * x
1472
1473        module = MyModule()
1474        real_result = module(torch.tensor([-1]))
1475
1476        # X is negative, so .item() < 0, which means we return y
1477        self.assertEqual(real_result, torch.tensor([0.5]))
1478
1479        graph, guards = torch._dynamo.export(module)(torch.tensor([-1]))
1480        result = graph(torch.tensor([2]))
1481        # X is positive, but we compiled helper_fn to return None, so it will still return y
1482        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1483
1484    def test_export_with_constant_not_none_control_flow(self):
1485        class MyModule(torch.nn.Module):
1486            @torch._dynamo.assume_constant_result
1487            def helper_fn(self, x):
1488                if x.item() < 0:
1489                    return None
1490                else:
1491                    return x
1492
1493            def forward(self, x):
1494                y = torch.tensor([0.5])
1495                x = self.helper_fn(x)
1496                if x is None:
1497                    return y
1498                return y * x
1499
1500        module = MyModule()
1501        real_result = module(torch.tensor([2]))
1502
1503        # X is positive, so .item() > 0, which means we return y * x
1504        self.assertEqual(real_result, torch.tensor([1.0]))
1505
1506        graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
1507        result = graph(torch.tensor([-0.5]))
1508        # X is negative, but we compiled helper_fn to return x, so it will still return y * x
1509        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1510
1511    def test_export_with_constant_none_control_flow_free_func(self):
1512        @torch._dynamo.assume_constant_result
1513        def helper_fn(x):
1514            if x.item() < 0:
1515                return None
1516            else:
1517                return x
1518
1519        class MyModule(torch.nn.Module):
1520            def forward(self, x):
1521                y = torch.tensor([0.5])
1522                x = helper_fn(x)
1523                if x is None:
1524                    return y
1525                return y * x
1526
1527        module = MyModule()
1528        real_result = module(torch.tensor([-1]))
1529
1530        # X is negative, so .item() < 0, which means we return y
1531        self.assertEqual(real_result, torch.tensor([0.5]))
1532
1533        graph, guards = torch._dynamo.export(module)(torch.tensor([-1]))
1534        result = graph(torch.tensor([2]))
1535        # X is positive, but we compiled helper_fn to return None, so it will still return y
1536        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1537
1538    def test_export_with_constant_not_none_control_flow_pos(self):
1539        class MyModule(torch.nn.Module):
1540            @torch._dynamo.assume_constant_result
1541            def helper_fn(self, x):
1542                if x.item() < 0:
1543                    return None
1544                else:
1545                    return x
1546
1547            def forward(self, x):
1548                y = torch.tensor([0.5])
1549                x = self.helper_fn(x)
1550                if x is None:
1551                    return y
1552                return y * x
1553
1554        module = MyModule()
1555        real_result = module(torch.tensor([2]))
1556
1557        # X is positive, so .item() > 0, which means we return y * x
1558        self.assertEqual(real_result, torch.tensor([1.0]))
1559
1560        graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
1561        result = graph(torch.tensor([-0.5]))
1562        # X is negative, but we compiled helper_fn to return x, so it will still return y * x
1563        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1564
1565    def test_export_with_constant_not_none_control_flow_free_func(self):
1566        @torch._dynamo.assume_constant_result
1567        def helper_fn(x):
1568            if x.item() < 0:
1569                return None
1570            else:
1571                return x
1572
1573        class MyModule(torch.nn.Module):
1574            def forward(self, x):
1575                y = torch.tensor([0.5])
1576                x = helper_fn(x)
1577                if x is None:
1578                    return y
1579                return y * x
1580
1581        module = MyModule()
1582        real_result = module(torch.tensor([2]))
1583
1584        # X is positive, so .item() > 0, which means we return y * x
1585        self.assertEqual(real_result, torch.tensor([1.0]))
1586
1587        graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
1588        result = graph(torch.tensor([-0.5]))
1589        # X is negative, but we compiled helper_fn to return x, so it will still return y * x
1590        self.assertTrue(torch._dynamo.utils.same(result, real_result))
1591
1592    def test_export_with_constant_not_return_const(self):
1593        class MyModule(torch.nn.Module):
1594            @torch._dynamo.assume_constant_result
1595            def helper_fn(self, x):
1596                return self.val
1597
1598            def forward(self, x):
1599                y = torch.tensor([0.5])
1600                x = self.helper_fn(x)
1601                if x == "A":
1602                    return y
1603                return -1
1604
1605        module = MyModule()
1606        module.val = "A"
1607        resA = module(torch.tensor([2]))
1608        graph, guards = torch._dynamo.export(module)(torch.tensor([2]))
1609        module.val = "B"
1610        resB = graph(torch.tensor([2]))
1611        self.assertTrue(torch._dynamo.utils.same(resA, resB))
1612
1613    def test_export_with_builtin_op_on_assume_constant(self):
1614        @torch._dynamo.assume_constant_result
1615        def get_y(y) -> torch.Tensor:
1616            return y
1617
1618        class Bob(torch.nn.Module):
1619            def __init__(self, p, val) -> None:
1620                super().__init__()
1621                self.p = p
1622                self.y = torch.nn.Parameter(torch.tensor(val))
1623
1624            def forward(self, x: torch.Tensor) -> torch.Tensor:
1625                # This only looks dynamic but it's actually a constant value
1626                if get_y(self.y) < self.p:
1627                    return torch.cat([x, x])
1628                else:
1629                    return x
1630
1631        model = Bob(0.5, 0.3)
1632        inp = torch.ones(3, 4)
1633        graph, guards = torch._dynamo.export(model)(inp)
1634        self.assertEqual(model(inp), graph(inp))
1635
1636    def test_export_with_constant_in_unspecialized_nn_module(self):
1637        class Module(torch.nn.Module):
1638            def __init__(self, y):
1639                super().__init__()
1640                self.y = y
1641
1642            @torch._dynamo.assume_constant_result
1643            def check(self):
1644                return self.y[0].item() == 1
1645
1646            def forward(self, x):
1647                # This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo
1648                self.device = x.device
1649
1650                if self.check():
1651                    return x + 1
1652                else:
1653                    return x + 2
1654
1655        model = Module(torch.tensor([1]))
1656        inp = torch.ones(3, 4)
1657        graph, _ = torch._dynamo.export(model)(inp)
1658        self.assertEqual(model(inp), graph(inp))
1659
1660    def test_export_decomp(self):
1661        def f(x):
1662            return x.t() + x.t()
1663
1664        def nop(x):
1665            return x.cos()
1666
1667        graph, _ = torch._dynamo.export(
1668            f,
1669            aten_graph=True,
1670            decomposition_table={torch.ops.aten.t.default: nop},
1671        )(torch.randn(5))
1672        self.assertEqual(
1673            len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
1674            0,
1675        )
1676
1677        graph, _ = torch._dynamo.export(f, aten_graph=True, decomposition_table=None)(
1678            torch.randn(5)
1679        )
1680        self.assertEqual(
1681            len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
1682            2,
1683        )
1684
1685    def test_export_decomp_asserts_bad_args(self):
1686        def f(x):
1687            return x.t() + x.t()
1688
1689        def nop(x):
1690            return x.cos()
1691
1692        with self.assertRaises(AssertionError):
1693            graph, _ = torch._dynamo.export(
1694                f,
1695                (torch.randn(5)),
1696                aten_graph=False,
1697                decomposition_table={torch.ops.aten.t.default: nop},
1698            )
1699
1700    @config.patch(capture_scalar_outputs=True)
1701    def test_export_with_module_layer(self):
1702        from functorch.experimental.control_flow import cond
1703
1704        class Module(torch.nn.Module):
1705            def __init__(self) -> None:
1706                super().__init__()
1707                self.linear = torch.nn.Linear(3, 3)
1708
1709            def forward(self, pred, x):
1710                def true_fn(val):
1711                    return self.linear(val) * torch.tensor(2)
1712
1713                def false_fn(val):
1714                    return self.linear(val) * torch.tensor(-1)
1715
1716                return cond(pred, true_fn, false_fn, [x])
1717
1718        mod = Module()
1719        x = torch.randn([3, 3])
1720        pred = torch.tensor(x[0][0].item() < 0)
1721        real_result = mod.forward(pred, x)
1722
1723        torch._dynamo.reset()
1724
1725        exported = torch._dynamo.export(mod.forward)(pred, x)
1726        out_graph = exported[0]
1727
1728        dynamo_result = out_graph(pred, x)
1729        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
1730
1731        # New X, just to show we did not specialize
1732        x = x * -1
1733        pred = torch.tensor(x[0][0].item() < 0)
1734        real_result_2 = mod.forward(pred, x)
1735        dynamo_result_2 = out_graph(pred, x)
1736        self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2))
1737
1738    @config.patch(capture_scalar_outputs=True)
1739    def test_export_with_cond_branches_calling_methods(self):
1740        from functorch.experimental.control_flow import cond
1741
1742        class Module(torch.nn.Module):
1743            # ok
1744            def __init__(self) -> None:
1745                super().__init__()
1746                self.linear = torch.nn.Linear(3, 3)
1747
1748            def t(self, val):
1749                return val + 1
1750
1751            def f(self, val):
1752                return val - 1
1753
1754            def true_fn(self, val):
1755                return self.linear(val) + self.t(val)
1756
1757            def false_fn(self, val):
1758                return self.linear(val) - self.f(val)
1759
1760            def forward(self, pred, x):
1761                return cond(pred, self.true_fn, self.false_fn, [x])
1762
1763        mod = Module()
1764        x = torch.randn([3, 3])
1765        pred = torch.tensor(x[0][0].item() < 0)
1766        real_result = mod.forward(pred, x)
1767        out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
1768        dynamo_result = out_graph(pred, x)
1769        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
1770
1771    @config.patch(capture_scalar_outputs=True)
1772    def test_export_with_cond_closure(self):
1773        from functorch.experimental.control_flow import cond
1774
1775        class Foo(torch.nn.Module):
1776            def __init__(self) -> None:
1777                super().__init__()
1778
1779            def forward(self, pred, x):
1780                def true_fn(x):
1781                    return x * 2
1782
1783                def false_fn(x):
1784                    return x - 2
1785
1786                return cond(pred, true_fn, false_fn, [x])
1787
1788        class Bar(torch.nn.Module):
1789            def __init__(self) -> None:
1790                super().__init__()
1791
1792            def forward(self, pred, x):
1793                def true_fn(x):
1794                    return x * 2
1795
1796                def false_fn(x):
1797                    return x - 2
1798
1799                return cond(pred, true_fn, false_fn, [x + 1])
1800
1801        class FooBar(torch.nn.Module):
1802            def __init__(self) -> None:
1803                super().__init__()
1804                self.linear = torch.nn.Linear(3, 3)
1805
1806            def forward(self, pred, x):
1807                y = x + x
1808
1809                def true_fn(x, y):
1810                    return self.linear(x) * (x + y)
1811
1812                def false_fn(x, y):
1813                    return x * (y - x)
1814
1815                return cond(pred, true_fn, false_fn, [x, y])
1816
1817        for Module in [Foo, Bar, FooBar]:
1818            mod = Module()
1819            x = torch.randn([3, 3], requires_grad=True)
1820            pred = torch.tensor(x[0][0].item() < 0)
1821            real_result = mod.forward(pred, x)
1822            out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
1823            dynamo_result = out_graph(pred, x)
1824            self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
1825
1826    def test_export_with_cond_with_closed_function(self):
1827        def hello(x):
1828            return x + 1
1829
1830        def hi(x):
1831            return x + 2
1832
1833        def foo(pred, x):
1834            def true_fn(x):
1835                return hello(x)
1836
1837            def false_fn(x):
1838                return hi(x)
1839
1840            return cond(pred, true_fn, false_fn, [x])
1841
1842        x = torch.randn(5)
1843        pred = x[0] > 0
1844        real_result = foo(pred, x)
1845        out_graph, _ = torch._dynamo.export(foo)(pred, x)
1846        dynamo_result = out_graph(pred, x)
1847        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
1848
1849    def test_export_with_cond_dynamic_shape_pred(self):
1850        from functorch.experimental.control_flow import cond
1851
1852        class Module(torch.nn.Module):
1853            def forward(self, x):
1854                def true_fn(x):
1855                    return x + x
1856
1857                def false_fn(x):
1858                    return x[:2]
1859
1860                return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
1861
1862        class Module2(torch.nn.Module):
1863            def forward(self, x):
1864                def true_fn(x):
1865                    return x + x
1866
1867                def false_fn(x):
1868                    return x[:2]
1869
1870                return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
1871
1872        mods = [Module(), Module2()]
1873        for mod in mods:
1874            x = torch.randn(2, 2)
1875            out_graph, guards = torch._dynamo.export(mod)(x)
1876            self.assertExpectedInline(
1877                out_graph.code.strip(),
1878                """\
1879def forward(self, x):
1880    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
1881    l_x_ = arg0
1882    size = l_x_.size()
1883    getitem = size[0];  size = None
1884    le = getitem <= 2;  getitem = None
1885    cond_true_0 = self.cond_true_0
1886    cond_false_0 = self.cond_false_0
1887    cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]);  le = cond_true_0 = cond_false_0 = l_x_ = None
1888    getitem_2 = cond[0];  cond = None
1889    return pytree.tree_unflatten([getitem_2], self._out_spec)""",
1890            )
1891            self.assertExpectedInline(
1892                out_graph.cond_true_0.code.strip(),
1893                """\
1894def forward(self, l_x_):
1895    l_x__1 = l_x_
1896    add = l_x__1 + l_x__1;  l_x__1 = None
1897    return (add,)""",
1898            )
1899            self.assertExpectedInline(
1900                out_graph.cond_false_0.code.strip(),
1901                """\
1902def forward(self, l_x_):
1903    l_x__1 = l_x_
1904    getitem = l_x__1[slice(None, 2, None)];  l_x__1 = None
1905    return (getitem,)""",
1906            )
1907            with self.assertRaisesRegex(
1908                torch._dynamo.exc.UncapturedHigherOrderOpError,
1909                "Cond doesn't work unless it is captured completely with torch.compile",
1910            ):
1911                # True branch and false branch return tensors of different shape
1912                torch._dynamo.export(mod)(torch.randn(3, 2))
1913
1914            # We specialize into one of the branches since predicate is a python boolean.
1915            test_x = torch.randn(3, 2)
1916            mod(test_x)
1917
1918    def test_export_with_map_cond(self):
1919        from functorch.experimental.control_flow import cond, map
1920
1921        class Module(torch.nn.Module):
1922            def inner(self, x, pred):
1923                def true_fn(x):
1924                    return x + x
1925
1926                def false_fn(x):
1927                    return x * x
1928
1929                return cond(pred, true_fn, false_fn, [x])
1930
1931            def forward(self, pred, xs):
1932                def body(x, pred):
1933                    return self.inner(x, pred)
1934
1935                return map(body, xs, pred)
1936
1937        mod = Module()
1938        x = torch.randn(3, 2, 1)
1939        pred_x = torch.tensor(True)
1940
1941        y = torch.randn(4, 3, 2)
1942        pred_y = torch.tensor(False)
1943        real_result = mod(pred_y, y)
1944
1945        out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
1946        self.assertEqual(real_result, out_graph(pred_y, y))
1947
1948    def test_export_with_map_zero_sized_tensor(self):
1949        from functorch.experimental.control_flow import map
1950
1951        class Module(torch.nn.Module):
1952            def forward(self, xs):
1953                def body(x):
1954                    return x + 1
1955
1956                return map(body, xs)
1957
1958        mod = Module()
1959        xs = torch.randn(0, 2)
1960        with self.assertRaisesRegex(
1961            torch._dynamo.exc.Unsupported,
1962            "zero-sized tensor",
1963        ):
1964            out_graph, _ = torch._dynamo.export(mod)(xs)
1965
1966    def test_export_meta_val(self):
1967        def f(x, y, z):
1968            return x * y + z
1969
1970        gm, _ = torch._dynamo.export(
1971            f,
1972            aten_graph=True,
1973        )(
1974            torch.ones(3, 2),
1975            torch.zeros(3, 2),
1976            torch.ones(3, 2),
1977        )
1978        for node in gm.graph.nodes:
1979            if node.op == "placeholder":
1980                self.assertIn("val", node.meta)
1981
1982    def test_input_container_type(self):
1983        def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
1984            return {"a": x.sum() + sum(y).sum()}
1985
1986        inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])
1987
1988        gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp)
1989
1990        self.assertEqual(gm(*inp), f(*inp))
1991
1992    @config.patch(assume_static_by_default=False)
1993    def test_export_symbolic_shape(self):
1994        def f(x: torch.Tensor) -> torch.Tensor:
1995            return torch.empty(x.shape[0] * 2)
1996
1997        inp = (torch.randn(6, 5),)
1998        gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp)
1999
2000        has_sym_size = False
2001        for node in gm.graph.nodes:
2002            if node.target is torch.ops.aten.sym_size.int:
2003                has_sym_size = True
2004
2005        self.assertTrue(has_sym_size)
2006
2007    @config.patch(assume_static_by_default=False)
2008    def test_dynamic_slicing(self):
2009        def f(x):
2010            return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
2011
2012        gm_aten_mode, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
2013
2014        inp = torch.randn(6, 7)
2015        self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape)
2016
2017        count = 0
2018        # aten graph should flatten getitem calls to actual
2019        # slice kernel call.
2020        for node in gm_aten_mode.graph.nodes:
2021            if (
2022                node.op == "call_function"
2023                and node.target == torch.ops.aten.slice.Tensor
2024            ):
2025                count += 1
2026
2027        self.assertEqual(count, 2)
2028
2029        gm_torch_mode, _ = torch._dynamo.export(f, aten_graph=False)(torch.randn(4, 5))
2030
2031        # In torch mode, the graph should contain 3 getitem methods
2032        # one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice
2033        # this is because Tensor class has its' own getitem method
2034        # which gets translated to aten.Slice later.
2035        count = 0
2036        for node in gm_torch_mode.graph.nodes:
2037            if node.op == "call_function" and node.target == operator.getitem:
2038                count += 1
2039
2040        self.assertEqual(count, 1)
2041        self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)
2042
2043    def test_dynamic_slicing_invalid(self):
2044        def g(x, y):
2045            return x[y : x.shape[0]]
2046
2047        with self.assertRaisesRegex(
2048            torch._dynamo.exc.Unsupported,
2049            "Dynamic slicing on data-dependent value is not supported",
2050        ):
2051            torch._dynamo.export(
2052                g,
2053                aten_graph=True,
2054            )(
2055                torch.randn(4, 5),
2056                torch.tensor(2),
2057            )
2058
2059    @config.patch(capture_scalar_outputs=True)
2060    def test_dynamic_slicing_simple(self):
2061        def f(x):
2062            return x[slice(None, None, None)]
2063
2064        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
2065
2066        inp = torch.randn(6, 7)
2067        self.assertEqual(gm(inp), f(inp))
2068
2069    def test_pre_dispatch_simple(self):
2070        def f(x):
2071            y = torch.ones_like(x)
2072            return torch.matmul(x, y)
2073
2074        gm, _ = torch._dynamo.export(
2075            f,
2076            aten_graph=True,
2077            pre_dispatch=True,
2078            tracing_mode="fake",
2079        )(
2080            torch.randn(5, 5),
2081        )
2082
2083        inp = torch.randn(6, 6)
2084        self.assertEqual(gm(inp), f(inp))
2085        self.assertExpectedInline(
2086            gm.code.strip(),
2087            """\
2088def forward(self, x):
2089    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
2090    arg0_1 = arg0
2091    ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False)
2092    matmul = torch.ops.aten.matmul.default(arg0_1, ones_like);  arg0_1 = ones_like = None
2093    return pytree.tree_unflatten([matmul], self._out_spec)""",
2094        )
2095
2096    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
2097    def test_export_cond_in_aten_symbolic(self):
2098        class ConditionOp(torch.nn.Module):
2099            def true_fn(self, x, y):
2100                return x * y
2101
2102            def false_fn(self, x, y):
2103                return x + y
2104
2105            def forward(self, pred, x, y):
2106                return cond(pred, self.true_fn, self.false_fn, [x, y])
2107
2108        model = ConditionOp()
2109        inp = (
2110            torch.tensor(False),
2111            torch.randn(4, 4),
2112            torch.randn(4, 4),
2113        )
2114        gm, _ = torch._dynamo.export(model, aten_graph=True)(*inp)
2115
2116        gm.print_readable()
2117
2118        self.assertEqual(gm(*inp), model(*inp))
2119
2120    def test_export_with_kwargs(self):
2121        def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
2122            out = pos0
2123            for arg in tuple0:
2124                out *= arg
2125            for arg in myargs:
2126                out *= arg
2127            out *= mykw0
2128            out *= mykwargs["input0"] * mykwargs["input1"]
2129            return out
2130
2131        mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
2132        tuple0 = (torch.randn(4), torch.randn(4))
2133        mykw0 = torch.randn(4)
2134        pos0 = torch.randn(4)
2135        myargs = [torch.randn(4), torch.randn(4)]
2136
2137        expected_argument_names = [
2138            "pos0",
2139            "tuple0",
2140            "myargs_0",
2141            "myargs_1",
2142            "mykw0",
2143            "input0",
2144            "input1",
2145        ]
2146        self._test_export_preserving_original_signature(
2147            fn_with_kwargs,
2148            expected_argument_names,
2149            pos0,
2150            tuple0,
2151            *myargs,
2152            mykw0=mykw0,
2153            **mykwargs,
2154        )
2155
2156    def test_export_with_kwargs_and_empty_args(self):
2157        def fn_with_kwargs(mykw0=None, **mykwargs):
2158            out = mykw0
2159            out *= mykwargs["input0"] * mykwargs["input1"]
2160            return out
2161
2162        mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
2163        mykw0 = torch.randn(4)
2164
2165        expected_argument_names = ["mykw0"] + list(mykwargs.keys())
2166        self._test_export_preserving_original_signature(
2167            fn_with_kwargs, expected_argument_names, mykw0, **mykwargs
2168        )
2169
2170    def test_export_with_args_and_empty_kwargs(self):
2171        def fn_with_kwargs(pos0, tuple0, *myargs):
2172            out = pos0
2173            for arg in tuple0:
2174                out *= arg
2175            for arg in myargs:
2176                out *= arg
2177            return out
2178
2179        tuple0 = (torch.randn(4), torch.randn(4))
2180        pos0 = torch.randn(4)
2181        myargs = [torch.randn(4), torch.randn(4)]
2182
2183        expected_argument_names = ["pos0", "tuple0", "myargs_0", "myargs_1"]
2184        self._test_export_preserving_original_signature(
2185            fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs
2186        )
2187
2188    @common_utils.parametrize(
2189        "default_value",
2190        [
2191            common_utils.subtest(None, name="None"),
2192            common_utils.subtest(42.0, name="float"),
2193            common_utils.subtest(
2194                # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
2195                torch.randn(4),
2196                name="tensor",
2197                decorators=[unittest.expectedFailure],
2198            ),
2199            common_utils.subtest(
2200                # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
2201                (torch.randn(4),),
2202                name="tuple",
2203                decorators=[unittest.expectedFailure],
2204            ),
2205        ],
2206    )
2207    def test_export_with_args_with_default(self, default_value):
2208        def fn(pos0, pos1_default=default_value):
2209            out = pos0
2210            if pos1_default is None:
2211                pos1_default = torch.randn(4)
2212            if isinstance(pos1_default, tuple):
2213                pos1_default = pos1_default[0]
2214            out *= pos1_default
2215            return out
2216
2217        pos0 = torch.randn(4)
2218        expected_argument_names = ["pos0"]
2219        self._test_export_preserving_original_signature(
2220            fn, expected_argument_names, pos0
2221        )
2222
2223    @common_utils.parametrize(
2224        "default_value",
2225        [
2226            common_utils.subtest(None, name="None"),
2227            common_utils.subtest(42.0, name="float"),
2228            common_utils.subtest(
2229                # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
2230                torch.randn(4),
2231                name="tensor",
2232                decorators=[unittest.expectedFailure],
2233            ),
2234            common_utils.subtest(
2235                # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
2236                (torch.randn(4),),
2237                name="tuple",
2238                decorators=[unittest.expectedFailure],
2239            ),
2240        ],
2241    )
2242    def test_export_with_kwargs_with_default(self, default_value):
2243        def fn(pos0, *, kw0, kw1_default=default_value, **kwargs):
2244            out = pos0
2245            out += kw0
2246            if kw1_default is None:
2247                kw1_default = torch.randn(4)
2248            elif isinstance(kw1_default, tuple):
2249                kw1_default = kw1_default[0]
2250            out += kw1_default
2251            out += kwargs["kw2"]
2252            return out
2253
2254        pos0 = torch.randn(4)
2255        kw0 = torch.randn(4)
2256        kw2 = torch.randn(4)
2257
2258        args = (pos0,)
2259        kwargs = {"kw0": kw0, "kw2": kw2}
2260        expected_argument_names = ["pos0", "kw0", "kw2"]
2261        self._test_export_preserving_original_signature(
2262            fn, expected_argument_names, *args, **kwargs
2263        )
2264
2265    def test_export_with_wrapped_fn(self):
2266        # To ensure dynamo.export is robust to wrapped functions
2267        # when it cannot use `inspect` to retrieve original signature
2268        # info.
2269        def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
2270            out = pos0
2271            out += pos1
2272            out += kw0
2273            out += kw1
2274            for arg in args:
2275                out += arg
2276            for kwarg in kwargs.values():
2277                out += kwarg
2278            return out
2279
2280        def wrapped_fn(*args, **kwargs):
2281            return _fn(*args, **kwargs)
2282
2283        pos0 = torch.randn(4)
2284        kw0 = torch.randn(4)
2285        args = (pos0, torch.randn(4), torch.randn(4))
2286        kwargs = {"kw0": kw0, "kw2": torch.randn(4)}
2287        expected_argument_names = [f"args_{i}" for i in range(len(args))] + list(
2288            kwargs.keys()
2289        )
2290
2291        self._test_export_preserving_original_signature(
2292            wrapped_fn, expected_argument_names, *args, **kwargs
2293        )
2294
2295    def test_export_with_functools_wrapped_method(self):
2296        def test_decorator(func):
2297            @functools.wraps(func)
2298            def wrapper(*args, **kwargs):
2299                return func(*args, **kwargs)
2300
2301            return wrapper
2302
2303        class MyModule(torch.nn.Module):
2304            def __init__(self) -> None:
2305                super().__init__()
2306
2307            def forward(self, x):
2308                return x
2309
2310            @test_decorator
2311            def method_to_test(self, pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
2312                out = pos0
2313                out += pos1
2314                out += kw0
2315                out += kw1
2316                for arg in args:
2317                    out += arg
2318                for kwarg in kwargs.values():
2319                    out += kwarg
2320                return out
2321
2322        pos0 = torch.randn(4)
2323        pos1 = torch.randn(4)
2324        unnamed_pos = torch.randn(4)
2325        kw0 = torch.randn(4)
2326        args = (pos0, pos1, unnamed_pos)
2327        kwargs = {"kw0": kw0, "kw2": torch.randn(4), "unnamed_kw": torch.randn(4)}
2328        expected_argument_names = [
2329            "pos0",
2330            "pos1",
2331            "args_0",  # 3rd unnamed positional argument
2332        ] + list(kwargs.keys())
2333        m = MyModule()
2334
2335        self._test_export_preserving_original_signature(
2336            m.method_to_test, expected_argument_names, *args, **kwargs
2337        )
2338
2339    def test_export_with_functools_wrapped_fn(self):
2340        def test_decorator(func):
2341            @functools.wraps(func)
2342            def wrapper(*args, **kwargs):
2343                return func(*args, **kwargs)
2344
2345            return wrapper
2346
2347        @test_decorator
2348        def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
2349            out = pos0
2350            out += pos1
2351            out += kw0
2352            out += kw1
2353            for arg in args:
2354                out += arg
2355            for kwarg in kwargs.values():
2356                out += kwarg
2357            return out
2358
2359        def wrapped_fn(*args, **kwargs):
2360            return _fn(*args, **kwargs)
2361
2362        pos0 = torch.randn(4)
2363        kw0 = torch.randn(4)
2364        args = (pos0, torch.randn(4), torch.randn(4))
2365        kwargs = {"kw0": kw0, "kw2": torch.randn(4)}
2366        expected_argument_names = [f"args_{i}" for i in range(len(args))] + list(
2367            kwargs.keys()
2368        )
2369
2370        self._test_export_preserving_original_signature(
2371            wrapped_fn, expected_argument_names, *args, **kwargs
2372        )
2373
2374    def _test_export_preserving_original_signature(
2375        self, fn, expected_argument_names: Sequence[str], *args, **kwargs
2376    ):
2377        torch._dynamo.reset()
2378        exported = torch._dynamo.export(
2379            fn,
2380            *args,
2381            **kwargs,
2382            aten_graph=False,
2383        )
2384
2385        out_graph = exported[0]
2386        dynamo_result = out_graph(*args, **kwargs)
2387        real_result = fn(*args, **kwargs)
2388        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
2389
2390        # Check that the exported graph preserves same argument names.
2391        self.assertEqual(
2392            inspect.getfullargspec(out_graph.forward).args[1:], expected_argument_names
2393        )
2394
2395    def test_dataclass_input_output(self):
2396        from dataclasses import dataclass
2397
2398        @dataclass
2399        class Tensors:
2400            x: torch.Tensor
2401            y: torch.Tensor
2402
2403        def f(t):
2404            return t.x + t.y
2405
2406        with self.assertRaisesRegex(
2407            UserError,
2408            "It looks like one of the inputs with type .*Tensors.* "
2409            "is not supported or pytree-flattenable",
2410        ):
2411            torch._dynamo.export(f, aten_graph=False)(
2412                Tensors(x=torch.randn(10), y=torch.randn(10))
2413            )
2414
2415        def f(x, y):
2416            return Tensors(x=x.sin(), y=y.cos())
2417
2418        with self.assertRaisesRegex(
2419            UserError,
2420            "It looks like one of the outputs with type .*Tensors.* "
2421            "is not supported or pytree-flattenable",
2422        ):
2423            torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10))
2424
2425    def test_empty(self):
2426        def f(x):
2427            return x
2428
2429        exported = torch._dynamo.export(f)(torch.randn(3, 3))
2430        out_graph = exported[0]
2431        inp = torch.randn(3, 3)
2432        self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp)))
2433
2434        class M(torch.nn.Module):
2435            def __init__(self) -> None:
2436                super().__init__()
2437                self.a = torch.ones(3, 3)
2438
2439            def forward(self):
2440                return self.a
2441
2442        exported = torch._dynamo.export(M())()
2443        out_graph = exported[0]
2444        self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph()))
2445
2446    @unittest.skipIf(not TEST_CUDA, "No CUDA available.")
2447    def test_export_with_parameters(self):
2448        class MyModule(torch.nn.Module):
2449            def __init__(self) -> None:
2450                super().__init__()
2451                self.features = torch.nn.Sequential(
2452                    torch.nn.Conv2d(
2453                        3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
2454                    ),
2455                    torch.nn.ReLU(inplace=True),
2456                )
2457
2458            def forward(self, x):
2459                return self.features(x)
2460
2461        model = MyModule().eval().cuda()
2462        random_inputs = (torch.rand([32, 3, 32, 32]).to("cuda"),)
2463        dim_x = torch.export.Dim("dim_x", min=1, max=32)
2464        exp_program = torch.export.export(
2465            model, random_inputs, dynamic_shapes={"x": {0: dim_x}}
2466        )
2467        output_buffer = io.BytesIO()
2468        # Tests if we can restore saved nn.Parameters when we load them again
2469        torch.export.save(exp_program, output_buffer)
2470        loaded_model = torch.export.load(output_buffer)
2471        self.assertTrue(
2472            isinstance(
2473                loaded_model.module().get_parameter("features.0.weight"),
2474                torch.nn.Parameter,
2475            )
2476        )
2477
2478    def test_export_fast_binary_broadcast_check(self):
2479        # This test looks at the case where we erroneously create a guard
2480        # when checking the equality of the operands' shape and the output
2481        # shape during FakeTensor's binary op fast path.
2482
2483        class MyModel(torch.nn.Module):
2484            def forward(self, a, b):
2485                # final shape is (dim0, 4, 8)
2486                # order matters since a & the output have the same shape
2487                return b + a
2488
2489        a = torch.randn(100, 4, 8)
2490        b = torch.randn(4, 8)
2491        model = MyModel().eval().cuda()
2492        batchsize = torch.export.Dim("dim0", min=3, max=1024)
2493        dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]}
2494
2495        torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec)
2496
2497    def test_export_fast_binary_broadcast_check_unbacked(self):
2498        class MyModel(torch.nn.Module):
2499            def forward(self, numel, scalar):
2500                u0 = numel.item()
2501                torch._check_is_size(u0)
2502                x = torch.ones(u0 + 1)
2503                return scalar - x
2504
2505        model = MyModel().eval().cuda()
2506        numel = torch.tensor(10)
2507        scalar = torch.randn(1)
2508        torch.export.export(model, (numel, scalar))
2509
2510    def test_export_meta(self):
2511        class MyModule(torch.nn.Module):
2512            def __init__(self) -> None:
2513                super().__init__()
2514                self.p = torch.nn.Parameter(torch.ones(2, 3))
2515
2516            def forward(self, x):
2517                return self.p + x
2518
2519        with torch.device("meta"):
2520            m = MyModule()
2521
2522        inp = torch.ones(2, 3, device="meta")
2523        exported = torch._dynamo.export(m)(inp)
2524        out_graph = exported[0]
2525        dynamo_result = out_graph(inp)
2526        self.assertEqual(dynamo_result, m(inp))
2527
2528    def test_constraint_violation_error_messages(self):
2529        class Foo(torch.nn.Module):
2530            def forward(self, x):
2531                if x.shape[0] == x.shape[1] * 2:
2532                    return x + 1
2533                else:
2534                    return x + 2
2535
2536        foo = Foo()
2537
2538        t = torch.zeros([8, 4])
2539        dim0 = torch.export.Dim("dim0", min=3, max=10)
2540        dim1 = torch.export.Dim("dim1")
2541        dynamic_shapes = {"x": (dim0, dim1)}
2542
2543        with self.assertRaisesRegex(
2544            torch._dynamo.exc.UserError,
2545            "Constraints violated .*!(.*\n)*.*"
2546            "by dim0 = 2\\*dim1(.*\n)*.*"
2547            "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*",
2548        ):
2549            torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes)
2550
2551        class Bar(torch.nn.Module):
2552            def forward(self, x):
2553                if x.shape[0] == 5:
2554                    return x + 1
2555                else:
2556                    return x + 2
2557
2558        bar = Bar()
2559
2560        t = torch.zeros([5])
2561        dim0 = torch.export.Dim("dim0", min=3, max=8)
2562        dynamic_shapes = {"x": (dim0,)}
2563        with self.assertRaisesRegex(
2564            torch._dynamo.exc.UserError,
2565            "Not all values.*valid.*inferred to be a constant",
2566        ):
2567            torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes)
2568
2569        class Qux(torch.nn.Module):
2570            def forward(self, x):
2571                if x.shape[0] > 5 and x.shape[0] < 10:
2572                    return x + 1
2573                else:
2574                    return x + 2
2575
2576        qux = Qux()
2577
2578        t = torch.zeros([7])
2579        dim0 = torch.export.Dim("dim0", min=3, max=8)
2580        dynamic_shapes = {"x": (dim0,)}
2581        with self.assertRaisesRegex(
2582            torch._dynamo.exc.UserError,
2583            "Not all values.*satisfy the generated guard",
2584        ):
2585            torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes)
2586
2587    def test_untracked_inputs_in_constraints(self):
2588        from copy import copy
2589
2590        class Foo(torch.nn.Module):
2591            def forward(self, x, y):
2592                return y + 1
2593
2594        foo = Foo()
2595
2596        x = torch.randn(2)
2597        y = torch.randn(5, 4)
2598
2599        dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y")
2600        dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
2601
2602        example_inputs = (copy(x), y)
2603        ep = torch.export.export(foo, example_inputs, dynamic_shapes=dynamic_shapes)
2604        ep.module()(torch.randn(3), y)  # no specialization error
2605
2606    def test_export_raise_guard_full_constraint(self):
2607        y = torch.randn([3, 3, 3])
2608
2609        def my_dyn_fn(x):
2610            if x.shape[0] == 3:
2611                return x.sin()
2612            return x.cos()
2613
2614        torch._dynamo.export(my_dyn_fn)(y)
2615
2616        with self.assertRaises(ConstraintViolationError):
2617            torch._dynamo.export(
2618                my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
2619            )(y)
2620
2621    def test_export_module_specify_constraints_signature(self):
2622        y = torch.randn([3, 3, 3])
2623
2624        class Mod(torch.nn.Module):
2625            def forward(self, x):
2626                if x.shape[0] == 3:
2627                    return x.sin()
2628                return x.cos()
2629
2630        mod = Mod()
2631        torch._dynamo.export(mod)(y)
2632
2633        with self.assertRaisesRegex(ConstraintViolationError, "dimx = 3"):
2634            torch._dynamo.export(mod, dynamic_shapes=({0: torch.export.Dim("dimx")},))(
2635                y
2636            )
2637
2638    def test_export_raise_guard_partial_constraint(self):
2639        y = torch.randn([3, 3, 3])
2640
2641        def my_dyn_fn(x):
2642            if x.shape[0] > 3:
2643                return x.sin()
2644            return x.cos()
2645
2646        torch._dynamo.export(my_dyn_fn)(y)
2647
2648        with self.assertRaises(ConstraintViolationError):
2649            torch._dynamo.export(
2650                my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
2651            )(y)
2652
2653    def test_export_raise_on_relationship(self):
2654        y = torch.randn([3, 3, 3])
2655
2656        def my_dyn_fn(a, b, c):
2657            if a.shape[0] == b.shape[1] == c.shape[2]:
2658                return a.sin()
2659
2660            return a.cos()
2661
2662        torch._dynamo.export(my_dyn_fn)(y, y, y)
2663        dim = torch.export.Dim("dim")
2664        dynamic_shapes = ({0: dim}, {0: dim}, {0: dim})
2665        with self.assertRaises(ConstraintViolationError):
2666            torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
2667        dynamic_shapes = ({0: dim}, {1: dim}, {2: dim})
2668        torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
2669
2670    def test_export_no_raise(self):
2671        y = torch.randn([3, 3, 3])
2672
2673        def my_dyn_fn(a, b, c):
2674            if a.shape[1] == 3:
2675                return a.cos()
2676            return a * b * c
2677
2678        torch._dynamo.export(my_dyn_fn)(y, y, y)
2679        dim = torch.export.Dim("dim")
2680        dynamic_shapes = ({0: dim}, {0: dim}, {0: dim})
2681        torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
2682
2683    def test_export_multi_dynamic_dim_unsafe_relationship(self):
2684        x = torch.randn([3, 3, 3])
2685        y = torch.randn([2, 2, 2])
2686        z = torch.randn([3, 3, 3])
2687
2688        def my_dyn_fn(a, b, c):
2689            if a.shape[0] == c.shape[0]:
2690                return a.cos()
2691            return a * c, b
2692
2693        torch._dynamo.export(my_dyn_fn)(x, y, z)
2694        dimx, dimy, dimz = torch.export.dims("dimx", "dimy", "dimz")
2695        dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz})
2696        with self.assertRaises(ConstraintViolationError):
2697            torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
2698        dimz = dimx
2699        dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz})
2700        torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
2701
2702    def test_remove_redundant_dynamic_dim_in_error_message(self):
2703        class Foo(torch.nn.Module):
2704            def forward(self, x, y):
2705                if x.shape[0] == y["k"].shape[0]:
2706                    return x + 1
2707                else:
2708                    return x - 1
2709
2710        foo = Foo()
2711
2712        a = torch.randn(3)
2713        b = torch.randn(3)
2714        dim0_a, dim0_b = torch.export.dims("dim0_a", "dim0_b")
2715        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "dim0_b = dim0_a"):
2716            torch.export.export(
2717                foo,
2718                (a, {"k": b}),
2719                dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}},
2720            )
2721
2722    def test_enforce_equalities(self):
2723        class Bar(torch.nn.Module):
2724            def forward(self, x, y):
2725                return torch.matmul(x, y)
2726
2727        bar = Bar()
2728
2729        batch, size = torch.export.dims("batch", "size")
2730        dynamic_shapes = {"x": (batch, size, size), "y": (batch, size, size)}
2731
2732        x = torch.randn(10, 3, 3)
2733        y = torch.randn(10, 3, 4)
2734        with self.assertRaisesRegex(
2735            torch._dynamo.exc.UserError,
2736            ".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3",
2737        ):
2738            torch.export.export(
2739                bar,
2740                (x, y),
2741                dynamic_shapes=dynamic_shapes,
2742            )
2743        y = torch.randn(10, 3, 3)
2744        ebar = torch.export.export(
2745            bar,
2746            (x, y),
2747            dynamic_shapes=dynamic_shapes,
2748        )
2749        self.assertEqual(
2750            [
2751                str(node.meta["val"].shape)
2752                for node in ebar.graph_module.graph.nodes
2753                if node.op == "placeholder"
2754            ],
2755            ["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
2756        )
2757
2758    @torch._dynamo.config.patch(
2759        capture_dynamic_output_shape_ops=True,
2760        specialize_int=True,
2761        capture_scalar_outputs=True,
2762    )
2763    def test_export_preserve_constraints_as_metadata_tensor(self):
2764        def f(x):
2765            b = x.nonzero()
2766            torch._check(b.shape[0] >= 2)
2767            torch._check(b.shape[0] <= 5)
2768            return b
2769
2770        y = torch.tensor([8, 8, 6])
2771        gm, _ = torch._dynamo.export(
2772            f,
2773            aten_graph=True,
2774            tracing_mode="symbolic",
2775        )(y)
2776
2777    @config.patch(
2778        capture_dynamic_output_shape_ops=True,
2779        specialize_int=True,
2780        capture_scalar_outputs=True,
2781    )
2782    def test_exported_graph_serialization(self):
2783        def f(x, y):
2784            b = x.item()
2785            torch._check_is_size(b)
2786            return torch.empty((b, y.shape[0]))
2787
2788        x = torch.tensor([3])
2789        y = torch.randn([8, 8, 6])
2790        example_inputs = [x, y]
2791        dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)})
2792        gm, _ = torch._dynamo.export(
2793            f,
2794            dynamic_shapes=dynamic_shapes,
2795            aten_graph=True,
2796            tracing_mode="symbolic",
2797        )(*example_inputs)
2798
2799        # Ensure the exported graph module with metadata is serializable,
2800        # metadata won't be saved in the serialized module
2801        buffer = io.BytesIO()
2802        torch.save(gm, buffer)
2803
2804    def test_export_dynamic_dim_not_1(self):
2805        x = torch.randn([1, 1, 1])
2806
2807        def my_dyn_fn(a):
2808            if a.shape[0] != 1:
2809                return a.cos()
2810            return a * a
2811
2812        torch._dynamo.export(my_dyn_fn)(x)
2813        with self.assertRaises(ConstraintViolationError):
2814            torch._dynamo.export(
2815                my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
2816            )(x)
2817
2818    def test_symbool(self):
2819        def f(x):
2820            a = torch.scalar_tensor(x.shape[0] > 4)
2821            return x.sin().sum() + a.sum()
2822
2823        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
2824        self.assertEqual(gm(torch.ones(3, 4)), f(torch.ones(3, 4)))
2825
2826    def test_export_multi_dynamic_dim_constraint(self):
2827        x = torch.randn([3, 3, 3])
2828        y = torch.randn([2, 2, 2])
2829        z = torch.randn([3, 3, 3])
2830
2831        def my_dyn_fn(a, b, c):
2832            if a.shape[0] == c.shape[0]:
2833                return a.cos()
2834            return a * c, b
2835
2836        torch._dynamo.export(my_dyn_fn)(x, y, z)
2837        dimx_0, dimx_1, dimx_2 = torch.export.dims("dimx_0", "dimx_1", "dimx_2")
2838        dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, None)
2839        with self.assertRaises(ConstraintViolationError):
2840            torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
2841        dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, {0: dimx_0})
2842        torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
2843
2844    def test_export_dynamic_dim_range_constraint(self):
2845        x = torch.ones(6, 4, 4)
2846        dynamic_shapes = ({0: torch.export.Dim("dimx", min=5, max=6)},)
2847
2848        def foo(x):
2849            if x.shape[0] > 3:  # ok
2850                return x.sin()
2851            return x.cos()
2852
2853        torch._dynamo.export(
2854            foo,
2855            dynamic_shapes=dynamic_shapes,
2856            aten_graph=True,
2857        )(x)
2858
2859        def bar(x):
2860            if x.shape[0] > 5:  # error
2861                return x.sin()
2862            return x.cos()
2863
2864        with self.assertRaises(ConstraintViolationError):
2865            torch._dynamo.export(
2866                bar,
2867                dynamic_shapes=dynamic_shapes,
2868                aten_graph=True,
2869            )(x)
2870
2871    def test_trivial_constraint(self):
2872        class Foo(torch.nn.Module):
2873            def forward(self, x):
2874                # complex divisibility condition
2875                if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0:
2876                    return x + 1
2877                else:
2878                    return x - 1
2879
2880        foo = Foo()
2881
2882        class Bar(torch.nn.Module):
2883            def forward(self, x):
2884                # trivially true
2885                if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0:
2886                    return x + 1
2887                else:
2888                    return x - 1
2889
2890        bar = Bar()
2891
2892        class Qux(torch.nn.Module):
2893            def forward(self, x):
2894                # simple divisibility condition (not trivially true)
2895                if (3 * x.shape[0]) % 2 == 0:
2896                    return x + 1
2897                else:
2898                    return x - 1
2899
2900        qux = Qux()
2901
2902        x = torch.randn(12)
2903        dim0 = torch.export.Dim("dim0", max=100)
2904        dynamic_shapes = {"x": (dim0,)}
2905        with self.assertRaisesRegex(
2906            torch._dynamo.exc.UserError,
2907            r"Constraints violated \(dim0\)",
2908        ):
2909            torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes)
2910
2911        torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes)
2912
2913        with self.assertRaisesRegex(
2914            torch._dynamo.exc.UserError,
2915            r"Constraints violated \(dim0\)",
2916        ):
2917            torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes)
2918
2919    def test_list_contains(self):
2920        def func(x):
2921            assert x.size(-1) in [4, 5, 6], "bad"
2922            return x + x
2923
2924        inps = (torch.randn(1, 5),)
2925        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
2926        real_result = opt_func(*inps)
2927
2928        torch._dynamo.reset()
2929
2930        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
2931        out_graph = exported[0]
2932
2933        dynamo_result = out_graph(*inps)
2934
2935        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
2936
2937    def test_list_not_contains(self):
2938        def func(x):
2939            assert x.size(0) not in [4, 5, 6], "bad1"
2940            assert "monkey" not in ["cow", "pig"], "bad2"
2941            return x + x
2942
2943        inps = (torch.randn(1, 5),)
2944        opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
2945        real_result = opt_func(*inps)
2946
2947        torch._dynamo.reset()
2948
2949        exported = torch._dynamo.export(func, aten_graph=True)(*inps)
2950        out_graph = exported[0]
2951
2952        dynamo_result = out_graph(*inps)
2953
2954        self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
2955
2956    def test_export_identity(self):
2957        inp = torch.tensor([0.1, 0.1])
2958
2959        def func(x):
2960            return x
2961
2962        torch._dynamo.reset()
2963        exported, _ = torch._dynamo.export(func)(inp)
2964        dynamo_result = exported(inp)
2965        self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result))
2966
2967    def test_export_specialized_int(self):
2968        class Foo(torch.nn.Module):
2969            def __init__(
2970                self,
2971                input_dim,
2972            ):
2973                super().__init__()
2974                self.torch_module = torch.nn.LayerNorm(
2975                    input_dim, eps=1e-5, elementwise_affine=True
2976                )
2977                self.int_val = 100
2978
2979            def forward(self, input):
2980                return input.cos() * self.int_val * self.torch_module.eps
2981
2982        mod = Foo(128)
2983        inp = torch.randn(3, 128)
2984
2985        # In export, int & float in forward should always be specialized
2986        gm, _ = torch._dynamo.export(mod, aten_graph=True)(inp)
2987        count = 0
2988        for node in gm.graph.nodes:
2989            if node.op == "placeholder":
2990                count += 1
2991        self.assertEqual(count, 1)
2992
2993    def test_export_with_nonzero_static(self):
2994        class BasicModule(torch.nn.Module):
2995            def __init__(self, static_size):
2996                super().__init__()
2997                self.static_size = static_size
2998
2999            def forward(self, x):
3000                return torch.nonzero_static(x, size=self.static_size)
3001
3002        input_tensors = torch.tensor([6, 8]), torch.zeros(2, 3)
3003        static_sizes = 3, 4
3004        for input_tensor, static_size in zip(input_tensors, static_sizes):
3005            m = BasicModule(static_size)
3006            gm, _ = torch._dynamo.export(m, aten_graph=True)(input_tensor)
3007            res = gm(input_tensor)
3008            self.assertEqual(res.size(0), static_size)
3009            self.assertTrue(
3010                torch._dynamo.utils.same(
3011                    res, torch.nonzero_static(input_tensor, size=static_size)
3012                )
3013            )
3014
3015    def test_export_pass_arg_by_name(self):
3016        class BasicModule(torch.nn.Module):
3017            def __init__(self) -> None:
3018                super().__init__()
3019                self.my_lin = torch.nn.Linear(3, 4, bias=True)
3020
3021            def forward(self, x):
3022                return self.my_lin(x)
3023
3024        mod, input_tensor = BasicModule(), torch.randn(2, 3)
3025        gm, guard = torch._dynamo.export(mod, aten_graph=True)(input_tensor)
3026        ref = mod(x=input_tensor)
3027        res = gm(x=input_tensor)
3028        self.assertTrue(torch._dynamo.utils.same(ref, res))
3029
3030    def test_export_pass_arg_by_name_star_args(self):
3031        class BasicModule(torch.nn.Module):
3032            def __init__(self) -> None:
3033                super().__init__()
3034                self.my_lin = torch.nn.Linear(3, 4, bias=True)
3035
3036            def forward(self, *args):
3037                return self.my_lin(args[0]) * self.my_lin(args[1])
3038
3039        mod, input_tensor, input_tensor2 = (
3040            BasicModule(),
3041            torch.randn(2, 3),
3042            torch.randn(2, 3),
3043        )
3044        gm, guard = torch._dynamo.export(mod, aten_graph=True)(
3045            input_tensor, input_tensor2
3046        )
3047        ref = mod(input_tensor, input_tensor2)
3048        res = gm(input_tensor, input_tensor2)
3049        self.assertTrue(torch._dynamo.utils.same(ref, res))
3050
3051    def test_export_mark_dynamic_conflict_dynamic_dim(self):
3052        y = torch.randn([3, 3, 3])
3053
3054        def my_dyn_fn(x):
3055            if x.shape[0] > 3:
3056                return x.sin()
3057            return x.cos()
3058
3059        torch._dynamo.mark_dynamic(y, 0)
3060        with self.assertRaisesRegex(
3061            RuntimeError,
3062            "Constraints violated",
3063        ):
3064            torch._dynamo.export(
3065                my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},)
3066            )(y)
3067
3068    def test_export_dynamic_dim_cleanup(self):
3069        y = torch.randn([3, 3, 3])
3070
3071        def my_dyn_fn(x):
3072            return x.cos()
3073
3074        torch._dynamo.export(my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},))(
3075            y
3076        )
3077
3078    @config.patch(capture_dynamic_output_shape_ops=True)
3079    def test_export_dynamic_control_flow_error(self):
3080        def f(x):
3081            if x.nonzero() > 3:
3082                return x.cos()
3083            return x.sin()
3084
3085        with self.assertRaisesRegex(
3086            torch._dynamo.exc.UserError,
3087            "Dynamic control flow is not supported at the moment",
3088        ):
3089            gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(5, 6))
3090
3091    @config.patch(assume_static_by_default=False)
3092    def test_export_persist_assert(self):
3093        def f(x):
3094            assert x[0].sum() > 4, "Shape must be more than 4"
3095            return x.cos() + x.sin()
3096
3097        gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
3098            torch.ones(5, 4, 6)
3099        )
3100
3101        def has_aten_op(gm, op):
3102            for node in gm.graph.nodes:
3103                if node.target == op:
3104                    return True
3105            return False
3106
3107        self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
3108
3109        gm.graph.eliminate_dead_code()
3110        gm.recompile()
3111        self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
3112
3113        with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
3114            gm(torch.zeros(3, 4, 5))
3115
3116    @common_utils.parametrize(
3117        "type_fn",
3118        [
3119            common_utils.subtest(type, name="builtin"),
3120            common_utils.subtest(lambda obj: obj.__class__, name="attr"),
3121        ],
3122    )
3123    def test_access_class_method_from_user_class(self, type_fn):
3124        class A:
3125            @classmethod
3126            def func(cls):
3127                return torch.Tensor([4, 5])
3128
3129        def f(x):
3130            a = A()
3131            return x.sum() + type_fn(a).func().sum()
3132
3133        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
3134        self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))
3135
3136    def test_not_functionalize(self):
3137        class Foo(torch.nn.Module):
3138            def __init__(self) -> None:
3139                super().__init__()
3140                self.buffer1 = torch.nn.Buffer(torch.ones(6, 2))
3141
3142            def forward(self, x):
3143                x.add_(2)
3144                return x.sum() + self.buffer1.sum()
3145
3146        example_inputs = (torch.ones(1, 2, 3),)
3147        gm, _ = torch._dynamo.export(
3148            Foo(),
3149            aten_graph=True,
3150            tracing_mode="symbolic",
3151        )(*example_inputs)
3152        count = 0
3153        for node in gm.graph.nodes:
3154            if node.target == torch.ops.aten.add_.Tensor:
3155                count += 1
3156        self.assertEqual(count, 1)
3157        test_inp = (torch.ones(1, 2, 3),)
3158        test_inp_v2 = (torch.ones(1, 2, 3),)
3159        self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2))
3160
3161    def test_round_dynamic_shapes(self):
3162        def f(x):
3163            return x[: round(x.shape[0] / 2)]
3164
3165        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
3166
3167        self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))
3168
3169    def test_cond_supported_pred_types(self):
3170        def true_fn(x):
3171            return x.cos()
3172
3173        def false_fn(x):
3174            return x.sin()
3175
3176        def f_pred_traced_as_symnode_var(x):
3177            return cond(x.shape[0] > 2, true_fn, false_fn, [x])
3178
3179        def f_pred_traced_as_tensor_var(x):
3180            return cond(x.all(), true_fn, false_fn, [x])
3181
3182        def f_pred_complex_expression_traced_as_symnode_var(x):
3183            return cond(
3184                x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10,
3185                true_fn,
3186                false_fn,
3187                [x],
3188            )
3189
3190        example_inputs = (torch.rand(5, 8),)
3191        for f in [
3192            f_pred_traced_as_symnode_var,
3193            f_pred_traced_as_tensor_var,
3194            f_pred_complex_expression_traced_as_symnode_var,
3195        ]:
3196            gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs)
3197            self.assertEqual(gm(*example_inputs), f(*example_inputs))
3198
3199    @unittest.expectedFailure  # TODO: Not sure why dynamo creates a new inputs for self.a
3200    def test_sum_param(self):
3201        # Setting a new attribute inside forward()
3202        class Foo(torch.nn.Module):
3203            def __init__(self) -> None:
3204                super().__init__()
3205                self.a = torch.randn(3, 2)
3206
3207            def forward(self, x):
3208                self.b = 2
3209                return x.sum() + self.a.sum() + self.b
3210
3211        torch._dynamo.export(Foo())(torch.randn(3, 2))
3212
3213    def test_mixed_real_and_fake_inputs(self):
3214        class _TestPattern(torch.nn.Module):
3215            def __init__(self) -> None:
3216                super().__init__()
3217                self.conv = torch.nn.Conv2d(1, 1, 1)
3218                self.bn = torch.nn.BatchNorm2d(1)
3219
3220            def forward(self, input):
3221                running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
3222                scale_factor = self.bn.weight / running_std
3223                weight_shape = [1] * len(self.conv.weight.shape)
3224                weight_shape[0] = -1
3225                bias_shape = [1] * len(self.conv.weight.shape)
3226                bias_shape[1] = -1
3227                scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
3228                zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
3229                conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
3230                conv_orig = conv / scale_factor.reshape(bias_shape)
3231                conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
3232                conv = self.bn(conv_orig)
3233                return conv
3234
3235        example_inputs = (torch.randn(1, 1, 3, 3),)
3236        torch._dynamo.export(
3237            _TestPattern(),
3238            aten_graph=True,
3239        )(*example_inputs)
3240
3241    @config.patch(
3242        capture_dynamic_output_shape_ops=True,
3243        capture_scalar_outputs=True,
3244        assume_static_by_default=False,
3245    )
3246    def test_sym_contains(self):
3247        def f(x, y):
3248            return x.size(0) in y
3249
3250        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(2), torch.ones(3))
3251
3252        true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5))
3253        false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2))
3254        self.assertEqual(gm(*true_inp), f(*true_inp))
3255        self.assertEqual(gm(*false_inp), f(*false_inp))
3256
3257    def test_cond_raise_user_error_on_missing_args(self):
3258        def true_fn(x):
3259            return x.cos()
3260
3261        def false_fn(x):
3262            return x.sin()
3263
3264        def f(x):
3265            return cond(x.shape[0] > 10, true_fn, false_fn)
3266
3267        example_inputs = (torch.rand(5),)
3268        with self.assertRaisesRegex(
3269            TypeError,
3270            r"cond\(\) missing 1 required positional argument: 'operands'",
3271        ):
3272            f(*example_inputs)
3273
3274    def test_cond_raise_user_error_on_unsupported_pred(self):
3275        def f_unsupported_pred(x):
3276            pred = torch.nn.Module()
3277            return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x])
3278
3279        example_inputs = (torch.rand(5),)
3280        with self.assertRaisesRegex(
3281            RuntimeError,
3282            "Expected pred to be bool or tensor, but got Module()",
3283        ):
3284            f_unsupported_pred(*example_inputs)
3285
3286    def test_cond_raise_user_error_on_non_list_operands(self):
3287        def f_non_list_operands(x):
3288            return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x)
3289
3290        example_inputs = (torch.rand(5),)
3291        with self.assertRaisesRegex(
3292            RuntimeError,
3293            r"Expect operands to be a tuple of possibly nested dict/list/tuple",
3294        ):
3295            f_non_list_operands(*example_inputs)
3296
3297    def test_cond_raise_user_error_on_non_tensor_operands(self):
3298        def f_non_tensor_operands(x):
3299            a: float = 3.14
3300            return cond(
3301                torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]
3302            )
3303
3304        example_inputs = (torch.rand(5),)
3305        with self.assertRaisesRegex(
3306            RuntimeError,
3307            r"Expect operands to be a tuple of possibly nested dict/list/tuple",
3308        ):
3309            f_non_tensor_operands(*example_inputs)
3310
3311    def test_cond_raise_user_error_on_branch_args_mismatch(self):
3312        def true_fn(x, y):
3313            return x.sin()
3314
3315        def false_fn(x):
3316            return x.cos()
3317
3318        def f_branch_args_mismatch(x, y):
3319            return cond(torch.tensor([[[[True]]]]), true_fn, false_fn, [x, y])
3320
3321        example_inputs = (torch.rand(5), torch.rand(2))
3322        with self.assertRaisesRegex(
3323            torch._dynamo.exc.UncapturedHigherOrderOpError,
3324            "Cond doesn't work unless it is captured completely with torch.compil",
3325        ):
3326            torch._dynamo.export(
3327                f_branch_args_mismatch,
3328                aten_graph=True,
3329            )(
3330                *example_inputs,
3331            )
3332
3333    @config.patch(suppress_errors=True)
3334    def test_uncaptured_higher_order_op_error_not_suppresed(self):
3335        def true_fn(x, y):
3336            return x.sin()
3337
3338        def false_fn(x):
3339            return x.cos()
3340
3341        def f_branch_args_mismatch(x, y):
3342            return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y])
3343
3344        example_inputs = (torch.rand(5), torch.rand(2))
3345        with self.assertRaisesRegex(
3346            torch._dynamo.exc.UncapturedHigherOrderOpError,
3347            "Cond doesn't work unless it is captured completely with torch.compile",
3348        ):
3349            torch._dynamo.export(
3350                f_branch_args_mismatch,
3351                aten_graph=True,
3352            )(
3353                *example_inputs,
3354            )
3355
3356    def test_cond_raise_user_error_on_branch_return_non_tensor(self):
3357        def f_branch_return_non_tensor(x):
3358            return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x])
3359
3360        example_inputs = (torch.rand(5),)
3361        with self.assertRaisesRegex(
3362            torch._dynamo.exc.UncapturedHigherOrderOpError,
3363            "Cond doesn't work unless it is captured completely with torch.compile",
3364        ):
3365            torch._dynamo.export(
3366                f_branch_return_non_tensor,
3367                aten_graph=True,
3368            )(*example_inputs)
3369
3370    def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
3371        def f_branch_return_multiple_tensors(pred, x, y):
3372            return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
3373
3374        example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
3375        gm, _ = torch._dynamo.export(
3376            f_branch_return_multiple_tensors,
3377            aten_graph=True,
3378        )(*example_inputs)
3379        self.assertEqual(
3380            gm(*example_inputs), f_branch_return_multiple_tensors(*example_inputs)
3381        )
3382
3383    def test_multiple_outputs_op_with_evaluator(self):
3384        class TopKModel(torch.nn.Module):
3385            def forward(self, x):
3386                values, _ = torch.topk(x, 3)
3387                return torch.sum(values)
3388
3389        x = torch.arange(1.0, 6.0, requires_grad=True)
3390        torch._dynamo.export(TopKModel())(x)
3391
3392    def test_cond_raise_user_error_on_mismatch_return_length(self):
3393        def true_fn(x):
3394            return x
3395
3396        def false_fn(x):
3397            return (x, x)
3398
3399        def f_mismatch_return_length(x):
3400            return cond(torch.tensor(100), true_fn, false_fn, [x])
3401
3402        example_inputs = (torch.rand(5),)
3403        with self.assertRaisesRegex(
3404            RuntimeError, "Unmatched number of outputs from cond"
3405        ):
3406            torch._dynamo.export(
3407                f_mismatch_return_length,
3408                aten_graph=True,
3409            )(*example_inputs)
3410
3411    def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self):
3412        def true_fn(x):
3413            return torch.tensor([[3], [2]])
3414
3415        def false_fn(x):
3416            return torch.tensor([3.14])
3417
3418        def f_return_tensor_mismatch(x):
3419            return cond(x.shape[0] < 3, true_fn, false_fn, [x])
3420
3421        example_inputs = (torch.rand(5),)
3422        with self.assertRaisesRegex(
3423            torch._dynamo.exc.UncapturedHigherOrderOpError,
3424            "Cond doesn't work unless it is captured completely with torch.compile",
3425        ):
3426            torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)(
3427                *example_inputs,
3428            )
3429
3430    def test_byte_tensor_does_not_crash(self):
3431        # See https://github.com/pytorch/pytorch/issues/100455
3432        def func(text):
3433            tensor = torch.ByteTensor(list(bytes(text, "utf8")))
3434            return tensor + tensor
3435
3436        text = "".join(chr(a % 90 + 40) for a in range(111))
3437        opt_func = torch._dynamo.optimize("eager", dynamic=True)(func)
3438        for i in [99, 100]:
3439            input = text[:i]
3440            opt_func(input)
3441
3442    def test_export_defaults_ok(self):
3443        class DynamicSliceExportMod(torch.nn.Module):
3444            def forward(self, x):
3445                results = []
3446                for i in range(4):
3447                    results.append(x[: x.size(0) - i, i : x.size(2), i:3])
3448                return tuple(results)
3449
3450        gm, _ = torch._dynamo.export(DynamicSliceExportMod(), aten_graph=True)(
3451            torch.randn(5, 5, 5),
3452        )
3453
3454        self.assertExpectedInline(
3455            gm.code.strip(),
3456            """\
3457def forward(self, x):
3458    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
3459    arg0_1 = arg0
3460    sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
3461    slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
3462    sub = sym_size_int - 1
3463    slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub);  sub = None
3464    slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int);  slice_2 = None
3465    slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3);  slice_3 = None
3466    sub_1 = sym_size_int - 2
3467    slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1);  sub_1 = None
3468    slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int);  slice_5 = None
3469    slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3);  slice_6 = None
3470    sub_2 = sym_size_int - 3
3471    slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2);  arg0_1 = sub_2 = None
3472    slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int);  slice_8 = sym_size_int = None
3473    slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3);  slice_9 = None
3474    return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""",
3475        )
3476
3477    def test_capture_symbolic_tracing_simple_within_fake_mode(self):
3478        from torch._dynamo.output_graph import config
3479
3480        def f(x):
3481            y = torch.randn(3)
3482            return x + x * y
3483
3484        with fake_tensor.FakeTensorMode(
3485            shape_env=ShapeEnv(
3486                allow_scalar_outputs=config.capture_scalar_outputs,
3487                allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
3488            ),
3489        ):
3490            x = torch.randn(3)
3491
3492            for aten_graph in [True, False]:
3493                gm, _ = torch._dynamo.export(f, aten_graph=aten_graph)(x)
3494                self.assertTrue(
3495                    isinstance(gm, torch.fx.GraphModule),
3496                    msg="test_capture_symbolic_tracing_simple_within_fake_mode_aten_graph_"
3497                    + str(aten_graph),
3498                )
3499
3500    def test_export_with_symbool_inputs(self):
3501        def f(pred: bool, x: torch.Tensor):
3502            if pred:
3503                return x.sin()
3504            else:
3505                return x.cos()
3506
3507        x = torch.randn([3, 4])
3508
3509        def test_symbool_guards(
3510            f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards
3511        ):
3512            shape_env = ShapeEnv()
3513            with fake_tensor.FakeTensorMode(
3514                shape_env=shape_env,
3515            ) as fake_mode:
3516                fake_x = fake_mode.from_tensor(
3517                    x,
3518                    symbolic_context=StatelessSymbolicContext(
3519                        dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())],
3520                    ),
3521                )
3522                for i, size in enumerate(size_tests):
3523                    pred = fake_x.size(0) == size
3524                    gm, guards = torch._dynamo.export(f)(pred, x)
3525                    actual = normalize_gm(gm.print_readable(print_output=False))
3526                    # TODO: This is naughty, EXPECTTEST_ACCEPT=1 doesn't work
3527                    self.assertExpectedInline(actual, exp_graph[i])
3528                    dynamo_shape_env_guards = [
3529                        guard
3530                        for guard in guards
3531                        if guard.guard_types is not None
3532                        and "SHAPE_ENV" in guard.guard_types
3533                    ]
3534                    self.assertEqual(len(dynamo_shape_env_guards), 1)
3535                    guard_code_on_predicate = [
3536                        code
3537                        for code in dynamo_shape_env_guards[0].code_list
3538                        if "L['pred']" in code
3539                    ]
3540                    self.assertEqual(guard_code_on_predicate, exp_guard_code[i])
3541                    outter_shape_env_guards = [
3542                        str(guard.expr) for guard in shape_env.guards
3543                    ]
3544                    self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i])
3545
3546        true_graph = """\
3547class GraphModule(torch.nn.Module):
3548    def forward(self, pred, x):
3549        arg1: "f32[s1, s2]";
3550
3551        arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
3552        l_x_ = arg1
3553
3554        sin: "f32[s1, s2]" = l_x_.sin();  l_x_ = None
3555        return pytree.tree_unflatten([sin], self._out_spec)
3556"""
3557        false_graph = """\
3558class GraphModule(torch.nn.Module):
3559    def forward(self, pred, x):
3560        arg1: "f32[s1, s2]";
3561
3562        arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
3563        l_x_ = arg1
3564
3565        cos: "f32[s1, s2]" = l_x_.cos();  l_x_ = None
3566        return pytree.tree_unflatten([cos], self._out_spec)
3567"""
3568        true_guard_code = [
3569            "cast_symbool_to_symint_guardless(L['pred']) == 1",
3570        ]
3571        false_guard_code = [
3572            "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
3573        ]
3574        test_symbool_guards(
3575            f,
3576            [3, 3, 4, 5],
3577            [true_graph, true_graph, false_graph, false_graph],
3578            [true_guard_code, true_guard_code, false_guard_code, false_guard_code],
3579            # Outter shape env should have no guards in it because we never specialize on the outter symbool.
3580            [[], [], [], []],
3581        )
3582
3583    def test_invalid_input_global(self) -> None:
3584        global bulbous_bouffant
3585        bulbous_bouffant = torch.randn(3)
3586
3587        def f(y):
3588            return bulbous_bouffant + y
3589
3590        self.assertExpectedInlineMunged(
3591            UserError,
3592            lambda: torch._dynamo.export(f)(torch.randn(3)),
3593            """\
3594G['bulbous_bouffant'], accessed at:
3595  File "test_export.py", line N, in f
3596    return bulbous_bouffant + y
3597""",
3598        )
3599
3600    def test_invalid_input_global_multiple_access(self) -> None:
3601        global macademia
3602        macademia = torch.randn(3)
3603
3604        def g(y):
3605            global macademia
3606            y = macademia + y
3607            return y
3608
3609        def f(y):
3610            global macademia
3611            y = g(y)
3612            return macademia + y
3613
3614        # NB: This doesn't actually work (it only reports the first usage),
3615        # but I'm leaving the test here in case we fix it later
3616        self.assertExpectedInlineMunged(
3617            UserError,
3618            lambda: torch._dynamo.export(f)(torch.randn(3)),
3619            """\
3620G['macademia'], accessed at:
3621  File "test_export.py", line N, in f
3622    y = g(y)
3623  File "test_export.py", line N, in g
3624    y = macademia + y
3625""",
3626        )
3627
3628    def test_invalid_input_nonlocal(self) -> None:
3629        arglebargle = torch.randn(3)
3630
3631        def f(y):
3632            return arglebargle + y
3633
3634        self.assertExpectedInlineMunged(
3635            UserError,
3636            lambda: torch._dynamo.export(f)(torch.randn(3)),
3637            """L['arglebargle'], a closed over free variable""",
3638        )
3639
3640    def test_invalid_input_unused_nonlocal_ok(self) -> None:
3641        arglebargle = torch.randn(3)
3642
3643        def f(y):
3644            x = arglebargle
3645            return y
3646
3647        torch._dynamo.export(f)(torch.randn(3))
3648
3649    def test_symbolic_tracing_within_fake_mode_with_constraints(self):
3650        from torch._subclasses import fake_tensor
3651
3652        fake_mode = fake_tensor.FakeTensorMode()
3653
3654        class DynamicShapeSimpleModel(torch.nn.Module):
3655            def __init__(self) -> None:
3656                super().__init__()
3657
3658            def forward(self, a, b, c) -> torch.Tensor:
3659                d = (torch.matmul(a, b) + c) / 2
3660                d_s0 = d.shape[0]
3661                d_s1 = d.shape[1]
3662                d_s3 = d_s0 * d_s1
3663                e = d.view(d_s3)
3664                return torch.cat([e, e])
3665
3666        with fake_mode:
3667            model = DynamicShapeSimpleModel()
3668            inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
3669            dim = torch.export.Dim("dim")
3670            dynamic_shapes = ({0: dim}, None, {0: dim})
3671            for aten_graph in [True, False]:
3672                gm = torch._dynamo.export(
3673                    model,
3674                    dynamic_shapes=dynamic_shapes,
3675                    aten_graph=aten_graph,
3676                )(*inputs).graph_module
3677
3678        # Since there are no parameters we can do this
3679        inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
3680        self.assertEqual(model(*inputs), gm(*inputs))
3681
3682    def test_symbolic_tracing_within_fake_mode_with_constraints_with_parameters(self):
3683        from torch._subclasses import fake_tensor
3684
3685        fake_mode = fake_tensor.FakeTensorMode()
3686
3687        # TODO: Seems to choke if you don't make a fresh model and
3688        # just try to export Linear directly...
3689        class Model(torch.nn.Module):
3690            def __init__(self) -> None:
3691                super().__init__()
3692                self.linear = torch.nn.Linear(2, 2)
3693
3694            def forward(self, x):
3695                out = self.linear(x)
3696                return out
3697
3698        with fake_mode:
3699            model = Model()
3700            inputs = (torch.randn(10, 2, 2),)
3701            dynamic_shapes = ({0: torch.export.Dim("dim")},)
3702            for aten_graph in [True, False]:
3703                gm = torch._dynamo.export(
3704                    model,
3705                    dynamic_shapes=dynamic_shapes,
3706                    aten_graph=aten_graph,
3707                )(*inputs).graph_module
3708
3709    def test_capture_symbolic_tracing_within_fake_mode(self):
3710        from torch._dynamo.output_graph import config
3711        from torch._subclasses import fake_tensor
3712        from torch.fx.experimental.symbolic_shapes import ShapeEnv
3713
3714        class Model(torch.nn.Module):
3715            def __init__(self) -> None:
3716                super().__init__()
3717                self.linear = torch.nn.Linear(2, 2)
3718                self.linear2 = torch.nn.Linear(2, 2)
3719
3720            def forward(self, x):
3721                out = self.linear(x)
3722                out = self.linear2(out)
3723                return out
3724
3725        # User-instantiated FakeTensorMode
3726        fake_mode = fake_tensor.FakeTensorMode(
3727            allow_non_fake_inputs=False,
3728            allow_fallback_kernels=True,
3729            shape_env=ShapeEnv(
3730                allow_scalar_outputs=config.capture_scalar_outputs,
3731                allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
3732            ),
3733        )
3734        # Fakefy input+model before exporting it
3735        with fake_mode:
3736            x = torch.rand(5, 2, 2)
3737            model = Model()
3738
3739            # Export the model with fake inputs and parameters
3740            for aten_graph in [True, False]:
3741                graph_module, _ = torch._dynamo.export(model, aten_graph=aten_graph)(x)
3742                self.assertTrue(
3743                    isinstance(graph_module, torch.fx.GraphModule),
3744                    msg="test_capture_symbolic_tracing_within_fake_mode_aten_graph_"
3745                    + str(aten_graph),
3746                )
3747
3748    def test_cond_op_param_buffer_lifted(self):
3749        class A(torch.nn.Module):
3750            def __init__(self) -> None:
3751                super().__init__()
3752                self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))
3753
3754            def forward(self):
3755                return self.buffer1.sum()
3756
3757        class B(torch.nn.Module):
3758            def __init__(self) -> None:
3759                super().__init__()
3760                self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))
3761
3762            def forward(self):
3763                return self.buffer2.sum()
3764
3765        class M(torch.nn.Module):
3766            def __init__(self) -> None:
3767                super().__init__()
3768                self.a = A()
3769                self.b = B()
3770
3771            def forward(self, x):
3772                def true_fn(x):
3773                    return x.cos() + self.a()
3774
3775                def false_fn(x):
3776                    return x.sin() + self.b()
3777
3778                return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)
3779
3780        gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4))
3781        self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
3782        self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))
3783
3784    def test_nested_cond_op_param_buffer_lifted(self):
3785        class A(torch.nn.Module):
3786            def __init__(self) -> None:
3787                super().__init__()
3788                self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))
3789
3790            def forward(self):
3791                return self.buffer1.sum()
3792
3793        class B(torch.nn.Module):
3794            def __init__(self) -> None:
3795                super().__init__()
3796                self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))
3797
3798            def forward(self):
3799                return self.buffer2.sum()
3800
3801        class M(torch.nn.Module):
3802            def __init__(self) -> None:
3803                super().__init__()
3804                self.a = A()
3805                self.b = B()
3806
3807            def forward(self, x):
3808                def true_true_fn(x):
3809                    return x.cos() + self.a()
3810
3811                def true_false_fn(x):
3812                    return x.cos() + self.a() + 1
3813
3814                def true_fn(x):
3815                    return cond(x.shape[0] > 5, true_true_fn, true_false_fn, [x])
3816
3817                def false_fn(x):
3818                    return x.sin() + self.b()
3819
3820                return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)
3821
3822        gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4))
3823        self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
3824        self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4)))
3825        self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))
3826
3827    def test_map_cond_param_buffer_lifted(self):
3828        from functorch.experimental.control_flow import cond, map
3829
3830        class A(torch.nn.Module):
3831            def __init__(self) -> None:
3832                super().__init__()
3833                self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))
3834
3835            def forward(self):
3836                return self.buffer1.sum()
3837
3838        class B(torch.nn.Module):
3839            def __init__(self) -> None:
3840                super().__init__()
3841                self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))
3842
3843            def forward(self):
3844                return self.buffer2.sum()
3845
3846        class Module(torch.nn.Module):
3847            def __init__(self) -> None:
3848                super().__init__()
3849                self.a = A()
3850                self.b = B()
3851
3852            def inner(self, x, pred):
3853                def true_fn(x):
3854                    return x + x + self.a()
3855
3856                def false_fn(x):
3857                    return x * x + self.b()
3858
3859                return cond(pred, true_fn, false_fn, [x])
3860
3861            def forward(self, pred, xs):
3862                def body(x, pred):
3863                    return self.inner(x, pred) + self.b()
3864
3865                return map(body, xs, pred)
3866
3867        mod = Module()
3868        x = torch.randn(3, 2, 1)
3869        pred_x = torch.tensor(True)
3870
3871        y = torch.randn(4, 3, 2)
3872        pred_y = torch.tensor(False)
3873        real_result = mod(pred_y, y)
3874
3875        out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
3876        self.assertEqual(real_result, out_graph(pred_y, y))
3877
3878    def test_cond_free_variables_overlapping(self):
3879        from functorch.experimental.control_flow import cond
3880
3881        class Module(torch.nn.Module):
3882            def __init__(self) -> None:
3883                super().__init__()
3884
3885            def forward(self, pred, x):
3886                a = torch.ones(6, 4)
3887                b = torch.ones(6, 4)
3888                c = torch.ones(6, 4)
3889                d = torch.ones(6, 4)
3890
3891                def true_fn(x):
3892                    return x + x + a.cos() + b.cos() + d.cos()
3893
3894                def false_fn(x):
3895                    return x * x + a.sin() + b.sin() + c.sin()
3896
3897                return cond(pred, true_fn, false_fn, [x])
3898
3899        mod = Module()
3900        x = torch.ones(6, 4)
3901        pred_x = torch.tensor(True)
3902
3903        out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
3904        self.assertExpectedInline(
3905            out_graph.code.strip(),
3906            """\
3907def forward(self, pred, x):
3908    arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
3909    l_pred_ = arg0
3910    l_x_ = arg1
3911    a = torch.ones(6, 4)
3912    b = torch.ones(6, 4)
3913    c = torch.ones(6, 4)
3914    d = torch.ones(6, 4)
3915    cond_true_0 = self.cond_true_0
3916    cond_false_0 = self.cond_false_0
3917    cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [a, b, l_x_, d, c]);  l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None
3918    getitem = cond[0];  cond = None
3919    return pytree.tree_unflatten([getitem], self._out_spec)""",  # noqa: B950,E122
3920        )
3921
3922        self.assertExpectedInline(
3923            out_graph.cond_true_0.code.strip(),
3924            """\
3925def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3926    a_1 = a
3927    b_1 = b
3928    l_x__1 = l_x_
3929    add = l_x__1 + l_x__1;  l_x__1 = None
3930    cos = a_1.cos();  a_1 = None
3931    add_1 = add + cos;  add = cos = None
3932    cos_1 = b_1.cos();  b_1 = None
3933    add_2 = add_1 + cos_1;  add_1 = cos_1 = None
3934    cos_2 = d_true_branch.cos();  d_true_branch = None
3935    add_3 = add_2 + cos_2;  add_2 = cos_2 = None
3936    return (add_3,)""",
3937        )
3938
3939        self.assertExpectedInline(
3940            out_graph.cond_false_0.code.strip(),
3941            """\
3942def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3943    a_1 = a
3944    b_1 = b
3945    l_x__1 = l_x_
3946    mul = l_x__1 * l_x__1;  l_x__1 = None
3947    sin = a_1.sin();  a_1 = None
3948    add = mul + sin;  mul = sin = None
3949    sin_1 = b_1.sin();  b_1 = None
3950    add_1 = add + sin_1;  add = sin_1 = None
3951    sin_2 = c_false_branch.sin();  c_false_branch = None
3952    add_2 = add_1 + sin_2;  add_1 = sin_2 = None
3953    return (add_2,)""",
3954        )
3955
3956    @unittest.skipIf(
3957        common_utils.TEST_WITH_ASAN,
3958        "Times out with ASAN, see https://github.com/pytorch/pytorch/issues/110416",
3959    )
3960    def test_retracibility(self):
3961        class MyLinear(torch.nn.Module):
3962            def __init__(self) -> None:
3963                super().__init__()
3964                self.weight = torch.randn(20, 98)
3965                self.bias = torch.randn(20)
3966
3967            def forward(self, x):
3968                return torch.nn.functional.linear(x, self.weight, self.bias)
3969
3970        class Foo(torch.nn.Module):
3971            def __init__(self) -> None:
3972                super().__init__()
3973                self.conv = torch.nn.Conv2d(16, 33, 3)
3974                self.linear = MyLinear()
3975
3976            def forward(self, x):
3977                a, b = x
3978                a_conv = self.conv(a)
3979                a_linear = self.linear(a_conv)
3980                b_conv = self.conv(b)
3981                b_linear = self.linear(b_conv)
3982                return (
3983                    a_linear.cos() + b_linear.sin(),
3984                    a_linear.sin() + b_linear.cos(),
3985                )
3986
3987        inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
3988
3989        gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
3990        gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
3991
3992        inp_test = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
3993
3994        self.assertTrue(torch.allclose(gm(inp_test)[0], gm2(inp_test)[0]))
3995        self.assertTrue(torch.allclose(gm(inp_test)[1], gm2(inp_test)[1]))
3996
3997    def test_retracibility_dict_container_inp_out(self):
3998        class MyLinear(torch.nn.Module):
3999            def __init__(self) -> None:
4000                super().__init__()
4001                self.weight = torch.randn(20, 98)
4002                self.bias = torch.randn(20)
4003
4004            def forward(self, x):
4005                return torch.nn.functional.linear(x, self.weight, self.bias)
4006
4007        class Foo(torch.nn.Module):
4008            def __init__(self) -> None:
4009                super().__init__()
4010                self.conv = torch.nn.Conv2d(16, 33, 3)
4011                self.linear = MyLinear()
4012
4013            def forward(self, x):
4014                a1, a2 = x["a"]
4015                b = x["b"]
4016                a1_conv = self.conv(a1)
4017                a1_linear = self.linear(a1_conv)
4018                a2_conv = self.conv(a2)
4019                a2_linear = self.linear(a2_conv)
4020                b_conv = self.conv(b)
4021                b_linear = self.linear(b_conv)
4022                return {
4023                    "a": [
4024                        a1_linear.cos() + b_linear.sin(),
4025                        a1_linear.cos() + b_linear.sin(),
4026                    ],
4027                    "b": a2_linear.sin() + b_linear.cos(),
4028                }
4029
4030        inp_container = {
4031            "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
4032            "b": torch.randn(20, 16, 50, 100),
4033        }
4034
4035        gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
4036        gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
4037
4038        inp_test = {
4039            "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
4040            "b": torch.randn(20, 16, 50, 100),
4041        }
4042
4043        self.assertTrue(torch.allclose(gm(inp_test)["a"][0], gm2(inp_test)["a"][0]))
4044        self.assertTrue(torch.allclose(gm(inp_test)["a"][1], gm2(inp_test)["a"][1]))
4045        self.assertTrue(torch.allclose(gm(inp_test)["b"], gm2(inp_test)["b"]))
4046
4047    def test_retracibility_nested_list_out(self):
4048        class MyLinear(torch.nn.Module):
4049            def __init__(self) -> None:
4050                super().__init__()
4051                self.weight = torch.randn(20, 98)
4052                self.bias = torch.randn(20)
4053
4054            def forward(self, x):
4055                return torch.nn.functional.linear(x, self.weight, self.bias)
4056
4057        class Foo(torch.nn.Module):
4058            def __init__(self) -> None:
4059                super().__init__()
4060                self.conv = torch.nn.Conv2d(16, 33, 3)
4061                self.linear = MyLinear()
4062
4063            def forward(self, x):
4064                a1, a2 = x["a"]
4065                b = x["b"]
4066                a1_conv = self.conv(a1)
4067                a1_linear = self.linear(a1_conv)
4068                a2_conv = self.conv(a2)
4069                a2_linear = self.linear(a2_conv)
4070                b_conv = self.conv(b)
4071                b_linear = self.linear(b_conv)
4072                return [
4073                    [
4074                        a1_linear.cos() + b_linear.sin(),
4075                        a1_linear.cos() + b_linear.sin(),
4076                    ],
4077                    [
4078                        a2_linear.sin() + b_linear.cos(),
4079                        a2_linear.sin() + b_linear.cos(),
4080                    ],
4081                ]
4082
4083        inp_container = {
4084            "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
4085            "b": torch.randn(20, 16, 50, 100),
4086        }
4087
4088        gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
4089        gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
4090
4091        inp_test = {
4092            "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
4093            "b": torch.randn(20, 16, 50, 100),
4094        }
4095
4096        self.assertTrue(torch.allclose(gm(inp_test)[0][0], gm2(inp_test)[0][0]))
4097        self.assertTrue(torch.allclose(gm(inp_test)[0][1], gm2(inp_test)[0][1]))
4098        self.assertTrue(torch.allclose(gm(inp_test)[1][0], gm2(inp_test)[1][0]))
4099        self.assertTrue(torch.allclose(gm(inp_test)[1][1], gm2(inp_test)[1][1]))
4100
4101    def test_fx_pytree(self):
4102        def foo(args):
4103            flat_args, spec = torch.utils._pytree.tree_flatten(args)
4104            flat_args_fx = torch.fx._pytree.tree_flatten_spec(args, spec)
4105            return flat_args_fx[0] + flat_args[0]
4106
4107        inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
4108
4109        gm, _ = torch._dynamo.export(foo, inp_container, aten_graph=True)
4110
4111        self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container)))
4112
4113    @config.patch(suppress_errors=True)
4114    @config.patch(verbose=True)
4115    def test_export_with_map_zero_sized_tensor_suppress_errors(self):
4116        from functorch.experimental.control_flow import map
4117
4118        class Module(torch.nn.Module):
4119            def forward(self, xs):
4120                def body(x):
4121                    return x + 1
4122
4123                return map(body, xs)
4124
4125        mod = Module()
4126        xs = torch.randn(0, 2)
4127        with self.assertRaises(
4128            torch._dynamo.exc.Unsupported,
4129        ):
4130            out_graph, _ = torch._dynamo.export(mod, xs)
4131
4132    def test_param_buffer_safe_from_mutation_simple(self):
4133        class Module(torch.nn.Module):
4134            def __init__(self) -> None:
4135                super().__init__()
4136                self.buffer1 = torch.nn.Buffer(torch.zeros(5, 5))
4137
4138            def forward(self, x):
4139                self.buffer1.add_(1)
4140                return x + self.buffer1
4141
4142        gm, _ = torch._dynamo.export(Module(), torch.ones(5, 5), aten_graph=False)
4143        buffers = list(gm.named_buffers())
4144        self.assertEqual(len(buffers), 1)
4145
4146        name, buffer = buffers[0]
4147        self.assertEqual(name, "L__self___buffer1")
4148
4149        self.assertTrue(torch.allclose(buffer, torch.zeros(5)))
4150
4151    def test_param_buffer_safe_from_mutation_recurse(self):
4152        class Child(torch.nn.Module):
4153            def __init__(self) -> None:
4154                super().__init__()
4155                self.buffer2 = torch.nn.Buffer(torch.zeros(5))
4156
4157            def forward(self, x):
4158                return x.sum() + self.buffer2.sum()
4159
4160        class Module(torch.nn.Module):
4161            def __init__(self) -> None:
4162                super().__init__()
4163                self.buffer1 = torch.nn.Buffer(torch.zeros(5))
4164                self.child = Child()
4165
4166            def forward(self, x):
4167                self.buffer1.add_(1)
4168                self.child.buffer2.add_(2)
4169                return x.sum() + self.buffer1.sum() + self.child(x)
4170
4171        gm, _ = torch._dynamo.export(Module(), torch.ones(5), aten_graph=False)
4172        for name, buffer in gm.named_buffers():
4173            self.assertTrue(torch.allclose(buffer, torch.zeros(5)))
4174
4175    def test_predispatch_with_higher_order(self):
4176        def f(x):
4177            return cond(x.shape[0] > 4, lambda x: x + 5, lambda x: x - 3, [x])
4178
4179        gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)(
4180            torch.randn(4, 4)
4181        )
4182        inp1 = torch.randn(4, 4)
4183        inp2 = torch.randn(6, 4)
4184        self.assertTrue(torch.allclose(f(inp1), gm(inp1)))
4185        self.assertTrue(torch.allclose(f(inp2), gm(inp2)))
4186
4187    def test_predispatch_with_higher_order_nested(self):
4188        def f(x):
4189            def true_fn(x):
4190                return cond(x.shape[0] > 6, lambda x: x + 10, lambda x: x - 10, [x])
4191
4192            return cond(x.shape[0] > 4, true_fn, lambda x: x - 3, [x])
4193
4194        gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)(
4195            torch.randn(4, 4)
4196        )
4197        inp1 = torch.randn(4, 4)
4198        inp2 = torch.randn(6, 4)
4199        inp3 = torch.randn(8, 4)
4200        self.assertTrue(torch.allclose(f(inp1), gm(inp1)))
4201        self.assertTrue(torch.allclose(f(inp2), gm(inp2)))
4202        self.assertTrue(torch.allclose(f(inp3), gm(inp3)))
4203
4204    def test_predispatch_with_for_out_dtype(self):
4205        class M(torch.nn.Module):
4206            def __init__(self, weight):
4207                super().__init__()
4208                self.weight = weight
4209
4210            def forward(self, x):
4211                return out_dtype(torch.ops.aten.mm.default, torch.int32, x, self.weight)
4212
4213        weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
4214        m = M(weight)
4215        x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
4216        gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True)
4217
4218        self.assertTrue(torch.allclose(m(x), gm(x)))
4219
4220    def test_predispatch_with_for_out_dtype_nested(self):
4221        class M(torch.nn.Module):
4222            def __init__(self, weight):
4223                super().__init__()
4224                self.weight = weight
4225
4226            def true_fn(self, x):
4227                return out_dtype(
4228                    torch.ops.aten.mm.default, torch.int32, x, self.weight
4229                ).sum()
4230
4231            def false_fn(self, x):
4232                return out_dtype(
4233                    torch.ops.aten.mul.Tensor, torch.int32, x, self.weight
4234                ).sum()
4235
4236            def forward(self, x):
4237                return cond(x.sum() != 0, self.true_fn, self.false_fn, [x])
4238
4239        weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
4240        m = M(weight)
4241        x = torch.ones((5, 5), dtype=torch.int8)
4242        gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True)
4243
4244        self.assertTrue(torch.allclose(m(x), gm(x)))
4245        y = torch.zeros((5, 5), dtype=torch.int8)
4246        self.assertTrue(torch.allclose(m(y), gm(y)))
4247
4248        self.assertExpectedInline(
4249            gm.true_graph_0.code.strip(),
4250            """\
4251def forward(self, arg0_1, arg1_1):
4252    out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg1_1, arg0_1);  arg1_1 = arg0_1 = None
4253    sum_1 = torch.ops.aten.sum.default(out_dtype);  out_dtype = None
4254    return (sum_1,)""",
4255        )
4256
4257        self.assertExpectedInline(
4258            gm.false_graph_0.code.strip(),
4259            """\
4260def forward(self, arg0_1, arg1_1):
4261    out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg1_1, arg0_1);  arg1_1 = arg0_1 = None
4262    sum_1 = torch.ops.aten.sum.default(out_dtype);  out_dtype = None
4263    return (sum_1,)""",
4264        )
4265
4266    def test_export_nn_module_stack_patched_module(self):
4267        def forward(self, x, y):
4268            return x * y
4269
4270        class Toplevel(torch.nn.Module):
4271            def __init__(self, m):
4272                super().__init__()
4273                self.m = m
4274
4275            def forward(self, x, y):
4276                return self.m(x, y)
4277
4278        class M(torch.nn.Module):
4279            def forward(self, x, y):
4280                return x + y
4281
4282        t = Toplevel(M())
4283        t.m.forward = forward.__get__(t.m, M)
4284        x, y = torch.rand(3), torch.rand(3)
4285        gm, _ = torch._dynamo.export(t, x, y)
4286
4287        self.assertTrue(torch.allclose(forward(None, x, y), gm(x, y)))
4288        for node in gm.graph.nodes:
4289            if node.op == "call_function":
4290                self.assertIn("nn_module_stack", node.meta)
4291
4292    def test_preserve_fx_node_metadata(self):
4293        class Module1(torch.nn.Module):
4294            def forward(self, x):
4295                return torch.sin(x)
4296
4297        class Module2(torch.nn.Module):
4298            def __init__(self) -> None:
4299                super().__init__()
4300                self.mod1 = Module1()
4301
4302            def forward(self, x):
4303                x = torch.cos(x)
4304                x = self.mod1(x)
4305                x = torch.relu(x)
4306                return x
4307
4308        def fn(x):
4309            return torch.abs(x)
4310
4311        mod = Module2()
4312        inp = torch.randn(3, 3)
4313
4314        gm, _ = torch._dynamo.export(mod)(inp)
4315
4316        # replace relu with fn
4317        gm_edit = copy.deepcopy(gm)
4318        for nd in gm_edit.graph.nodes:
4319            if nd.target == torch.relu:
4320                nd.target = fn
4321                nd.meta.clear()
4322                break
4323        gm_edit.recompile()
4324
4325        gm2, _ = torch._dynamo.export(gm_edit)(inp)
4326
4327        self.assertExpectedInline(
4328            gm.code.strip(),
4329            """\
4330def forward(self, x):
4331    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
4332    l_x_ = arg0
4333    x = torch.cos(l_x_);  l_x_ = None
4334    x_1 = torch.sin(x);  x = None
4335    x_2 = torch.relu(x_1);  x_1 = None
4336    return pytree.tree_unflatten([x_2], self._out_spec)""",
4337        )
4338
4339        def _constais_op(gm, target):
4340            for nd in gm.graph.nodes:
4341                if nd.target == target:
4342                    return True
4343            return False
4344
4345        self.assertTrue(_constais_op(gm_edit, torch.cos))
4346        self.assertTrue(_constais_op(gm_edit, torch.sin))
4347        self.assertTrue(not _constais_op(gm_edit, torch.relu))
4348
4349        self.assertExpectedInline(
4350            gm2.code.strip(),
4351            """\
4352def forward(self, x):
4353    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
4354    l_x_ = arg0
4355    x = torch.cos(l_x_);  l_x_ = None
4356    x_1 = torch.sin(x);  x = None
4357    x_2 = torch.abs(x_1);  x_1 = None
4358    return pytree.tree_unflatten([x_2], self._out_spec)""",
4359        )
4360
4361        # check for other metadata
4362        for op in (torch.sin, torch.cos):
4363            nd1 = next(filter(lambda nd: nd.target == op, gm.graph.nodes))
4364            nd2 = next(filter(lambda nd: nd.target == op, gm2.graph.nodes))
4365            self.assertTrue(
4366                ("nn_module_stack" in nd1.meta) == ("nn_module_stack" in nd2.meta)
4367            )
4368            if "nn_module_stack" in nd1.meta:
4369                self.assertEqual(
4370                    nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"]
4371                )
4372            self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"])
4373
4374    def test_preserve_fx_node_metadata_recompile(self):
4375        def fn(x):
4376            return torch.sin(x)
4377
4378        gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
4379        do_export = torch._dynamo.export(gm)
4380        torch._dynamo.optimize("eager")(fn)(torch.randn(3, 3))
4381        gm1, _ = do_export(torch.randn(3, 3))
4382        gm2, _ = do_export(torch.randn(5, 3))
4383
4384        self.assertExpectedInline(
4385            gm1.code.strip(),
4386            """\
4387def forward(self, x):
4388    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
4389    l_x_ = arg0
4390    sin = torch.sin(l_x_);  l_x_ = None
4391    return pytree.tree_unflatten([sin], self._out_spec)""",
4392        )
4393        self.assertExpectedInline(
4394            gm2.code.strip(),
4395            """\
4396def forward(self, x):
4397    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
4398    l_x_ = arg0
4399    sin = torch.sin(l_x_);  l_x_ = None
4400    return pytree.tree_unflatten([sin], self._out_spec)""",
4401        )
4402
4403    def test_preserve_fx_node_metadata_inline(self):
4404        def f1(x):
4405            return torch.sin(x)
4406
4407        gm, _ = torch._dynamo.export(f1)(torch.randn(3, 3))
4408
4409        def f2(x):
4410            x = torch.cos(x)
4411            return gm(x)
4412
4413        gm2, _ = torch._dynamo.export(f2)(torch.randn(3, 3))
4414
4415        self.assertExpectedInline(
4416            gm2.code.strip(),
4417            """\
4418def forward(self, x):
4419    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
4420    l_x_ = arg0
4421    x = torch.cos(l_x_);  l_x_ = None
4422    sin = torch.sin(x);  x = None
4423    return pytree.tree_unflatten([sin], self._out_spec)""",
4424        )
4425
4426    def test_preserve_fx_node_metadata_graph_break(self):
4427        def fn(x):
4428            x = torch.sin(x)
4429            x = torch.abs(x)
4430            return torch.cos(x)
4431
4432        def bad_fn(x):
4433            torch._dynamo.graph_break()
4434            return x
4435
4436        gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
4437
4438        # replace abs with graph break
4439        gm_edit = copy.deepcopy(gm)
4440        for nd in gm_edit.graph.nodes:
4441            if nd.target == torch.abs:
4442                nd.target = bad_fn
4443                nd.meta.clear()
4444                break
4445        gm_edit.recompile()
4446
4447        expected = [
4448            """x = torch.sin(l_x_)""",
4449            """cos = torch.cos(l_stack0_)""",
4450        ]
4451
4452        def test_backend(gm: torch.fx.GraphModule, example_inputs):
4453            self.assertTrue(expected)
4454            # Normalize output for dynamic and not
4455            for nd in gm.graph.nodes:
4456                if "example_value" in nd.meta:
4457                    del nd.meta["example_value"]
4458            self.assertIn(expected[0], gm.print_readable(print_output=False))
4459            expected.pop(0)
4460            return gm.forward
4461
4462        torch._dynamo.reset()
4463        opt_gm_edit = torch.compile(gm_edit, backend=test_backend)
4464        opt_gm_edit(torch.randn(3, 3))
4465
4466    def test_torch_inference_mode_ctx(self):
4467        @torch.inference_mode()
4468        def fn(x):
4469            return x + 1
4470
4471        gm, _ = torch._dynamo.export(fn, torch.rand(2, 2))
4472
4473        inp = torch.randn(2, 2)
4474        out = gm(inp)
4475        self.assertExpectedInline(
4476            gm.code.strip(),
4477            """\
4478def forward(self, x):
4479    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
4480    l_args_0_ = arg0
4481    _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
4482    add = l_args_0_ + 1;  l_args_0_ = None
4483    _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode);  _enter_inference_mode = _exit_inference_mode = None
4484    return pytree.tree_unflatten([add], self._out_spec)""",  # NOQA: B950
4485        )
4486        self.assertEqual(out.requires_grad, False)
4487        with self.assertRaisesRegex(
4488            RuntimeError,
4489            "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.",
4490        ):
4491            out.requires_grad = True
4492
4493        @torch.inference_mode(False)
4494        def fn_no_inference(x):
4495            return x + 1
4496
4497        gm_no_inference, _ = torch._dynamo.export(fn_no_inference, torch.rand(2, 2))
4498        self.assertExpectedInline(
4499            gm_no_inference.code.strip(),
4500            """\
4501def forward(self, x):
4502    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
4503    l_args_0_ = arg0
4504    _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False)
4505    add = l_args_0_ + 1;  l_args_0_ = None
4506    _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode);  _enter_inference_mode = _exit_inference_mode = None
4507    return pytree.tree_unflatten([add], self._out_spec)""",  # NOQA: B950
4508        )
4509
4510        inp = torch.randn(2, 2)
4511        out = gm_no_inference(inp)
4512        self.assertEqual(out.requires_grad, False)
4513        out.requires_grad = True
4514
4515        def fn(x):
4516            with torch.inference_mode():
4517                return x + 1
4518
4519        gm, _ = torch._dynamo.export(fn)(torch.rand(2, 2))
4520        self.assertExpectedInline(
4521            gm.code.strip(),
4522            """\
4523def forward(self, x):
4524    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
4525    l_x_ = arg0
4526    _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
4527    add = l_x_ + 1;  l_x_ = None
4528    _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode);  _enter_inference_mode = _exit_inference_mode = None
4529    return pytree.tree_unflatten([add], self._out_spec)""",  # NOQA: B950
4530        )
4531        inp = torch.randn(2, 2, requires_grad=True)
4532        out = gm(inp)
4533        self.assertEqual(out.requires_grad, False)
4534
4535    def test_export_masking_with_no_grad(self):
4536        def fn(x, b, y):
4537            x = x.clone()
4538            x[b] = y
4539            return x
4540
4541        def fn_no_grad(x, b, y):
4542            with torch.no_grad():
4543                return fn(x, b, y)
4544
4545        def fn_inference_mode(x, b, y):
4546            with torch.inference_mode():
4547                return fn(x, b, y)
4548
4549        x = torch.randn(4, requires_grad=True)
4550        b = torch.tensor([True, False, True, False])
4551        y = torch.randn(2, requires_grad=True)
4552
4553        gm, _ = torch._dynamo.export(fn_no_grad)(x, b, y)
4554        self.assertExpectedInline(
4555            gm.code.strip(),
4556            """\
4557def forward(self, x, b, y):
4558    arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
4559    l_x_ = arg0
4560    l_b_ = arg1
4561    l_y_ = arg2
4562    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
4563    x = l_x_.clone();  l_x_ = None
4564    x[l_b_] = l_y_;  setitem = x;  l_b_ = l_y_ = setitem = None
4565    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
4566    return pytree.tree_unflatten([x], self._out_spec)""",
4567        )
4568
4569        gm, _ = torch._dynamo.export(fn_inference_mode)(x, b, y)
4570        self.assertExpectedInline(
4571            gm.code.strip(),
4572            """\
4573def forward(self, x, b, y):
4574    arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
4575    l_x_ = arg0
4576    l_b_ = arg1
4577    l_y_ = arg2
4578    _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
4579    x = l_x_.clone();  l_x_ = None
4580    x[l_b_] = l_y_;  setitem = x;  l_b_ = l_y_ = setitem = None
4581    _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode);  _enter_inference_mode = _exit_inference_mode = None
4582    return pytree.tree_unflatten([x], self._out_spec)""",  # NOQA: B950
4583        )
4584
4585        with self.assertRaisesRegex(
4586            torch._dynamo.exc.Unsupported, "boolean masking setitem backwards"
4587        ):
4588            gm, _ = torch._dynamo.export(fn)(x, b, y)
4589
4590    def test_dynamo_list_index(self):
4591        def fn(x, in_list):
4592            return x + in_list.index(2)
4593
4594        inputs = (torch.ones(2, 2), [1, 2])
4595        graph, _ = torch._dynamo.export(fn)(*inputs)
4596        out = graph(*inputs)
4597        self.assertEqual(out, torch.ones(2, 2) + 1)
4598
4599
4600common_utils.instantiate_parametrized_tests(ExportTests)
4601
4602if __name__ == "__main__":
4603    from torch._dynamo.test_case import run_tests
4604
4605    run_tests()
4606