1# Owner(s): ["oncall: export"] 2# flake8: noqa 3import copy 4import dataclasses 5import unittest 6from contextlib import contextmanager 7from dataclasses import dataclass 8from re import escape 9from typing import Any, List 10 11import torch 12import torch._dynamo as torchdynamo 13from functorch.experimental.control_flow import cond, map 14from torch import Tensor 15from torch._export.utils import ( 16 get_buffer, 17 get_param, 18 is_buffer, 19 is_param, 20 register_dataclass_as_pytree_node, 21) 22from torch._higher_order_ops.torchbind import enable_torchbind_tracing 23from torch.export import Constraint, Dim, export, FlatArgsAdapter, unflatten 24from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG 25from torch.export.unflatten import _disable_interpreter 26from torch.fx.experimental.proxy_tensor import make_fx 27from torch.testing import FileCheck 28from torch.testing._internal.common_utils import ( 29 find_library_location, 30 IS_FBCODE, 31 IS_MACOS, 32 IS_SANDCASTLE, 33 IS_WINDOWS, 34 run_tests, 35 skipIfTorchDynamo, 36 TestCase, 37) 38from torch.testing._internal.torchbind_impls import init_torchbind_implementations 39from torch.utils._pytree import ( 40 LeafSpec, 41 tree_flatten, 42 tree_unflatten, 43 TreeSpec, 44 treespec_dumps, 45 treespec_loads, 46) 47 48 49@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") 50class TestUnflatten(TestCase): 51 def compare_outputs(self, eager, unflattened, args): 52 orig_output = eager(*args) 53 unflattened_output = unflattened(*args) 54 self.assertTrue(torch.allclose(orig_output, unflattened_output)) 55 56 def test_unflatten_nested(self): 57 class NestedChild(torch.nn.Module): 58 def forward(self, x): 59 return x / x 60 61 class Child1(torch.nn.Module): 62 def __init__(self) -> None: 63 super().__init__() 64 self.nested = NestedChild() 65 self.register_parameter( 66 "child1param", torch.nn.Parameter(torch.ones(2, 3)) 67 ) 68 69 def forward(self, x): 70 x = self.nested(x) 71 return x + self.child1param 72 73 class Child2(torch.nn.Module): 74 def __init__(self) -> None: 75 super().__init__() 76 self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) 77 78 def forward(self, x): 79 return x - self.child2buffer 80 81 class MyModule(torch.nn.Module): 82 def __init__(self) -> None: 83 super().__init__() 84 self.foo = Child1() 85 self.bar = Child2() 86 self.register_parameter( 87 "rootparam", torch.nn.Parameter(torch.ones(2, 3)) 88 ) 89 90 def forward(self, x): 91 x = x * self.rootparam 92 x = self.foo(x) 93 x = self.bar(x) 94 return x 95 96 orig_eager = MyModule() 97 export_module = export(orig_eager, (torch.rand(2, 3),), {}) 98 unflattened = unflatten(export_module) 99 100 inputs = (torch.rand(2, 3),) 101 102 # Compare the root modules and all submodules 103 self.compare_outputs(orig_eager, unflattened, inputs) 104 self.compare_outputs(orig_eager.foo, unflattened.foo, inputs) 105 self.compare_outputs(orig_eager.bar, unflattened.bar, inputs) 106 self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs) 107 108 # Check state dicts are equal 109 orig_state_dict = orig_eager.state_dict() 110 exported_state_dict = unflattened.state_dict() 111 for name, value in orig_state_dict.items(): 112 self.assertTrue(torch.allclose(value, exported_state_dict[name])) 113 114 def test_unflatten_buffer_mutation(self): 115 class Child(torch.nn.Module): 116 def __init__(self) -> None: 117 super().__init__() 118 self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) 119 120 def forward(self, x): 121 self.child2buffer.add_(x) 122 return x - self.child2buffer 123 124 class MyModule(torch.nn.Module): 125 def __init__(self) -> None: 126 super().__init__() 127 self.foo = Child() 128 self.register_parameter( 129 "rootparam", torch.nn.Parameter(torch.ones(2, 3)) 130 ) 131 132 def forward(self, x): 133 x = self.foo(x) 134 return x * self.rootparam 135 136 eager_module = MyModule() 137 export_module = export(eager_module, (torch.rand(2, 3),), {}) 138 unflattened_module = unflatten(export_module) 139 140 # Buffer should look the same before and after one run 141 eager_buffer = eager_module.foo.child2buffer 142 unflattened_buffer = unflattened_module.foo.child2buffer 143 self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer)) 144 145 inputs = (torch.rand(2, 3),) 146 eager_module(*inputs) 147 unflattened_module(*inputs) 148 self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer)) 149 150 def test_unflatten_nested_access(self): 151 class Child(torch.nn.Module): 152 def __init__(self) -> None: 153 super().__init__() 154 self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) 155 156 def forward(self, x): 157 return x - self.child2buffer 158 159 class MyModule(torch.nn.Module): 160 def __init__(self) -> None: 161 super().__init__() 162 self.foo = Child() 163 self.register_parameter( 164 "rootparam", torch.nn.Parameter(torch.ones(2, 3)) 165 ) 166 167 def forward(self, x): 168 x = x + self.foo.child2buffer 169 x = self.foo(x) 170 return x 171 172 eager_module = MyModule() 173 export_module = export(eager_module, (torch.rand(2, 3),), {}) 174 unflattened_module = unflatten(export_module) 175 176 inputs = (torch.rand(2, 3),) 177 self.compare_outputs(eager_module, unflattened_module, inputs) 178 179 def test_unflatten_shared_submodule(self): 180 class Shared(torch.nn.Module): 181 def __init__(self) -> None: 182 super().__init__() 183 layernorm = torch.nn.LayerNorm(10) 184 self.sub_net = torch.nn.Sequential( 185 layernorm, 186 torch.nn.ReLU(), 187 layernorm, 188 torch.nn.ReLU(), 189 ) 190 191 def forward(self, x): 192 return self.sub_net(x) 193 194 eager_module = Shared() 195 inps = (torch.rand(10),) 196 export_module = export(eager_module, inps, {}) 197 unflattened_module = unflatten(export_module) 198 self.compare_outputs(eager_module, unflattened_module, inps) 199 self.assertTrue(hasattr(unflattened_module, "sub_net")) 200 for i in range(len(eager_module.sub_net)): 201 self.assertTrue(hasattr(unflattened_module.sub_net, str(i))) 202 self.assertEqual( 203 id(getattr(unflattened_module.sub_net, "0")), 204 id(getattr(unflattened_module.sub_net, "2")), 205 ) 206 207 @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 208 @skipIfTorchDynamo("Non strict mode is not meant to run with dynamo") 209 def test_unflatten_preserve_signature(self): 210 class NestedChild(torch.nn.Module): 211 def forward(self, zx, y): 212 return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]} 213 214 class Child1(torch.nn.Module): 215 def __init__(self) -> None: 216 super().__init__() 217 self.nested = NestedChild() 218 219 def forward(self, x, y): 220 z = torch.ones_like(x) 221 xw = self.nested((z, x), y={"key": y}) 222 return xw["w"] + z - xw["x"] 223 224 class Child2(torch.nn.Module): 225 def __init__(self) -> None: 226 super().__init__() 227 228 def forward(self, x): 229 return x - 1 230 231 class MyModule(torch.nn.Module): 232 def __init__(self) -> None: 233 super().__init__() 234 self.foo = Child1() 235 self.bar = Child2() 236 237 def forward(self, x, y): 238 x = self.foo(x, y) 239 x = self.bar(x) 240 return x 241 242 orig_eager = MyModule() 243 inps = torch.rand(2, 3), torch.rand(2, 3) 244 for strict in [True, False]: 245 export_module = export( 246 orig_eager, 247 inps, 248 {}, 249 preserve_module_call_signature=("foo.nested",), 250 strict=strict, 251 ) 252 unflattened = unflatten(export_module) 253 self.compare_outputs(export_module.module(), unflattened, inps) 254 unflattened.foo.nested = NestedChild() 255 self.compare_outputs(export_module.module(), unflattened, inps) 256 257 # Test tree spec mismatched input 258 orig_outs = export_module.module()(*inps) 259 new_inps = *inps, torch.rand(2, 3) 260 with self.assertRaisesRegex( 261 TypeError, 262 "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?", 263 ): 264 unflattened(new_inps) 265 266 # With flat args adapter 267 class KeepTwoFlatArgsAdapter(FlatArgsAdapter): 268 def adapt( 269 self, 270 target_spec: TreeSpec, 271 input_spec: TreeSpec, 272 input_args: List[Any], 273 ) -> List[Any]: 274 while len(input_args) > 2: 275 input_args.pop(-1) 276 return input_args 277 278 unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter()) 279 new_outs = unflattened(*new_inps) 280 self.assertTrue(torch.allclose(orig_outs, new_outs)) 281 282 def test_unflatten_param_list_dict(self): 283 class Mod(torch.nn.Module): 284 def __init__(self) -> None: 285 super().__init__() 286 self.param_list = torch.nn.ParameterList() 287 self.param_dict = torch.nn.ParameterDict() 288 for i in range(2): 289 self.param_list.append(torch.nn.Parameter(torch.randn((2, 3)))) 290 self.param_dict[f"key_{i}"] = torch.nn.Parameter( 291 torch.randn((2, 3)) 292 ) 293 294 def forward(self, x): 295 for i in range(2): 296 x = x + self.param_list[i] 297 x = x + self.param_dict[f"key_{i}"] 298 return x 299 300 export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) 301 unflattened = unflatten(export_module) 302 303 self.compare_outputs( 304 export_module.module(), unflattened, (torch.randn((2, 3)),) 305 ) 306 307 @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 308 def test_unflatten_preserve_with_unused_input(self): 309 class M1(torch.nn.Module): 310 def forward(self, x, a, b): 311 return x + a, b 312 313 class M(torch.nn.Module): 314 def __init__(self) -> None: 315 super().__init__() 316 self.m1 = M1() 317 318 def forward(self, x, y): 319 a, b = torch.topk(y, 2) 320 return self.m1(x, a, b)[0] 321 322 ep = torch.export.export( 323 M(), 324 (torch.randn(2), torch.randn(5)), 325 preserve_module_call_signature=("m1",), 326 strict=False, 327 ) 328 ep.graph.eliminate_dead_code() 329 unflattened = unflatten(ep) 330 self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5))) 331 332 def test_unflatten_wrong_input(self): 333 class Mod(torch.nn.Module): 334 def __init__(self) -> None: 335 super().__init__() 336 self.param_list = torch.nn.ParameterList() 337 self.param_dict = torch.nn.ParameterDict() 338 for i in range(2): 339 self.param_list.append(torch.nn.Parameter(torch.randn((2, 3)))) 340 self.param_dict[f"key_{i}"] = torch.nn.Parameter( 341 torch.randn((2, 3)) 342 ) 343 344 def forward(self, x): 345 a = x.sum() 346 for i in range(2): 347 a = a + self.param_list[i].sum() 348 a = a + self.param_dict[f"key_{i}"].sum() 349 return a 350 351 export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) 352 with self.assertRaisesRegex( 353 RuntimeError, 354 escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), 355 ): 356 export_module.module()(torch.randn(6, 6)) 357 358 unflattened = unflatten(export_module) 359 with self.assertRaisesRegex( 360 RuntimeError, 361 escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), 362 ): 363 unflattened(torch.randn(6, 6)) 364 365 @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 366 def test_unflatten_with_inplace_compile(self): 367 class NestedChild(torch.nn.Module): 368 def forward(self, x): 369 return x / x 370 371 class Child1(torch.nn.Module): 372 def __init__(self) -> None: 373 super().__init__() 374 self.nested = NestedChild() 375 self.register_parameter( 376 "child1param", torch.nn.Parameter(torch.ones(2, 3)) 377 ) 378 379 def forward(self, x): 380 x = self.nested(x) 381 return x + self.child1param 382 383 class Child2(torch.nn.Module): 384 def __init__(self) -> None: 385 super().__init__() 386 self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) 387 388 def forward(self, x): 389 return x - self.child2buffer 390 391 class MyModule(torch.nn.Module): 392 def __init__(self) -> None: 393 super().__init__() 394 self.foo = Child1() 395 self.bar = Child2() 396 self.register_parameter( 397 "rootparam", torch.nn.Parameter(torch.ones(2, 3)) 398 ) 399 400 def forward(self, x): 401 x = x * self.rootparam 402 x = self.foo(x) 403 x = self.bar(x) 404 return x 405 406 orig_eager = MyModule() 407 export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) 408 unflattened = unflatten(export_module) 409 410 # in-place compilation should work. Pass fullgraph to ensure no graph breaks. 411 from torch._dynamo.backends.debugging import ExplainWithBackend 412 413 eb = ExplainWithBackend("inductor") 414 unflattened.foo.compile(backend=eb, fullgraph=True) 415 inputs = (torch.randn(2, 3),) 416 self.compare_outputs(orig_eager, unflattened, inputs) 417 self.assertEqual(len(eb.graphs), 1) 418 419 def test_fx_trace(self): 420 class MyModule(torch.nn.Module): 421 def __init__(self) -> None: 422 super().__init__() 423 424 def forward(self, x, y): 425 x = x[0] + x[1] 426 x = x + y["foo"] 427 return x 428 429 orig_eager = MyModule() 430 inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)}) 431 export_module = export(orig_eager, inputs, {}) 432 433 unflattened = unflatten(export_module) 434 torch.fx.symbolic_trace( 435 unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH) 436 ) 437 438 def test_double_nested_submodule(self): 439 class SubSubMod(torch.nn.Module): 440 def __init__(self) -> None: 441 super().__init__() 442 443 def forward(self, x): 444 return x * x 445 446 class SubMod(torch.nn.Module): 447 def __init__(self) -> None: 448 super().__init__() 449 self.subsubmod = SubSubMod() 450 451 def forward(self, x): 452 return x - x 453 454 class MyModule(torch.nn.Module): 455 def __init__(self) -> None: 456 super().__init__() 457 self.submod = SubMod() 458 459 def forward(self, x): 460 return x + self.submod.subsubmod(x) 461 462 orig_eager = MyModule() 463 export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) 464 unflattened = unflatten(export_module) 465 466 inputs = (torch.rand(2, 3),) 467 self.compare_outputs(orig_eager, unflattened, inputs) 468 469 def test_unflatten_container_type(self): 470 class Leaf(torch.nn.Module): 471 def __init__(self) -> None: 472 super().__init__() 473 self.linear = torch.nn.Linear(4, 4) 474 475 def forward(self, x): 476 return self.linear(x) 477 478 class Bar(torch.nn.Module): 479 def __init__(self) -> None: 480 super().__init__() 481 self.leaf = Leaf() 482 self.buffer = torch.nn.Buffer(torch.randn(4, 4)) 483 484 def forward(self, x, z): 485 return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum() 486 487 class Foo(torch.nn.Module): 488 def __init__(self) -> None: 489 super().__init__() 490 self.bar = Bar() 491 492 def forward(self, x, z): 493 y = self.bar.buffer + x + z[0] + z[1] 494 return self.bar(x, z) + y.sum() 495 496 inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)]) 497 mod = Foo() 498 ep_strict = torch.export.export(mod, inp) 499 ep_non_strict = torch.export.export(mod, inp, strict=False) 500 501 gm_unflat_non_strict = unflatten(ep_non_strict) 502 ep = torch.export.export(gm_unflat_non_strict, inp, strict=False) 503 self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp))) 504 505 def test_unflattened_module_nodes_has_meta_val(self): 506 class SubMod(torch.nn.Module): 507 def __init__(self) -> None: 508 super().__init__() 509 510 def forward(self, x): 511 return x + x, x * x 512 513 class MyModule(torch.nn.Module): 514 def __init__(self) -> None: 515 super().__init__() 516 self.submod = SubMod() 517 518 def forward(self, x): 519 return x + sum(self.submod(x)) 520 521 orig_eager = MyModule() 522 export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) 523 unflattened = unflatten(export_module) 524 525 inputs = (torch.rand(2, 3),) 526 self.compare_outputs(orig_eager, unflattened, inputs) 527 528 def check_meta(gm): 529 for n in gm.graph.nodes: 530 if n.op == "output": 531 continue 532 self.assertTrue(n.meta.get("val") is not None) 533 534 for m in unflattened.modules(): 535 check_meta(m) 536 537 def test_unflatten_requires_grad_param(self): 538 class M(torch.nn.Module): 539 def __init__(self) -> None: 540 super().__init__() 541 self.p = torch.nn.Parameter(torch.ones(3, 3), requires_grad=False) 542 543 def forward(self, x): 544 return self.p + x 545 546 with torch.device("meta"): 547 mod = M() 548 549 inputs = (torch.randn(3, 3, device="meta"),) 550 ep = export(mod, inputs) 551 unflattened = unflatten(ep) 552 self.assertTrue(unflattened.state_dict()["p"].requires_grad is False) 553 self.assertTrue(unflattened.p.requires_grad is False) 554 555 def test_placeholder_and_get_attr_ordering_after_unflattened(self): 556 class TransposeModule(torch.nn.Module): 557 def __init__(self) -> None: 558 super().__init__() 559 self.conv = torch.nn.Conv2d(3, 1, 3, stride=2) 560 561 def forward(self, x): 562 x = self.conv(x) 563 return x.transpose(0, 1) 564 565 x = torch.randn(32, 3, 64, 64) 566 exported_program = export(TransposeModule(), args=(x,)) 567 unflattened_module = unflatten(exported_program) 568 569 # Check the inputs of the created call_module node are in order 570 call_module_input_order = [] 571 for node in unflattened_module.graph.nodes: 572 if node.op == "call_module": 573 transpose_module = unflattened_module.get_submodule(node.target) 574 for sub_node in transpose_module.graph.nodes: 575 if sub_node.op == "placeholder" or sub_node.op == "get_attr": 576 call_module_input_order.append(sub_node.op) 577 self.assertEqual( 578 call_module_input_order, ["placeholder", "get_attr", "get_attr"] 579 ) 580 581 def test_unflatten_constant_tensor(self): 582 class SubMod(torch.nn.Module): 583 def __init__(self) -> None: 584 super().__init__() 585 self.initializer = 0.1 586 587 def forward(self, x): 588 return x + torch.tensor(self.initializer) 589 590 class Mod(torch.nn.Module): 591 def __init__(self) -> None: 592 super().__init__() 593 self.submod = SubMod() 594 595 def forward(self, x): 596 return x + self.submod(x) 597 598 export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) 599 unflattened = unflatten(export_module) 600 601 self.compare_outputs( 602 export_module.module(), unflattened, (torch.randn((2, 3)),) 603 ) 604 605 @skipIfTorchDynamo("custom objects not supported in dynamo yet") 606 def test_unflatten_constant_obj(self): 607 init_torchbind_implementations() 608 609 @torch._library.register_fake_class("_TorchScriptTesting::_Foo") 610 class FakeFoo: 611 def __init__(self, x: int, y: int): 612 self.x = x 613 self.y = y 614 615 @classmethod 616 def __obj_unflatten__(cls, flat_ctx): 617 return cls(**dict(flat_ctx)) 618 619 def add_tensor(self, z): 620 return (self.x + self.y) * z 621 622 class SubMod(torch.nn.Module): 623 def __init__(self) -> None: 624 super().__init__() 625 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 626 627 def forward(self, x): 628 return x + self.attr.add_tensor(x) 629 630 class Mod(torch.nn.Module): 631 def __init__(self) -> None: 632 super().__init__() 633 self.submod = SubMod() 634 635 def forward(self, x): 636 return x + self.submod(x) 637 638 with enable_torchbind_tracing(): 639 export_module = torch.export.export( 640 Mod(), (torch.randn((2, 3)),), strict=False 641 ) 642 unflattened = unflatten(export_module) 643 644 self.compare_outputs( 645 export_module.module(), unflattened, (torch.randn((2, 3)),) 646 ) 647 648 # skip connection is not supported yet 649 @unittest.expectedFailure 650 def test_unflatten_skipped_call_module(self): 651 class C(torch.nn.Module): 652 def __init__(self): 653 super().__init__() 654 655 def forward(self, x): 656 return a.d(x.cos()) 657 658 class B(torch.nn.Module): 659 def __init__(self): 660 super().__init__() 661 self.c = C() 662 663 def forward(self, x): 664 return self.c(x) + x 665 666 class D(torch.nn.Module): 667 def __init__(self): 668 super().__init__() 669 670 def forward(self, x): 671 return x.sin() 672 673 class A(torch.nn.Module): 674 def __init__(self): 675 super().__init__() 676 self.b = B() 677 self.d = D() 678 679 def forward(self, x): 680 return self.b(x) 681 682 a = A() 683 684 # The call chain looks like this: 685 # A -> B -> C -> A.d 686 ep = torch.export.export(a, (torch.randn(3),), strict=False) 687 unflattened = unflatten(ep) 688 689 def test_nested_leaf_non_strict(self): 690 class Leaf(torch.nn.Module): 691 def forward(self, x): 692 return x + 1 693 694 class Nested(torch.nn.Module): 695 def __init__(self) -> None: 696 super().__init__() 697 self.leaf = Leaf() 698 699 def forward(self, x): 700 return self.leaf(x) + 2 701 702 class TopLevel(torch.nn.Module): 703 def __init__(self) -> None: 704 super().__init__() 705 self.nested = Nested() 706 707 def forward(self, x): 708 return self.nested(x) + 3 709 710 ep = torch.export.export( 711 TopLevel(), 712 (torch.randn(3),), 713 strict=False, 714 preserve_module_call_signature=("nested",), 715 ) 716 717 torch.export.unflatten(ep) 718 719 def test_unflatten_submodule_ordering(self): 720 class Module2(torch.nn.Module): 721 def __init__(self) -> None: 722 super().__init__() 723 self.buffer = torch.nn.Buffer(torch.rand(3, 4)) 724 self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4))) 725 726 def forward(self, x): 727 return x + self.buffer + self.param 728 729 class Module1(torch.nn.Module): 730 def __init__(self) -> None: 731 super().__init__() 732 self.buffer = torch.nn.Buffer(torch.rand(3, 4)) 733 self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4))) 734 735 def forward(self, x): 736 return x + self.buffer + self.param 737 738 class Module(torch.nn.Module): 739 def __init__(self) -> None: 740 super().__init__() 741 self.mod2 = Module2() 742 self.mod3 = self.mod2 743 self.mod1 = Module1() 744 745 def forward(self, x): 746 return self.mod3(self.mod2(self.mod1(x))) 747 748 mod = Module() 749 750 ep = torch.export.export(mod, (torch.randn(3, 4),)) 751 752 unflattened = torch.export.unflatten(ep) 753 fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)] 754 self.assertEqual(len(fqn_list), 4) 755 self.assertEqual( 756 [x for x, _ in mod.named_modules(remove_duplicate=False)], 757 fqn_list, 758 ) 759 760 def test_duplicate_placeholder(self): 761 N, C, H, W = 1, 2, 2, 3 762 763 class MyModule(torch.nn.Module): 764 def __init__(self) -> None: 765 super().__init__() 766 layer = torch.nn.LayerNorm([C, H, W]) 767 self.norms = torch.nn.ModuleList( 768 [ 769 layer, # reuse layer norm 770 layer, 771 layer, 772 ] 773 ) 774 775 def forward(self, input_): 776 for i in range(len(self.norms)): 777 output = self.norms[i](input_) 778 input_ = output 779 return output 780 781 mod = MyModule() 782 input_ = torch.randn(N, C, H, W) 783 784 ep_strict = export(copy.deepcopy(mod), (input_,), strict=True) 785 umod = unflatten(ep_strict) 786 self.assertTrue(torch.allclose(umod(input_), mod(input_))) 787 788 ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False) 789 umod = unflatten(ep_non_strict) 790 self.assertTrue(torch.allclose(umod(input_), mod(input_))) 791 792 def test_simple_alias(self): 793 # handle weight sharing, check tensor ids after unflattening 794 class Foo(torch.nn.Module): 795 def __init__(self) -> None: 796 super().__init__() 797 # alias param 798 self.bias = torch.nn.Parameter(torch.randn(4)) 799 self.m = torch.nn.Linear(4, 4) 800 self.m.bias = self.bias 801 802 def forward(self, x): 803 return self.m(x) + self.bias 804 805 m = Foo() 806 inps = (torch.randn(4, 4),) 807 ep = export(m, inps) 808 unep = unflatten(ep) 809 self.assertTrue(id(unep.m.bias) == id(unep.bias)) 810 811 # handle aliasing where one alias is unused 812 class Foo(torch.nn.Module): 813 def __init__(self) -> None: 814 super().__init__() 815 self.bias = torch.nn.Parameter(torch.randn(4)) 816 self.m = torch.nn.Linear(4, 4) 817 self.m.bias = ( 818 self.bias 819 ) # self.bias is unused, aliasing should be handled 820 821 def forward(self, x): 822 return self.m(x) 823 824 m = Foo() 825 inps = (torch.randn(4, 4),) 826 ep = export(m, inps) 827 unep = unflatten(ep) 828 self.assertTrue(torch.allclose(unep(*inps), m(*inps))) 829 830 def test_attr_as_submod_input(self): 831 class layer(torch.nn.Module): 832 def forward(self, x, const) -> torch.Tensor: 833 return x + const 834 835 class M(torch.nn.Module): 836 def __init__(self) -> None: 837 super().__init__() 838 self.const = torch.nn.Buffer(torch.ones(4, 8)) 839 self.layers = torch.nn.ModuleList([layer() for _ in range(2)]) 840 841 def forward(self, x: torch.Tensor) -> torch.Tensor: 842 for layer in self.layers: 843 x = layer(x, self.const) 844 return x 845 846 mod = M() 847 x = torch.randn(4, 8) 848 ep = export(mod, (x,)) 849 unflattened = unflatten(ep) 850 torch.testing.assert_close(unflattened(x), mod(x)) 851 852 def test_dedup_sym_size(self): 853 # Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2), 854 # but only one copy of sym_size is created in the initial export graph. 855 # For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature, 856 # but for m2 floordiv should be passed in as a placeholder. 857 # Test that this is preserved, and the unflattened module runs correctly. 858 class M1(torch.nn.Module): 859 def forward(self, x, y): 860 d = x.size(0) // 2 861 return y[:d] 862 863 class M2(torch.nn.Module): 864 def forward(self, x, y): 865 d = x.size(0) // 2 866 return y[:d] 867 868 class M(torch.nn.Module): 869 def __init__(self) -> None: 870 super().__init__() 871 self.m1 = M1() 872 self.m2 = M2() 873 874 def forward(self, x, y): 875 d = x.size(0) // 2 876 m1_res = self.m1(x, y) 877 m2_res = self.m2(x, y) 878 return y[d:] + m1_res + m2_res 879 880 inputs = (torch.ones(10), torch.ones(10)) 881 d_ = torch.export.Dim("foo", max=2048) 882 d = 2 * d_ 883 ep = torch.export.export( 884 M(), 885 inputs, 886 dynamic_shapes=((d,), (d,)), 887 strict=False, 888 preserve_module_call_signature=("m1",), 889 ) 890 unflat = unflatten(ep) 891 unflat(*inputs) 892 893 fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count( 894 torch.ops.aten.sym_size.int 895 ) 896 self.assertEqual(fn_count_sym_size(unflat.graph), 1) 897 self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1) 898 self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0) 899 900 def test_unflatten_eager(self): 901 class NestedChild(torch.nn.Module): 902 def forward(self, x): 903 return x / x 904 905 class Child1(torch.nn.Module): 906 def __init__(self) -> None: 907 super().__init__() 908 self.nested = NestedChild() 909 self.register_parameter( 910 "child1param", torch.nn.Parameter(torch.ones(2, 3)) 911 ) 912 913 def forward(self, x): 914 x = self.nested(x) 915 return x + self.child1param 916 917 class Child2(torch.nn.Module): 918 def __init__(self) -> None: 919 super().__init__() 920 self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) 921 922 def forward(self, x): 923 return x - self.child2buffer 924 925 class MyModule(torch.nn.Module): 926 def __init__(self) -> None: 927 super().__init__() 928 self.foo = Child1() 929 self.bar = Child2() 930 self.register_parameter( 931 "rootparam", torch.nn.Parameter(torch.ones(2, 3)) 932 ) 933 934 def forward(self, x): 935 x = x * self.rootparam 936 x = self.foo(x) 937 x = self.bar(x) 938 return x 939 940 orig_eager = MyModule() 941 export_module = export(orig_eager, (torch.rand(2, 3),), {}) 942 with _disable_interpreter(): 943 unflattened = unflatten(export_module) 944 945 self.assertEqual(unflattened._run_with_interpeter, False) 946 self.assertEqual(unflattened.foo._run_with_interpeter, False) 947 948 inputs = (torch.rand(2, 3),) 949 950 # Compare the root modules and all submodules 951 self.compare_outputs(orig_eager, unflattened, inputs) 952 self.compare_outputs(orig_eager.foo, unflattened.foo, inputs) 953 self.compare_outputs(orig_eager.bar, unflattened.bar, inputs) 954 self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs) 955 956 # Check state dicts are equal 957 orig_state_dict = orig_eager.state_dict() 958 exported_state_dict = unflattened.state_dict() 959 for name, value in orig_state_dict.items(): 960 self.assertTrue(torch.allclose(value, exported_state_dict[name])) 961 962 # Check composability with symbolic trace, as torchrec ddp uses symbolic 963 # tracer 964 symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs) 965 self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs))) 966 967 # torch.compile submodule 968 unflattened.foo = torch.compile(unflattened.foo, fullgraph=True) 969 self.compare_outputs(orig_eager, unflattened, inputs) 970 971 972if __name__ == "__main__": 973 run_tests() 974