1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from collections import OrderedDict 6from typing import Any, List, Tuple 7 8import torch 9import torch.nn as nn 10from torch.testing._internal.jit_utils import JitTestCase 11 12 13# Make the helper files in test/ importable 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25class TestModuleContainers(JitTestCase): 26 def test_sequential_intermediary_types(self): 27 class A(torch.nn.Module): 28 def forward(self, x): 29 return x + 3 30 31 class B(torch.nn.Module): 32 def forward(self, x): 33 return {"1": x} 34 35 class C(torch.nn.Module): 36 def __init__(self) -> None: 37 super().__init__() 38 self.foo = torch.nn.Sequential(A(), B()) 39 40 def forward(self, x): 41 return self.foo(x) 42 43 self.checkModule(C(), (torch.tensor(1),)) 44 45 def test_moduledict(self): 46 class Inner(torch.nn.Module): 47 def forward(self, x): 48 return x + 10 49 50 class Inner2(torch.nn.Module): 51 def forward(self, x): 52 return x * 2 53 54 class Inner3(torch.nn.Module): 55 def forward(self, x): 56 return (x - 4) * 3 57 58 class M(torch.nn.Module): 59 def __init__(self) -> None: 60 super().__init__() 61 modules = OrderedDict( 62 [ 63 ("one", Inner()), 64 ("two", Inner2()), 65 ("three", Inner3()), 66 ] 67 ) 68 self.moduledict = nn.ModuleDict(modules) 69 70 def forward(self, x, skip_name): 71 # type: (Tensor, str) 72 names = torch.jit.annotate(List[str], []) 73 values = [] 74 for name in self.moduledict: 75 names.append(name) 76 77 for name, mod in self.moduledict.items(): 78 if name != skip_name: 79 names.append(name) 80 x = mod(x) 81 values.append(x) 82 83 for mod in self.moduledict.values(): 84 x = mod(x) 85 values.append(x) 86 87 for key in self.moduledict.keys(): 88 names.append(key) 89 90 return x, names 91 92 class M2(M): 93 def forward(self, x, skip_name): 94 # type: (Tensor, str) 95 names = torch.jit.annotate(List[str], []) 96 values = [] 97 x2 = x 98 iter = 0 99 for name in self.moduledict: 100 names.append(name) 101 102 for i, (name, mod) in enumerate(self.moduledict.items()): 103 iter += i 104 if name != skip_name: 105 names.append(name) 106 x = mod(x) 107 values.append(x) 108 109 for i, mod in enumerate(self.moduledict.values()): 110 iter += i 111 x = mod(x) 112 values.append(x) 113 114 for i, key in enumerate(self.moduledict.keys()): 115 iter += i 116 names.append(key) 117 118 for mod, mod in zip(self.moduledict.values(), self.moduledict.values()): 119 iter += i 120 x2 = mod(mod(x2)) 121 122 return x, x2, names, iter 123 124 for name in ["", "one", "two", "three"]: 125 inp = torch.tensor(1) 126 self.checkModule(M(), (inp, name)) 127 self.checkModule(M2(), (inp, name)) 128 129 def test_custom_container_forward(self): 130 class Inner(torch.nn.Module): 131 def forward(self, x): 132 return x + 10 133 134 class CustomSequential(nn.Sequential): 135 def __init__(self) -> None: 136 super().__init__(nn.ReLU(), Inner()) 137 138 def forward(self, x): 139 x = x + 3 140 for mod in self: 141 x = mod(x) 142 return x - 5 143 144 self.checkModule(CustomSequential(), (torch.tensor(0.5),)) 145 146 class CustomModuleList(nn.ModuleList): 147 def __init__(self) -> None: 148 super().__init__([nn.ReLU(), Inner()]) 149 150 def forward(self, x): 151 x = x + 3 152 for mod in self: 153 x = mod(x) 154 return x - 5 155 156 self.checkModule(CustomModuleList(), (torch.tensor(0.5),)) 157 158 class CustomModuleDict(nn.ModuleDict): 159 def __init__(self) -> None: 160 super().__init__( 161 OrderedDict( 162 [ 163 ("one", Inner()), 164 ("two", nn.ReLU()), 165 ("three", Inner()), 166 ] 167 ) 168 ) 169 170 def forward(self, x): 171 x = x + 3 172 names = torch.jit.annotate(List[str], []) 173 for name, mod in self.items(): 174 x = mod(x) 175 names.append(name) 176 return names, x - 5 177 178 self.checkModule(CustomModuleDict(), (torch.tensor(0.5),)) 179 180 def test_script_module_list_sequential(self): 181 class M(torch.jit.ScriptModule): 182 def __init__(self, mod_list): 183 super().__init__() 184 self.mods = mod_list 185 186 @torch.jit.script_method 187 def forward(self, v): 188 for m in self.mods: 189 v = m(v) 190 return v 191 192 with torch.jit.optimized_execution(False): 193 m = M(nn.Sequential(nn.ReLU())) 194 self.assertExportImportModule(m, (torch.randn(2, 2),)) 195 196 def test_script_modulelist_index(self): 197 class Sub(torch.nn.Module): 198 def __init__(self, i): 199 super().__init__() 200 self.i = i 201 202 def forward(self, thing): 203 return thing - self.i 204 205 class M(torch.nn.Module): 206 def __init__(self) -> None: 207 super().__init__() 208 self.mods = nn.ModuleList([Sub(i) for i in range(10)]) 209 210 def forward(self, v): 211 v = self.mods[4].forward(v) 212 v = self.mods[-1].forward(v) 213 v = self.mods[-9].forward(v) 214 return v 215 216 x = torch.tensor(1) 217 self.checkModule(M(), (x,)) 218 219 class MForward(torch.nn.Module): 220 def __init__(self) -> None: 221 super().__init__() 222 self.mods = nn.ModuleList([Sub(i) for i in range(10)]) 223 224 def forward(self, v): 225 v = self.mods[4](v) 226 v = self.mods[-1](v) 227 v = self.mods[-9](v) 228 return v 229 230 self.checkModule(MForward(), (torch.tensor(1),)) 231 232 class M2(M): 233 def forward(self, v): 234 return self.mods[-11].forward(v) 235 236 with self.assertRaisesRegexWithHighlight( 237 Exception, "Index -11 out of range", "self.mods[-11]" 238 ): 239 torch.jit.script(M2()) 240 241 class M3(M): 242 def forward(self, v): 243 i = 3 244 return self.mods[i].forward(v) 245 246 with self.assertRaisesRegexWithHighlight( 247 Exception, "Enumeration is supported", "self.mods[i]" 248 ): 249 torch.jit.script(M3()) 250 251 class M4(M): 252 def forward(self, v): 253 i = 3 254 return self.mods[i].forward(v) 255 256 with self.assertRaisesRegex(Exception, "will fail because i is not a literal"): 257 torch.jit.script(M4()) 258 259 def test_module_interface_special_methods(self): 260 class CustomModuleInterface(torch.nn.Module): 261 pass 262 263 class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList): 264 def __init__(self, modules=None): 265 CustomModuleInterface.__init__(self) 266 torch.nn.ModuleList.__init__(self, modules) 267 268 class CustomSequential(CustomModuleInterface, torch.nn.Sequential): 269 def __init__(self, modules=None): 270 CustomModuleInterface.__init__(self) 271 torch.nn.Sequential.__init__(self, modules) 272 273 class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict): 274 def __init__(self, modules=None): 275 CustomModuleInterface.__init__(self) 276 torch.nn.ModuleDict.__init__(self, modules) 277 278 class MyModule(torch.nn.Module): 279 def __init__(self) -> None: 280 super().__init__() 281 # work around aliasing issue for 'is' operator by scripting ReLU up front 282 self.submod = torch.jit.script(torch.nn.ReLU()) 283 self.modulelist = CustomModuleList([self.submod]) 284 self.sequential = CustomSequential(self.submod) 285 self.moduledict = CustomModuleDict({"submod": self.submod}) 286 287 def forward(self, inputs): 288 assert ( 289 self.modulelist[0] is self.submod 290 ), "__getitem__ failing for ModuleList" 291 assert len(self.modulelist) == 1, "__len__ failing for ModuleList" 292 for module in self.modulelist: 293 assert module is self.submod, "__iter__ failing for ModuleList" 294 295 assert ( 296 self.sequential[0] is self.submod 297 ), "__getitem__ failing for Sequential" 298 assert len(self.sequential) == 1, "__len__ failing for Sequential" 299 for module in self.sequential: 300 assert module is self.submod, "__iter__ failing for Sequential" 301 302 assert ( 303 self.moduledict["submod"] is self.submod 304 ), "__getitem__ failing for ModuleDict" 305 assert len(self.moduledict) == 1, "__len__ failing for ModuleDict" 306 307 # note: unable to index moduledict with a string variable currently 308 i = 0 309 for key in self.moduledict: 310 i += 1 311 assert i == len(self.moduledict), "iteration failing for ModuleDict" 312 313 assert "submod" in self.moduledict, "__contains__ fails for ModuleDict" 314 315 for key in self.moduledict.keys(): 316 assert key == "submod", "keys() fails for ModuleDict" 317 318 for item in self.moduledict.items(): 319 assert item[0] == "submod", "items() fails for ModuleDict" 320 assert item[1] is self.submod, "items() fails for ModuleDict" 321 322 for value in self.moduledict.values(): 323 assert value is self.submod, "values() fails for ModuleDict" 324 325 return inputs 326 327 m = MyModule() 328 self.checkModule(m, [torch.randn(2, 2)]) 329 330 def test_special_method_with_override(self): 331 class CustomModuleInterface(torch.nn.Module): 332 pass 333 334 class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList): 335 def __init__(self, modules=None): 336 CustomModuleInterface.__init__(self) 337 torch.nn.ModuleList.__init__(self, modules) 338 339 def __len__(self): 340 # this is arbitrary, just to check that the overridden py __len__ from 341 # CustomModuleList takes precedence over the automatically generated 342 # __len__ added by the jit compiler 343 return 2 344 345 class MyModule(torch.nn.Module): 346 def __init__(self) -> None: 347 super().__init__() 348 # work around aliasing issue for 'is' operator by scripting ReLU up front 349 self.submod = torch.jit.script(torch.nn.ReLU()) 350 self.modulelist = CustomModuleList([self.submod]) 351 352 def forward(self, inputs): 353 assert len(self.modulelist) == 2, "__len__ failing for ModuleList" 354 return inputs 355 356 m = MyModule() 357 self.checkModule(m, [torch.randn(2, 2)]) 358 mm = torch.jit.script(m) 359 360 def test_moduledict_getitem(self): 361 class MyModule(torch.nn.Module): 362 def __init__(self) -> None: 363 super().__init__() 364 self.relu = torch.jit.script(torch.nn.ReLU()) 365 self.tanh = torch.jit.script(torch.nn.Tanh()) 366 self.moduledict = torch.nn.ModuleDict( 367 {"relu": self.relu, "tanh": self.tanh} 368 ) 369 370 def forward(self, input): 371 assert self.moduledict["relu"] is self.relu 372 assert self.moduledict["tanh"] is self.tanh 373 return input 374 375 m = MyModule() 376 self.checkModule(m, [torch.randn(2, 2)]) 377 378 def test_moduledict_keyerror(self): 379 class BadModule(torch.nn.Module): 380 def __init__(self) -> None: 381 super().__init__() 382 self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None}) 383 384 def forward(self, input): 385 assert self.moduledict["blah"] == "blah", "this is a keyerror" 386 387 with self.assertRaisesRegexWithHighlight( 388 RuntimeError, "Key Error, blah", 'self.moduledict["blah"' 389 ): 390 b = BadModule() 391 torch.jit.script(b) 392 393 class AnotherBadModule(torch.nn.Module): 394 def __init__(self) -> None: 395 super().__init__() 396 self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None}) 397 398 def forward(self, input): 399 idx = "blah" 400 assert self.moduledict[idx] == "blah", "this is a string literal error" 401 402 with self.assertRaisesRegexWithHighlight( 403 RuntimeError, 404 "Unable to extract string literal index. " 405 "ModuleDict indexing is only supported with string literals. " 406 "For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail " 407 "because i is not a literal.", 408 "self.moduledict[idx]", 409 ): 410 b = AnotherBadModule() 411 torch.jit.script(b) 412 413 def test_normal_list_attribute_with_modules_error(self): 414 """ 415 Test that an attempt to script a module with a regular list attribute 416 containing other modules fails with a relevant error message. 417 """ 418 419 class Mod(torch.nn.Module): 420 def __init__(self) -> None: 421 super().__init__() 422 self.a = [torch.nn.ReLU(), torch.nn.ReLU()] 423 424 def forward(self): 425 return len(self.a) 426 427 error_msg = "Could not infer type of list element: Cannot infer concrete type of torch.nn.Module" 428 with self.assertRaisesRegexWithHighlight(RuntimeError, error_msg, "self.a"): 429 torch.jit.script(Mod()) 430 431 def test_empty_dict_override_contains(self): 432 class CustomModuleInterface(torch.nn.Module): 433 pass 434 435 class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict): 436 def __init__(self, modules=None): 437 CustomModuleInterface.__init__(self) 438 torch.nn.ModuleDict.__init__(self, modules) 439 440 class MyModule(torch.nn.Module): 441 def __init__(self) -> None: 442 super().__init__() 443 # work around aliasing issue for 'is' operator by scripting ReLU up front 444 self.submod = torch.jit.script(torch.nn.ReLU()) 445 self.moduledict = CustomModuleDict() 446 447 def forward(self, inputs): 448 assert ( 449 "submod" not in self.moduledict 450 ), "__contains__ fails for ModuleDict" 451 return inputs 452 453 m = MyModule() 454 self.checkModule(m, [torch.randn(2, 2)]) 455 456 def test_typed_module_dict(self): 457 """ 458 Test that a type annotation can be provided for a ModuleDict that allows 459 non-static indexing. 460 """ 461 462 @torch.jit.interface 463 class ModuleInterface(torch.nn.Module): 464 def forward(self, inp: Any) -> Any: 465 pass 466 467 class ImplementsInterface(torch.nn.Module): 468 def forward(self, inp: Any) -> Any: 469 if isinstance(inp, torch.Tensor): 470 return torch.max(inp, dim=0) 471 472 return inp 473 474 class DoesNotImplementInterface(torch.nn.Module): 475 def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 476 return torch.max(inp, dim=0) 477 478 # Test annotation of submodule. 479 class Mod(torch.nn.Module): 480 def __init__(self) -> None: 481 super().__init__() 482 self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) 483 484 def forward(self, x: torch.Tensor, key: str) -> Any: 485 value: ModuleInterface = self.d[key] 486 return value.forward(x) 487 488 m = Mod() 489 self.checkModule(m, (torch.randn(2, 2), "module")) 490 491 # Test annotation of self. 492 class ModDict(torch.nn.ModuleDict): 493 def __init__(self) -> None: 494 super().__init__({"module": ImplementsInterface()}) 495 496 def forward(self, x: torch.Tensor, key: str) -> Any: 497 submodule: ModuleInterface = self[key] 498 return submodule.forward(x) 499 500 m = ModDict() 501 self.checkModule(m, (torch.randn(2, 2), "module")) 502 503 # Test error message thrown when annotated attribute does not comply with the 504 # annotation. 505 class ModWithWrongAnnotation(torch.nn.ModuleDict): 506 def __init__(self) -> None: 507 super().__init__() 508 self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()}) 509 510 def forward(self, x: torch.Tensor, key: str) -> Any: 511 submodule: ModuleInterface = self.d[key] 512 return submodule.forward(x) 513 514 with self.assertRaisesRegexWithHighlight( 515 RuntimeError, r"Attribute module is not of annotated type", "self.d[key]" 516 ): 517 torch.jit.script(ModWithWrongAnnotation()) 518 519 def test_typed_module_list(self): 520 """ 521 Test that a type annotation can be provided for a ModuleList that allows 522 non-static indexing. 523 """ 524 525 @torch.jit.interface 526 class ModuleInterface(torch.nn.Module): 527 def forward(self, inp: Any) -> Any: 528 pass 529 530 class ImplementsInterface(torch.nn.Module): 531 def forward(self, inp: Any) -> Any: 532 if isinstance(inp, torch.Tensor): 533 return torch.max(inp, dim=0) 534 535 return inp 536 537 class DoesNotImplementInterface(torch.nn.Module): 538 def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 539 return torch.max(inp, dim=0) 540 541 # Test annotation of submodule. 542 class Mod(torch.nn.Module): 543 def __init__(self) -> None: 544 super().__init__() 545 self.l = torch.nn.ModuleList([ImplementsInterface()]) 546 547 def forward(self, x: torch.Tensor, idx: int) -> Any: 548 value: ModuleInterface = self.l[idx] 549 return value.forward(x) 550 551 m = Mod() 552 self.checkModule(m, (torch.randn(2, 2), 0)) 553 554 # Test annotation of self. 555 class ModList(torch.nn.ModuleList): 556 def __init__(self) -> None: 557 super().__init__([ImplementsInterface()]) 558 559 def forward(self, x: torch.Tensor, idx: int) -> Any: 560 submodule: ModuleInterface = self[idx] 561 return submodule.forward(x) 562 563 m = ModList() 564 self.checkModule(m, (torch.randn(2, 2), 0)) 565 566 # Test error message thrown when annotated attribute does not comply with the 567 # annotation. 568 class ModWithWrongAnnotation(torch.nn.ModuleList): 569 def __init__(self) -> None: 570 super().__init__() 571 self.l = torch.nn.ModuleList([DoesNotImplementInterface()]) 572 573 def forward(self, x: torch.Tensor, idx: int) -> Any: 574 submodule: ModuleInterface = self.l[idx] 575 return submodule.forward(x) 576 577 with self.assertRaisesRegexWithHighlight( 578 RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]" 579 ): 580 torch.jit.script(ModWithWrongAnnotation()) 581 582 def test_module_properties(self): 583 class ModuleWithProperties(torch.nn.Module): 584 __jit_unused_properties__ = ["ignored_attr"] 585 586 def __init__(self, a: int): 587 super().__init__() 588 self.a = a 589 590 def forward(self, a: int, b: int): 591 self.attr = a + b 592 return self.attr 593 594 @property 595 def attr(self): 596 return self.a 597 598 @property 599 def ignored_attr(self): 600 return sum([self.a]) 601 602 @torch.jit.unused 603 @property 604 def ignored_attr_2(self): 605 return sum([self.a]) 606 607 @ignored_attr_2.setter 608 def ignored_attr_2(self, value): 609 self.a = sum([self.a]) 610 611 @attr.setter 612 def attr(self, a: int): 613 if a > 0: 614 self.a = a 615 else: 616 self.a = 0 617 618 class ModuleWithNoSetter(torch.nn.Module): 619 def __init__(self, a: int): 620 super().__init__() 621 self.a = a 622 623 def forward(self, a: int, b: int): 624 self.attr + a + b 625 626 @property 627 def attr(self): 628 return self.a + 1 629 630 self.checkModule( 631 ModuleWithProperties(5), 632 ( 633 5, 634 6, 635 ), 636 ) 637 self.checkModule( 638 ModuleWithProperties(5), 639 ( 640 -5, 641 -6, 642 ), 643 ) 644 self.checkModule( 645 ModuleWithNoSetter(5), 646 ( 647 5, 648 6, 649 ), 650 ) 651 self.checkModule( 652 ModuleWithNoSetter(5), 653 ( 654 -5, 655 -6, 656 ), 657 ) 658 659 mod = ModuleWithProperties(3) 660 scripted_mod = torch.jit.script(mod) 661 662 with self.assertRaisesRegex(AttributeError, "has no attribute"): 663 scripted_mod.ignored_attr 664 665 def test_module_inplace_construct(self): 666 class M(nn.Module): 667 def __init__(self, start: int): 668 super().__init__() 669 self.linear = nn.Linear(3, 3) 670 self.attribute = start 671 self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float)) 672 673 def method(self) -> int: 674 return self.attribute 675 676 @torch.jit.unused 677 def unused_method(self): 678 return self.attribute + self.attribute 679 680 def forward(self, x): 681 return self.linear(self.linear(x)) 682 683 class N(nn.Module): 684 def __init__(self) -> None: 685 super().__init__() 686 self.linear = nn.Linear(4, 4) 687 688 @torch.jit.ignore 689 def ignored_method(self, x): 690 return x 691 692 def forward(self, x): 693 return self.linear(x) 694 695 m = torch.jit.script(M(3)) 696 n = torch.jit.script(N()) 697 698 n._reconstruct(m._c) 699 700 inp = torch.rand((3)) 701 702 # Check that both modules produce the same output. 703 with torch.no_grad(): 704 m_out = m(inp) 705 n_out = n(inp) 706 self.assertEqual(m_out, n_out) 707 708 # Check that ignored method is still intact. 709 self.assertEqual(inp, n.ignored_method(inp)) 710 711 def test_parameterlist_script_getitem(self): 712 class MyModule(nn.Module): 713 def __init__(self) -> None: 714 super().__init__() 715 self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)]) 716 self.parameter_list = nn.ParameterList( 717 [nn.Parameter(torch.zeros(1)) for _ in range(10)] 718 ) 719 720 def forward(self, x): 721 self.module_list[0] 722 self.parameter_list[0] 723 return x 724 725 self.checkModule(MyModule(), (torch.zeros(1))) 726 727 def test_parameterlist_script_iter(self): 728 class MyModule(nn.Module): 729 def __init__(self) -> None: 730 super().__init__() 731 self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)]) 732 self.parameter_list = nn.ParameterList( 733 [nn.Parameter(torch.zeros(1)) for _ in range(10)] 734 ) 735 736 def forward(self, x): 737 r = x 738 for i, p in enumerate(self.parameter_list): 739 r = r + p + i 740 return r 741 742 self.checkModule(MyModule(), (torch.zeros(1),)) 743 744 def test_parameterdict_script_getitem(self): 745 class MyModule(nn.Module): 746 def __init__(self) -> None: 747 super().__init__() 748 self.parameter_dict = nn.ParameterDict( 749 {k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]} 750 ) 751 752 def forward(self, x): 753 return ( 754 self.parameter_dict["a"] * x 755 + self.parameter_dict["b"] * self.parameter_dict["c"] 756 ) 757 758 self.checkModule(MyModule(), (torch.ones(1),)) 759