xref: /aosp_15_r20/external/pytorch/test/export/test_unflatten.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2# flake8: noqa
3import copy
4import dataclasses
5import unittest
6from contextlib import contextmanager
7from dataclasses import dataclass
8from re import escape
9from typing import Any, List
10
11import torch
12import torch._dynamo as torchdynamo
13from functorch.experimental.control_flow import cond, map
14from torch import Tensor
15from torch._export.utils import (
16    get_buffer,
17    get_param,
18    is_buffer,
19    is_param,
20    register_dataclass_as_pytree_node,
21)
22from torch._higher_order_ops.torchbind import enable_torchbind_tracing
23from torch.export import Constraint, Dim, export, FlatArgsAdapter, unflatten
24from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG
25from torch.export.unflatten import _disable_interpreter
26from torch.fx.experimental.proxy_tensor import make_fx
27from torch.testing import FileCheck
28from torch.testing._internal.common_utils import (
29    find_library_location,
30    IS_FBCODE,
31    IS_MACOS,
32    IS_SANDCASTLE,
33    IS_WINDOWS,
34    run_tests,
35    skipIfTorchDynamo,
36    TestCase,
37)
38from torch.testing._internal.torchbind_impls import init_torchbind_implementations
39from torch.utils._pytree import (
40    LeafSpec,
41    tree_flatten,
42    tree_unflatten,
43    TreeSpec,
44    treespec_dumps,
45    treespec_loads,
46)
47
48
49@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
50class TestUnflatten(TestCase):
51    def compare_outputs(self, eager, unflattened, args):
52        orig_output = eager(*args)
53        unflattened_output = unflattened(*args)
54        self.assertTrue(torch.allclose(orig_output, unflattened_output))
55
56    def test_unflatten_nested(self):
57        class NestedChild(torch.nn.Module):
58            def forward(self, x):
59                return x / x
60
61        class Child1(torch.nn.Module):
62            def __init__(self) -> None:
63                super().__init__()
64                self.nested = NestedChild()
65                self.register_parameter(
66                    "child1param", torch.nn.Parameter(torch.ones(2, 3))
67                )
68
69            def forward(self, x):
70                x = self.nested(x)
71                return x + self.child1param
72
73        class Child2(torch.nn.Module):
74            def __init__(self) -> None:
75                super().__init__()
76                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
77
78            def forward(self, x):
79                return x - self.child2buffer
80
81        class MyModule(torch.nn.Module):
82            def __init__(self) -> None:
83                super().__init__()
84                self.foo = Child1()
85                self.bar = Child2()
86                self.register_parameter(
87                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
88                )
89
90            def forward(self, x):
91                x = x * self.rootparam
92                x = self.foo(x)
93                x = self.bar(x)
94                return x
95
96        orig_eager = MyModule()
97        export_module = export(orig_eager, (torch.rand(2, 3),), {})
98        unflattened = unflatten(export_module)
99
100        inputs = (torch.rand(2, 3),)
101
102        # Compare the root modules and all submodules
103        self.compare_outputs(orig_eager, unflattened, inputs)
104        self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
105        self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
106        self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
107
108        # Check state dicts are equal
109        orig_state_dict = orig_eager.state_dict()
110        exported_state_dict = unflattened.state_dict()
111        for name, value in orig_state_dict.items():
112            self.assertTrue(torch.allclose(value, exported_state_dict[name]))
113
114    def test_unflatten_buffer_mutation(self):
115        class Child(torch.nn.Module):
116            def __init__(self) -> None:
117                super().__init__()
118                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
119
120            def forward(self, x):
121                self.child2buffer.add_(x)
122                return x - self.child2buffer
123
124        class MyModule(torch.nn.Module):
125            def __init__(self) -> None:
126                super().__init__()
127                self.foo = Child()
128                self.register_parameter(
129                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
130                )
131
132            def forward(self, x):
133                x = self.foo(x)
134                return x * self.rootparam
135
136        eager_module = MyModule()
137        export_module = export(eager_module, (torch.rand(2, 3),), {})
138        unflattened_module = unflatten(export_module)
139
140        # Buffer should look the same before and after one run
141        eager_buffer = eager_module.foo.child2buffer
142        unflattened_buffer = unflattened_module.foo.child2buffer
143        self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
144
145        inputs = (torch.rand(2, 3),)
146        eager_module(*inputs)
147        unflattened_module(*inputs)
148        self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
149
150    def test_unflatten_nested_access(self):
151        class Child(torch.nn.Module):
152            def __init__(self) -> None:
153                super().__init__()
154                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
155
156            def forward(self, x):
157                return x - self.child2buffer
158
159        class MyModule(torch.nn.Module):
160            def __init__(self) -> None:
161                super().__init__()
162                self.foo = Child()
163                self.register_parameter(
164                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
165                )
166
167            def forward(self, x):
168                x = x + self.foo.child2buffer
169                x = self.foo(x)
170                return x
171
172        eager_module = MyModule()
173        export_module = export(eager_module, (torch.rand(2, 3),), {})
174        unflattened_module = unflatten(export_module)
175
176        inputs = (torch.rand(2, 3),)
177        self.compare_outputs(eager_module, unflattened_module, inputs)
178
179    def test_unflatten_shared_submodule(self):
180        class Shared(torch.nn.Module):
181            def __init__(self) -> None:
182                super().__init__()
183                layernorm = torch.nn.LayerNorm(10)
184                self.sub_net = torch.nn.Sequential(
185                    layernorm,
186                    torch.nn.ReLU(),
187                    layernorm,
188                    torch.nn.ReLU(),
189                )
190
191            def forward(self, x):
192                return self.sub_net(x)
193
194        eager_module = Shared()
195        inps = (torch.rand(10),)
196        export_module = export(eager_module, inps, {})
197        unflattened_module = unflatten(export_module)
198        self.compare_outputs(eager_module, unflattened_module, inps)
199        self.assertTrue(hasattr(unflattened_module, "sub_net"))
200        for i in range(len(eager_module.sub_net)):
201            self.assertTrue(hasattr(unflattened_module.sub_net, str(i)))
202        self.assertEqual(
203            id(getattr(unflattened_module.sub_net, "0")),
204            id(getattr(unflattened_module.sub_net, "2")),
205        )
206
207    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
208    @skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
209    def test_unflatten_preserve_signature(self):
210        class NestedChild(torch.nn.Module):
211            def forward(self, zx, y):
212                return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
213
214        class Child1(torch.nn.Module):
215            def __init__(self) -> None:
216                super().__init__()
217                self.nested = NestedChild()
218
219            def forward(self, x, y):
220                z = torch.ones_like(x)
221                xw = self.nested((z, x), y={"key": y})
222                return xw["w"] + z - xw["x"]
223
224        class Child2(torch.nn.Module):
225            def __init__(self) -> None:
226                super().__init__()
227
228            def forward(self, x):
229                return x - 1
230
231        class MyModule(torch.nn.Module):
232            def __init__(self) -> None:
233                super().__init__()
234                self.foo = Child1()
235                self.bar = Child2()
236
237            def forward(self, x, y):
238                x = self.foo(x, y)
239                x = self.bar(x)
240                return x
241
242        orig_eager = MyModule()
243        inps = torch.rand(2, 3), torch.rand(2, 3)
244        for strict in [True, False]:
245            export_module = export(
246                orig_eager,
247                inps,
248                {},
249                preserve_module_call_signature=("foo.nested",),
250                strict=strict,
251            )
252            unflattened = unflatten(export_module)
253            self.compare_outputs(export_module.module(), unflattened, inps)
254            unflattened.foo.nested = NestedChild()
255            self.compare_outputs(export_module.module(), unflattened, inps)
256
257            # Test tree spec mismatched input
258            orig_outs = export_module.module()(*inps)
259            new_inps = *inps, torch.rand(2, 3)
260            with self.assertRaisesRegex(
261                TypeError,
262                "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
263            ):
264                unflattened(new_inps)
265
266            # With flat args adapter
267            class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
268                def adapt(
269                    self,
270                    target_spec: TreeSpec,
271                    input_spec: TreeSpec,
272                    input_args: List[Any],
273                ) -> List[Any]:
274                    while len(input_args) > 2:
275                        input_args.pop(-1)
276                    return input_args
277
278            unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
279            new_outs = unflattened(*new_inps)
280            self.assertTrue(torch.allclose(orig_outs, new_outs))
281
282    def test_unflatten_param_list_dict(self):
283        class Mod(torch.nn.Module):
284            def __init__(self) -> None:
285                super().__init__()
286                self.param_list = torch.nn.ParameterList()
287                self.param_dict = torch.nn.ParameterDict()
288                for i in range(2):
289                    self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
290                    self.param_dict[f"key_{i}"] = torch.nn.Parameter(
291                        torch.randn((2, 3))
292                    )
293
294            def forward(self, x):
295                for i in range(2):
296                    x = x + self.param_list[i]
297                    x = x + self.param_dict[f"key_{i}"]
298                return x
299
300        export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
301        unflattened = unflatten(export_module)
302
303        self.compare_outputs(
304            export_module.module(), unflattened, (torch.randn((2, 3)),)
305        )
306
307    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
308    def test_unflatten_preserve_with_unused_input(self):
309        class M1(torch.nn.Module):
310            def forward(self, x, a, b):
311                return x + a, b
312
313        class M(torch.nn.Module):
314            def __init__(self) -> None:
315                super().__init__()
316                self.m1 = M1()
317
318            def forward(self, x, y):
319                a, b = torch.topk(y, 2)
320                return self.m1(x, a, b)[0]
321
322        ep = torch.export.export(
323            M(),
324            (torch.randn(2), torch.randn(5)),
325            preserve_module_call_signature=("m1",),
326            strict=False,
327        )
328        ep.graph.eliminate_dead_code()
329        unflattened = unflatten(ep)
330        self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5)))
331
332    def test_unflatten_wrong_input(self):
333        class Mod(torch.nn.Module):
334            def __init__(self) -> None:
335                super().__init__()
336                self.param_list = torch.nn.ParameterList()
337                self.param_dict = torch.nn.ParameterDict()
338                for i in range(2):
339                    self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
340                    self.param_dict[f"key_{i}"] = torch.nn.Parameter(
341                        torch.randn((2, 3))
342                    )
343
344            def forward(self, x):
345                a = x.sum()
346                for i in range(2):
347                    a = a + self.param_list[i].sum()
348                    a = a + self.param_dict[f"key_{i}"].sum()
349                return a
350
351        export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
352        with self.assertRaisesRegex(
353            RuntimeError,
354            escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
355        ):
356            export_module.module()(torch.randn(6, 6))
357
358        unflattened = unflatten(export_module)
359        with self.assertRaisesRegex(
360            RuntimeError,
361            escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
362        ):
363            unflattened(torch.randn(6, 6))
364
365    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
366    def test_unflatten_with_inplace_compile(self):
367        class NestedChild(torch.nn.Module):
368            def forward(self, x):
369                return x / x
370
371        class Child1(torch.nn.Module):
372            def __init__(self) -> None:
373                super().__init__()
374                self.nested = NestedChild()
375                self.register_parameter(
376                    "child1param", torch.nn.Parameter(torch.ones(2, 3))
377                )
378
379            def forward(self, x):
380                x = self.nested(x)
381                return x + self.child1param
382
383        class Child2(torch.nn.Module):
384            def __init__(self) -> None:
385                super().__init__()
386                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
387
388            def forward(self, x):
389                return x - self.child2buffer
390
391        class MyModule(torch.nn.Module):
392            def __init__(self) -> None:
393                super().__init__()
394                self.foo = Child1()
395                self.bar = Child2()
396                self.register_parameter(
397                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
398                )
399
400            def forward(self, x):
401                x = x * self.rootparam
402                x = self.foo(x)
403                x = self.bar(x)
404                return x
405
406        orig_eager = MyModule()
407        export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
408        unflattened = unflatten(export_module)
409
410        # in-place compilation should work. Pass fullgraph to ensure no graph breaks.
411        from torch._dynamo.backends.debugging import ExplainWithBackend
412
413        eb = ExplainWithBackend("inductor")
414        unflattened.foo.compile(backend=eb, fullgraph=True)
415        inputs = (torch.randn(2, 3),)
416        self.compare_outputs(orig_eager, unflattened, inputs)
417        self.assertEqual(len(eb.graphs), 1)
418
419    def test_fx_trace(self):
420        class MyModule(torch.nn.Module):
421            def __init__(self) -> None:
422                super().__init__()
423
424            def forward(self, x, y):
425                x = x[0] + x[1]
426                x = x + y["foo"]
427                return x
428
429        orig_eager = MyModule()
430        inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
431        export_module = export(orig_eager, inputs, {})
432
433        unflattened = unflatten(export_module)
434        torch.fx.symbolic_trace(
435            unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH)
436        )
437
438    def test_double_nested_submodule(self):
439        class SubSubMod(torch.nn.Module):
440            def __init__(self) -> None:
441                super().__init__()
442
443            def forward(self, x):
444                return x * x
445
446        class SubMod(torch.nn.Module):
447            def __init__(self) -> None:
448                super().__init__()
449                self.subsubmod = SubSubMod()
450
451            def forward(self, x):
452                return x - x
453
454        class MyModule(torch.nn.Module):
455            def __init__(self) -> None:
456                super().__init__()
457                self.submod = SubMod()
458
459            def forward(self, x):
460                return x + self.submod.subsubmod(x)
461
462        orig_eager = MyModule()
463        export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
464        unflattened = unflatten(export_module)
465
466        inputs = (torch.rand(2, 3),)
467        self.compare_outputs(orig_eager, unflattened, inputs)
468
469    def test_unflatten_container_type(self):
470        class Leaf(torch.nn.Module):
471            def __init__(self) -> None:
472                super().__init__()
473                self.linear = torch.nn.Linear(4, 4)
474
475            def forward(self, x):
476                return self.linear(x)
477
478        class Bar(torch.nn.Module):
479            def __init__(self) -> None:
480                super().__init__()
481                self.leaf = Leaf()
482                self.buffer = torch.nn.Buffer(torch.randn(4, 4))
483
484            def forward(self, x, z):
485                return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum()
486
487        class Foo(torch.nn.Module):
488            def __init__(self) -> None:
489                super().__init__()
490                self.bar = Bar()
491
492            def forward(self, x, z):
493                y = self.bar.buffer + x + z[0] + z[1]
494                return self.bar(x, z) + y.sum()
495
496        inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
497        mod = Foo()
498        ep_strict = torch.export.export(mod, inp)
499        ep_non_strict = torch.export.export(mod, inp, strict=False)
500
501        gm_unflat_non_strict = unflatten(ep_non_strict)
502        ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
503        self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp)))
504
505    def test_unflattened_module_nodes_has_meta_val(self):
506        class SubMod(torch.nn.Module):
507            def __init__(self) -> None:
508                super().__init__()
509
510            def forward(self, x):
511                return x + x, x * x
512
513        class MyModule(torch.nn.Module):
514            def __init__(self) -> None:
515                super().__init__()
516                self.submod = SubMod()
517
518            def forward(self, x):
519                return x + sum(self.submod(x))
520
521        orig_eager = MyModule()
522        export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
523        unflattened = unflatten(export_module)
524
525        inputs = (torch.rand(2, 3),)
526        self.compare_outputs(orig_eager, unflattened, inputs)
527
528        def check_meta(gm):
529            for n in gm.graph.nodes:
530                if n.op == "output":
531                    continue
532                self.assertTrue(n.meta.get("val") is not None)
533
534        for m in unflattened.modules():
535            check_meta(m)
536
537    def test_unflatten_requires_grad_param(self):
538        class M(torch.nn.Module):
539            def __init__(self) -> None:
540                super().__init__()
541                self.p = torch.nn.Parameter(torch.ones(3, 3), requires_grad=False)
542
543            def forward(self, x):
544                return self.p + x
545
546        with torch.device("meta"):
547            mod = M()
548
549        inputs = (torch.randn(3, 3, device="meta"),)
550        ep = export(mod, inputs)
551        unflattened = unflatten(ep)
552        self.assertTrue(unflattened.state_dict()["p"].requires_grad is False)
553        self.assertTrue(unflattened.p.requires_grad is False)
554
555    def test_placeholder_and_get_attr_ordering_after_unflattened(self):
556        class TransposeModule(torch.nn.Module):
557            def __init__(self) -> None:
558                super().__init__()
559                self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
560
561            def forward(self, x):
562                x = self.conv(x)
563                return x.transpose(0, 1)
564
565        x = torch.randn(32, 3, 64, 64)
566        exported_program = export(TransposeModule(), args=(x,))
567        unflattened_module = unflatten(exported_program)
568
569        # Check the inputs of the created call_module node are in order
570        call_module_input_order = []
571        for node in unflattened_module.graph.nodes:
572            if node.op == "call_module":
573                transpose_module = unflattened_module.get_submodule(node.target)
574                for sub_node in transpose_module.graph.nodes:
575                    if sub_node.op == "placeholder" or sub_node.op == "get_attr":
576                        call_module_input_order.append(sub_node.op)
577        self.assertEqual(
578            call_module_input_order, ["placeholder", "get_attr", "get_attr"]
579        )
580
581    def test_unflatten_constant_tensor(self):
582        class SubMod(torch.nn.Module):
583            def __init__(self) -> None:
584                super().__init__()
585                self.initializer = 0.1
586
587            def forward(self, x):
588                return x + torch.tensor(self.initializer)
589
590        class Mod(torch.nn.Module):
591            def __init__(self) -> None:
592                super().__init__()
593                self.submod = SubMod()
594
595            def forward(self, x):
596                return x + self.submod(x)
597
598        export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
599        unflattened = unflatten(export_module)
600
601        self.compare_outputs(
602            export_module.module(), unflattened, (torch.randn((2, 3)),)
603        )
604
605    @skipIfTorchDynamo("custom objects not supported in dynamo yet")
606    def test_unflatten_constant_obj(self):
607        init_torchbind_implementations()
608
609        @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
610        class FakeFoo:
611            def __init__(self, x: int, y: int):
612                self.x = x
613                self.y = y
614
615            @classmethod
616            def __obj_unflatten__(cls, flat_ctx):
617                return cls(**dict(flat_ctx))
618
619            def add_tensor(self, z):
620                return (self.x + self.y) * z
621
622        class SubMod(torch.nn.Module):
623            def __init__(self) -> None:
624                super().__init__()
625                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
626
627            def forward(self, x):
628                return x + self.attr.add_tensor(x)
629
630        class Mod(torch.nn.Module):
631            def __init__(self) -> None:
632                super().__init__()
633                self.submod = SubMod()
634
635            def forward(self, x):
636                return x + self.submod(x)
637
638        with enable_torchbind_tracing():
639            export_module = torch.export.export(
640                Mod(), (torch.randn((2, 3)),), strict=False
641            )
642        unflattened = unflatten(export_module)
643
644        self.compare_outputs(
645            export_module.module(), unflattened, (torch.randn((2, 3)),)
646        )
647
648    # skip connection is not supported yet
649    @unittest.expectedFailure
650    def test_unflatten_skipped_call_module(self):
651        class C(torch.nn.Module):
652            def __init__(self):
653                super().__init__()
654
655            def forward(self, x):
656                return a.d(x.cos())
657
658        class B(torch.nn.Module):
659            def __init__(self):
660                super().__init__()
661                self.c = C()
662
663            def forward(self, x):
664                return self.c(x) + x
665
666        class D(torch.nn.Module):
667            def __init__(self):
668                super().__init__()
669
670            def forward(self, x):
671                return x.sin()
672
673        class A(torch.nn.Module):
674            def __init__(self):
675                super().__init__()
676                self.b = B()
677                self.d = D()
678
679            def forward(self, x):
680                return self.b(x)
681
682        a = A()
683
684        # The call chain looks like this:
685        # A -> B -> C -> A.d
686        ep = torch.export.export(a, (torch.randn(3),), strict=False)
687        unflattened = unflatten(ep)
688
689    def test_nested_leaf_non_strict(self):
690        class Leaf(torch.nn.Module):
691            def forward(self, x):
692                return x + 1
693
694        class Nested(torch.nn.Module):
695            def __init__(self) -> None:
696                super().__init__()
697                self.leaf = Leaf()
698
699            def forward(self, x):
700                return self.leaf(x) + 2
701
702        class TopLevel(torch.nn.Module):
703            def __init__(self) -> None:
704                super().__init__()
705                self.nested = Nested()
706
707            def forward(self, x):
708                return self.nested(x) + 3
709
710        ep = torch.export.export(
711            TopLevel(),
712            (torch.randn(3),),
713            strict=False,
714            preserve_module_call_signature=("nested",),
715        )
716
717        torch.export.unflatten(ep)
718
719    def test_unflatten_submodule_ordering(self):
720        class Module2(torch.nn.Module):
721            def __init__(self) -> None:
722                super().__init__()
723                self.buffer = torch.nn.Buffer(torch.rand(3, 4))
724                self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
725
726            def forward(self, x):
727                return x + self.buffer + self.param
728
729        class Module1(torch.nn.Module):
730            def __init__(self) -> None:
731                super().__init__()
732                self.buffer = torch.nn.Buffer(torch.rand(3, 4))
733                self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
734
735            def forward(self, x):
736                return x + self.buffer + self.param
737
738        class Module(torch.nn.Module):
739            def __init__(self) -> None:
740                super().__init__()
741                self.mod2 = Module2()
742                self.mod3 = self.mod2
743                self.mod1 = Module1()
744
745            def forward(self, x):
746                return self.mod3(self.mod2(self.mod1(x)))
747
748        mod = Module()
749
750        ep = torch.export.export(mod, (torch.randn(3, 4),))
751
752        unflattened = torch.export.unflatten(ep)
753        fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
754        self.assertEqual(len(fqn_list), 4)
755        self.assertEqual(
756            [x for x, _ in mod.named_modules(remove_duplicate=False)],
757            fqn_list,
758        )
759
760    def test_duplicate_placeholder(self):
761        N, C, H, W = 1, 2, 2, 3
762
763        class MyModule(torch.nn.Module):
764            def __init__(self) -> None:
765                super().__init__()
766                layer = torch.nn.LayerNorm([C, H, W])
767                self.norms = torch.nn.ModuleList(
768                    [
769                        layer,  # reuse layer norm
770                        layer,
771                        layer,
772                    ]
773                )
774
775            def forward(self, input_):
776                for i in range(len(self.norms)):
777                    output = self.norms[i](input_)
778                    input_ = output
779                return output
780
781        mod = MyModule()
782        input_ = torch.randn(N, C, H, W)
783
784        ep_strict = export(copy.deepcopy(mod), (input_,), strict=True)
785        umod = unflatten(ep_strict)
786        self.assertTrue(torch.allclose(umod(input_), mod(input_)))
787
788        ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False)
789        umod = unflatten(ep_non_strict)
790        self.assertTrue(torch.allclose(umod(input_), mod(input_)))
791
792    def test_simple_alias(self):
793        # handle weight sharing, check tensor ids after unflattening
794        class Foo(torch.nn.Module):
795            def __init__(self) -> None:
796                super().__init__()
797                # alias param
798                self.bias = torch.nn.Parameter(torch.randn(4))
799                self.m = torch.nn.Linear(4, 4)
800                self.m.bias = self.bias
801
802            def forward(self, x):
803                return self.m(x) + self.bias
804
805        m = Foo()
806        inps = (torch.randn(4, 4),)
807        ep = export(m, inps)
808        unep = unflatten(ep)
809        self.assertTrue(id(unep.m.bias) == id(unep.bias))
810
811        # handle aliasing where one alias is unused
812        class Foo(torch.nn.Module):
813            def __init__(self) -> None:
814                super().__init__()
815                self.bias = torch.nn.Parameter(torch.randn(4))
816                self.m = torch.nn.Linear(4, 4)
817                self.m.bias = (
818                    self.bias
819                )  # self.bias is unused, aliasing should be handled
820
821            def forward(self, x):
822                return self.m(x)
823
824        m = Foo()
825        inps = (torch.randn(4, 4),)
826        ep = export(m, inps)
827        unep = unflatten(ep)
828        self.assertTrue(torch.allclose(unep(*inps), m(*inps)))
829
830    def test_attr_as_submod_input(self):
831        class layer(torch.nn.Module):
832            def forward(self, x, const) -> torch.Tensor:
833                return x + const
834
835        class M(torch.nn.Module):
836            def __init__(self) -> None:
837                super().__init__()
838                self.const = torch.nn.Buffer(torch.ones(4, 8))
839                self.layers = torch.nn.ModuleList([layer() for _ in range(2)])
840
841            def forward(self, x: torch.Tensor) -> torch.Tensor:
842                for layer in self.layers:
843                    x = layer(x, self.const)
844                return x
845
846        mod = M()
847        x = torch.randn(4, 8)
848        ep = export(mod, (x,))
849        unflattened = unflatten(ep)
850        torch.testing.assert_close(unflattened(x), mod(x))
851
852    def test_dedup_sym_size(self):
853        # Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2),
854        # but only one copy of sym_size is created in the initial export graph.
855        # For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature,
856        # but for m2 floordiv should be passed in as a placeholder.
857        # Test that this is preserved, and the unflattened module runs correctly.
858        class M1(torch.nn.Module):
859            def forward(self, x, y):
860                d = x.size(0) // 2
861                return y[:d]
862
863        class M2(torch.nn.Module):
864            def forward(self, x, y):
865                d = x.size(0) // 2
866                return y[:d]
867
868        class M(torch.nn.Module):
869            def __init__(self) -> None:
870                super().__init__()
871                self.m1 = M1()
872                self.m2 = M2()
873
874            def forward(self, x, y):
875                d = x.size(0) // 2
876                m1_res = self.m1(x, y)
877                m2_res = self.m2(x, y)
878                return y[d:] + m1_res + m2_res
879
880        inputs = (torch.ones(10), torch.ones(10))
881        d_ = torch.export.Dim("foo", max=2048)
882        d = 2 * d_
883        ep = torch.export.export(
884            M(),
885            inputs,
886            dynamic_shapes=((d,), (d,)),
887            strict=False,
888            preserve_module_call_signature=("m1",),
889        )
890        unflat = unflatten(ep)
891        unflat(*inputs)
892
893        fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count(
894            torch.ops.aten.sym_size.int
895        )
896        self.assertEqual(fn_count_sym_size(unflat.graph), 1)
897        self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1)
898        self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0)
899
900    def test_unflatten_eager(self):
901        class NestedChild(torch.nn.Module):
902            def forward(self, x):
903                return x / x
904
905        class Child1(torch.nn.Module):
906            def __init__(self) -> None:
907                super().__init__()
908                self.nested = NestedChild()
909                self.register_parameter(
910                    "child1param", torch.nn.Parameter(torch.ones(2, 3))
911                )
912
913            def forward(self, x):
914                x = self.nested(x)
915                return x + self.child1param
916
917        class Child2(torch.nn.Module):
918            def __init__(self) -> None:
919                super().__init__()
920                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
921
922            def forward(self, x):
923                return x - self.child2buffer
924
925        class MyModule(torch.nn.Module):
926            def __init__(self) -> None:
927                super().__init__()
928                self.foo = Child1()
929                self.bar = Child2()
930                self.register_parameter(
931                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
932                )
933
934            def forward(self, x):
935                x = x * self.rootparam
936                x = self.foo(x)
937                x = self.bar(x)
938                return x
939
940        orig_eager = MyModule()
941        export_module = export(orig_eager, (torch.rand(2, 3),), {})
942        with _disable_interpreter():
943            unflattened = unflatten(export_module)
944
945        self.assertEqual(unflattened._run_with_interpeter, False)
946        self.assertEqual(unflattened.foo._run_with_interpeter, False)
947
948        inputs = (torch.rand(2, 3),)
949
950        # Compare the root modules and all submodules
951        self.compare_outputs(orig_eager, unflattened, inputs)
952        self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
953        self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
954        self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
955
956        # Check state dicts are equal
957        orig_state_dict = orig_eager.state_dict()
958        exported_state_dict = unflattened.state_dict()
959        for name, value in orig_state_dict.items():
960            self.assertTrue(torch.allclose(value, exported_state_dict[name]))
961
962        # Check composability with symbolic trace, as torchrec ddp uses symbolic
963        # tracer
964        symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs)
965        self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs)))
966
967        # torch.compile submodule
968        unflattened.foo = torch.compile(unflattened.foo, fullgraph=True)
969        self.compare_outputs(orig_eager, unflattened, inputs)
970
971
972if __name__ == "__main__":
973    run_tests()
974