1# Owner(s): ["oncall: jit"] 2 3import os 4import re 5import sys 6import types 7import typing 8import typing_extensions 9from collections import OrderedDict 10from typing import Dict, List, Optional, Tuple 11 12import torch 13import torch.jit.frontend 14import torch.nn as nn 15from torch import Tensor 16from torch.testing import FileCheck 17 18 19# Make the helper files in test/ importable 20pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 21sys.path.append(pytorch_test_dir) 22from torch.testing._internal.jit_utils import ( 23 _tmp_donotuse_dont_inline_everything, 24 JitTestCase, 25) 26 27 28if __name__ == "__main__": 29 raise RuntimeError( 30 "This test file is not meant to be run directly, use:\n\n" 31 "\tpython test/test_jit.py TESTNAME\n\n" 32 "instead." 33 ) 34 35 36class TestRecursiveScript(JitTestCase): 37 def test_inferred_nonetype(self): 38 class M(nn.Module): 39 def __init__(self) -> None: 40 super().__init__() 41 self.x = None 42 43 def forward(self): 44 assert self.x is None 45 46 m = torch.jit.script(M()) 47 self.checkModule(M(), ()) 48 49 def test_script_function_attribute(self): 50 @torch.jit.script 51 def fn1(x): 52 return x + x 53 54 @torch.jit.script 55 def fn2(x): 56 return x - x 57 58 class M(torch.nn.Module): 59 def __init__(self, fn): 60 super().__init__() 61 self.fn = fn 62 63 def forward(self, x): 64 return self.fn(x) 65 66 fn1_mod = M(fn1) 67 fn2_mod = M(fn2) 68 69 self.checkModule(fn1_mod, (torch.randn(2, 2),)) 70 self.checkModule(fn2_mod, (torch.randn(2, 2),)) 71 72 def test_python_function_attribute(self): 73 class M(torch.nn.Module): 74 def __init__(self, fn): 75 super().__init__() 76 self.fn = fn 77 78 def forward(self, x): 79 return self.fn(x) 80 81 mod = M(torch.sigmoid) 82 83 self.checkModule(mod, (torch.randn(2, 2),)) 84 85 def test_failed_function_compilation(self): 86 def fn(x): 87 return i_dont_exist # noqa: F821 88 89 class M(torch.nn.Module): 90 def __init__(self, fn): 91 super().__init__() 92 self.fn = fn 93 94 def forward(self, x): 95 return self.fn(x) 96 97 m = M(fn) 98 with self.assertRaisesRegexWithHighlight( 99 RuntimeError, "failed to compile", "i_dont_exist" 100 ): 101 torch.jit.script(m) 102 103 def test_init_error(self): 104 class M(nn.Module): 105 def __init__(self) -> None: 106 self.x = 2 107 108 def forward(self): 109 pass 110 111 with self.assertRaisesRegex(RuntimeError, "has not been initialized"): 112 torch.jit.script(M()) 113 114 def test_script_after_eval(self): 115 class M(nn.Module): 116 def forward(self): 117 if self.training: 118 return 2 119 else: 120 return 0 121 122 m = M() 123 sm1 = torch.jit.script(m) 124 m.eval() 125 sm2 = torch.jit.script(m) 126 127 # m is in eval mode, training should be False 128 self.assertFalse(m.training) 129 130 # sm1 was created while m had training = True 131 self.assertTrue(sm1.training) 132 self.assertEqual(sm1.training, sm1._c.getattr("training")) 133 self.assertEqual(sm1(), 2) 134 135 # sm2 was created after m was eval'ed 136 self.assertFalse(sm2.training) 137 self.assertEqual(sm2.training, sm2._c.getattr("training")) 138 self.assertEqual(sm2(), 0) 139 140 def test_module_name(self): 141 class MyModule(torch.nn.Module): 142 def __init__(self) -> None: 143 super().__init__() 144 self.x = 2 145 146 def forward(self, t): 147 return t + self.x 148 149 m = torch.jit.script(MyModule()) 150 FileCheck().check("MyModule").run(m.graph) 151 152 def test_repeated_error_stack(self): 153 def d(x): 154 return "a" - 2 155 156 def c(x): 157 return d(x) 158 159 def b(x): 160 return c(x) 161 162 def a(x): 163 return b(x) 164 165 try: 166 torch.jit.script(a) 167 except Exception as e: 168 FileCheck().check_count("is being compiled", 2).run(str(e)) 169 170 try: 171 torch.jit.script(a) 172 except Exception as e: 173 # Make sure that no entries are left over from the previous failure 174 FileCheck().check_count("is being compiled", 2).run(str(e)) 175 176 def test_constants_with_final(self): 177 class M1(torch.nn.Module): 178 x: torch.jit.Final[int] 179 180 def __init__(self) -> None: 181 super().__init__() 182 self.x = 2 183 184 def forward(self, t): 185 return t + self.x 186 187 self.checkModule(M1(), (torch.randn(2, 2),)) 188 189 class M2(torch.nn.Module): 190 x: typing_extensions.Final[int] 191 192 def __init__(self) -> None: 193 super().__init__() 194 self.x = 2 195 196 def forward(self, t): 197 return t + self.x 198 199 self.checkModule(M2(), (torch.randn(2, 2),)) 200 201 class M3(torch.nn.Module): 202 x: typing.Final[int] 203 204 def __init__(self) -> None: 205 super().__init__() 206 self.x = 2 207 208 def forward(self, t): 209 return t + self.x 210 211 self.checkModule(M3(), (torch.randn(2, 2),)) 212 213 def test_ignore_class(self): 214 @torch.jit.ignore 215 class MyScriptClass: 216 def unscriptable(self): 217 return "a" + 200 218 219 class TestModule(torch.nn.Module): 220 def forward(self, x): 221 return MyScriptClass() 222 223 with self.assertRaisesRegexWithHighlight( 224 torch.jit.frontend.FrontendError, 225 "Cannot instantiate class", 226 "MyScriptClass", 227 ): 228 t = torch.jit.script(TestModule()) 229 230 def test_method_call(self): 231 class M(nn.Module): 232 def test(self, x): 233 return x 234 235 def forward(self, z): 236 y = self.test(z) 237 return z + 20 + y 238 239 self.checkModule(M(), (torch.randn(2, 2),)) 240 241 def test_module_repr(self): 242 class Submodule(nn.Module): 243 def forward(self, x): 244 return x 245 246 class MyModule(nn.Module): 247 def __init__(self) -> None: 248 super().__init__() 249 self.conv = nn.Conv2d(10, 10, 3) 250 self.lin = nn.Linear(10, 10) 251 self.sub = Submodule() 252 253 def forward(self, x): 254 return self.lin(x) + self.sub(x) + self.conv(x) 255 256 m = torch.jit.script(MyModule()) 257 258 with self.capture_stdout() as out: 259 print(m) 260 261 f = FileCheck() 262 f.check("MyModule") 263 f.check("Conv2d") 264 f.check("Linear") 265 f.check("Submodule") 266 f.run(out[0]) 267 268 self.assertEqual(m.original_name, "MyModule") 269 270 def test_dir(self): 271 def test_module_dir(mod): 272 dir_set = dir(mod) 273 scripted_mod = torch.jit.script(mod) 274 dir_scripted = set(dir(scripted_mod)) 275 # set not currently copied over 276 ignore_set = [ 277 "training", 278 "__delitem__", 279 "__setitem__", 280 "clear", 281 "items", 282 "keys", 283 "pop", 284 "update", 285 "values", 286 ] 287 for attr in dir_set: 288 if attr in ignore_set: 289 continue 290 self.assertTrue(attr in dir_scripted, attr) 291 292 class MyModule(nn.Module): 293 def __init__(self) -> None: 294 super().__init__() 295 self.conv = nn.Conv2d(10, 10, 3) 296 self.lin = nn.Linear(10, 10) 297 298 def forward(self, x): 299 return self.lin(x) + self.conv(x) 300 301 test_module_dir(MyModule()) 302 303 # test custom __dir__ for containers 304 conv = nn.Conv2d(10, 10, 3) 305 linear = nn.Linear(10, 10) 306 307 test_module_dir(nn.Sequential(conv, linear)) 308 test_module_dir( 309 nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])) 310 ) 311 312 def test_class_compile(self): 313 def other_fn(a: int, b: Tensor) -> Tensor: 314 return a * b 315 316 class B: 317 def __init__(self, x): 318 self.x = 2 319 320 def helper(self, a): 321 return self.x + a + other_fn(self.x, a) 322 323 class N(torch.nn.Module): 324 def forward(self, x): 325 b = B(x) 326 return b.helper(x) 327 328 self.checkModule(N(), (torch.randn(2, 2),)) 329 330 def test_error_stack(self): 331 def d(x: int) -> int: 332 return x + 10 333 334 def c(x): 335 return d("hello") + d(x) 336 337 def b(x): 338 return c(x) 339 340 def a(x): 341 return b(x) 342 343 try: 344 scripted = torch.jit.script(a) 345 except RuntimeError as e: 346 checker = FileCheck() 347 checker.check("Expected a value of type 'int'") 348 checker.check("def c(x)") 349 checker.check("def b(x)") 350 checker.check("def a(x)") 351 checker.run(str(e)) 352 353 def test_error_stack_module(self): 354 def d(x: int) -> int: 355 return x + 10 356 357 def c(x): 358 return d("hello") + d(x) 359 360 def b(x): 361 return c(x) 362 363 class Submodule(torch.nn.Module): 364 def forward(self, x): 365 return b(x) 366 367 class M(torch.nn.Module): 368 def __init__(self) -> None: 369 super().__init__() 370 self.submodule = Submodule() 371 372 def some_method(self, y): 373 return y + self.submodule(y) 374 375 def forward(self, x): 376 return self.some_method(x) 377 378 try: 379 scripted = torch.jit.script(M()) 380 except RuntimeError as e: 381 checker = FileCheck() 382 checker.check("Expected a value of type 'int'") 383 checker.check("'c' is being compiled since it was called from 'b'") 384 checker.check("'b' is being compiled since it was called from") 385 checker.run(str(e)) 386 387 @_tmp_donotuse_dont_inline_everything 388 def test_script_basic(self): 389 def a_python_fn(a, b, c): 390 return a + b + c 391 392 @torch.jit.script 393 def a_script_fn(d, e, f): 394 return a_python_fn(d, e, f) 395 396 graph = str(a_script_fn.graph) 397 FileCheck().check("prim::CallFunction").run(graph) 398 FileCheck().check_not("^a_python_fn").run(graph) 399 t = torch.ones(2, 2) 400 self.assertEqual(a_script_fn(t, t, t), t + t + t) 401 402 def test_error_stack_class(self): 403 class X: 404 def bad_fn(self): 405 import pdb # noqa: F401 406 407 def fn(x) -> X: 408 return X(10) 409 410 try: 411 torch.jit.script(fn) 412 except Exception as e: 413 checker = FileCheck() 414 checker.check("import statements") 415 checker.check("is being compiled since it was called from") 416 checker.run(str(e)) 417 418 def test_error_stack_annotation(self): 419 class X: 420 def bad_fn(self): 421 import pdb # noqa: F401 422 423 def fn(x) -> X: 424 return X(10) 425 426 try: 427 torch.jit.script(fn) 428 except Exception as e: 429 checker = FileCheck() 430 checker.check("import statements") 431 checker.check("is being compiled since it was called from") 432 checker.check("-> X") 433 checker.run(str(e)) 434 435 def test_module_basic(self): 436 class Other(torch.nn.Module): 437 __constants__ = ["x"] 438 439 def __init__(self, x): 440 super().__init__() 441 self.x = x 442 self.param = torch.nn.Parameter(torch.ones(2, 2)) 443 444 def some_unscriptable_method(self): 445 a = 2 446 a = [2] 447 return a 448 449 def forward(self, t): 450 return t + self.x + self.param 451 452 class M(torch.nn.Module): 453 def __init__(self) -> None: 454 super().__init__() 455 self.other = Other(200) 456 457 def forward(self, t): 458 return self.other(t) * 2 459 460 self.checkModule(M(), (torch.ones(2, 2),)) 461 462 def test_module_function_export(self): 463 class Other(torch.nn.Module): 464 __constants__ = ["x"] 465 466 def __init__(self, x): 467 super().__init__() 468 self.x = x 469 self.param = torch.nn.Parameter(torch.ones(2, 2)) 470 471 @torch.jit.export 472 def some_entry_point(self, y): 473 return y + 20 474 475 def forward(self, t): 476 return t + self.x + self.param 477 478 class M(torch.nn.Module): 479 def __init__(self) -> None: 480 super().__init__() 481 self.other = Other(200) 482 483 def forward(self, t): 484 return self.other(t) * 2 485 486 self.checkModule(M(), (torch.ones(2, 2),)) 487 488 def test_iterable_modules(self): 489 class Inner(torch.nn.Module): 490 def forward(self, x): 491 return x + 10 492 493 class M(torch.nn.Module): 494 def __init__(self) -> None: 495 super().__init__() 496 self.sequential = nn.Sequential( 497 Inner(), Inner(), nn.Sequential(Inner(), Inner()) 498 ) 499 self.module_list = nn.ModuleList([Inner(), Inner()]) 500 501 def forward(self, x): 502 for mod in self.module_list: 503 x += mod(x) 504 x += self.sequential(x) 505 return x 506 507 self.checkModule(M(), (torch.randn(5, 5),)) 508 509 def test_prepare_scriptable_basic(self): 510 class SeluButReluWhenScripted(torch.nn.SELU): 511 def __prepare_scriptable__(self): 512 return nn.ReLU() 513 514 t = torch.randn(5, 5) 515 m = SeluButReluWhenScripted() 516 sm = torch.jit.script(m) 517 eager_out = m(t) 518 script_out = sm(t) 519 self.assertNotEqual(eager_out, script_out) 520 521 def test_prepare_scriptable_iterable_modules(self): 522 class SeluButReluWhenScripted(torch.nn.SELU): 523 def __prepare_scriptable__(self): 524 return nn.ReLU() 525 526 class M(torch.nn.Module): 527 def __init__(self) -> None: 528 super().__init__() 529 shared = SeluButReluWhenScripted() 530 self.sequential = nn.Sequential( 531 SeluButReluWhenScripted(), 532 SeluButReluWhenScripted(), 533 nn.Sequential( 534 SeluButReluWhenScripted(), shared, SeluButReluWhenScripted() 535 ), 536 shared, 537 ) 538 self.module_list = nn.ModuleList( 539 [SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()] 540 ) 541 542 def forward(self, x): 543 for mod in self.module_list: 544 x += mod(x) 545 x += self.sequential(x) 546 return x 547 548 t = torch.randn(5, 5) 549 m = M() 550 eager_out = m(t.clone()) 551 sm = torch.jit.script(m) 552 script_out = sm(t.clone()) 553 self.assertNotEqual(eager_out, script_out) 554 555 def test_prepare_scriptable_cycle(self): 556 t = torch.randn(5, 5) 557 c = torch.nn.Module() 558 p = torch.nn.Module() 559 c.__dict__["_p"] = p 560 p.__dict__["_c"] = c 561 562 sm = torch.jit.script(p) 563 564 def test_prepare_scriptable_escape_hatch(self): 565 class NonJitableClass: 566 def __call__(self, int1, int2, *args): 567 total = int1 + int2 568 for arg in args: 569 total += arg 570 return total 571 572 obj = NonJitableClass() 573 574 self.assertEqual(obj(1, 2), 3) 575 self.assertEqual(obj(1, 2, 3, 4), 10) 576 with self.assertRaisesRegex( 577 torch.jit.frontend.NotSupportedError, 578 expected_regex="can't take variable number of arguments", 579 ): 580 torch.jit.script(obj) 581 582 def escape_hatch(int1: int, int2: int) -> int: 583 return int1 + int2 584 585 class NonJitableClassWithEscapeHatch(NonJitableClass): 586 def __prepare_scriptable__(self): 587 return escape_hatch 588 589 jit_obj = torch.jit.script(NonJitableClassWithEscapeHatch()) 590 591 self.assertEqual(jit_obj(1, 2), 3) 592 with self.assertRaisesRegex( 593 RuntimeError, 594 expected_regex=re.escape( 595 "expected at most 2 argument(s) but received 4 argument(s)" 596 ), 597 ): 598 jit_obj(1, 2, 3, 4) 599 600 def test_attributes(self): 601 @torch.jit.script 602 class Inner2: 603 def __init__(self) -> None: 604 self.b = "a string" 605 606 @torch.jit.script 607 class Foo: 608 def __init__(self) -> None: 609 self.a = 4 610 self.inner = Inner2() 611 612 @torch.jit.script 613 class SFoo: 614 def __init__(self) -> None: 615 self.a = 4 616 self.inner = Inner2() 617 618 def __setstate__(self, obj: Tuple[int, Inner2]) -> None: 619 a, inner = obj 620 self.a = a 621 self.inner = inner 622 623 def __getstate__(self): 624 return (self.a, self.inner) 625 626 untyped_values = ( 627 ("my_dict", {"I": "am", "a test": "test"}), 628 ("my_float", 2.3), 629 ("my_int", 99), 630 ("my_bool", False), 631 ("my_tuple", (1, 2, 3, 4)), 632 ("my_list", [(1, 2), (3, 4)]), 633 # ('my_tensor', torch.randn(2, 2)), 634 ("my_int_list", [1, 2, 3, 4]), 635 # ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]), 636 ("my_bool_list", [True, True, False, True]), 637 ("my_float_list", [1.0, 2.0, 3.0, 4.0]), 638 ("my_str_list", ["hello", "bye"]), 639 ) 640 typed_values = ( 641 ("my_empty_list", []), 642 ("my_empty_dict", {}), 643 ("my_none", None), 644 ("my_object", Foo()), 645 ("my_object2", SFoo()), 646 ) 647 648 class M(torch.nn.Module): 649 # TODO: re-enable this once this test is in a Python 3-only syntax 650 # file 651 # my_empty_list : List[int] 652 # my_empty_dict : Dict[str, int] 653 # my_none : Optional[int] 654 655 def forward(self, x): 656 return ( 657 self.my_dict, 658 self.my_float, 659 self.my_int, 660 self.my_bool, 661 # self.my_tensor, 662 self.my_int_list, 663 # self.my_tensor_list, 664 self.my_bool_list, 665 self.my_float_list, 666 self.my_str_list, 667 self.my_empty_list, 668 self.my_empty_dict, 669 self.my_none, 670 self.my_object.a, 671 self.my_object.inner.b, 672 self.my_object.a, 673 self.my_object2.inner.b, 674 ) 675 676 # TODO: as a followup, fix this test 677 # We can't define class attributes like we should be doing: 678 # class M(torch.nn.Module): 679 # my_empty_list : List[int] 680 # my_empty_dict : Dict[str, int] 681 # my_none : Optional[int] 682 # my_out_of_line_attribute: List[int] = [1, 2, 3] 683 # since there's no string frontend for Python classes (so the `define`) 684 # trick doesn't work. 685 M.__annotations__ = { 686 "my_empty_list": List[int], 687 "my_empty_dict": Dict[str, int], 688 "my_none": Optional[int], 689 "my_object": Foo, 690 "my_object2": SFoo, 691 } 692 693 m = M() 694 for name, value in untyped_values + typed_values: 695 setattr(m, name, value) 696 697 self.checkModule(m, (torch.randn(5, 5),)) 698 699 def test_function_attribute_in_submodule(self): 700 class N(nn.Module): 701 def __init__(self, norm): 702 super().__init__() 703 self.activation = torch.nn.functional.relu 704 self.norm = norm 705 706 def forward(self, src): 707 output = src 708 output = self.norm(output) 709 return output 710 711 class M(nn.Module): 712 def __init__(self) -> None: 713 super().__init__() 714 encoder_norm = nn.ReLU() 715 self.encoder = N(encoder_norm) 716 717 def forward(self, x): 718 return self.encoder(x) 719 720 m = M() 721 self.checkModule(m, (torch.randn(5, 5),)) 722 723 def test_inner_traced_module(self): 724 class Dummy(nn.Module): 725 def forward(self, x): 726 return x 727 728 class Model(nn.Module): 729 def __init__(self, dummies): 730 super().__init__() 731 self._dummies = dummies 732 733 def forward(self, x): 734 out = [] 735 for dummy in self._dummies: 736 out.append(dummy(x)) 737 return out 738 739 dummy = torch.jit.trace(Dummy(), torch.randn(1, 2)) 740 dummies = nn.ModuleList([dummy]) 741 model = Model(dummies) 742 self.checkModule(model, (torch.rand(5, 5),)) 743 744 def test_script_loaded_module(self): 745 """ 746 Test that we can hold a loaded ScriptModule as a submodule. 747 """ 748 749 class Dummy(nn.Module): 750 def forward(self, x): 751 return x 752 753 dummy = torch.jit.script(Dummy()) 754 dummy = self.getExportImportCopy(dummy) 755 756 class ContainsLoaded(torch.nn.Module): 757 def __init__(self) -> None: 758 super().__init__() 759 self.encoder = dummy 760 761 def forward(self, input): 762 return self.encoder(input) 763 764 self.checkModule(ContainsLoaded(), (torch.rand(2, 3),)) 765 766 def test_optional_module(self): 767 class Dummy(nn.Module): 768 def __init__(self) -> None: 769 super().__init__() 770 self.foo = nn.Linear(2, 2) 771 772 def forward(self, x): 773 if self.foo is not None: 774 return self.foo(x) 775 return x 776 777 mod = Dummy() 778 self.checkModule(mod, (torch.rand(2, 2),)) 779 mod.foo = None 780 self.checkModule(mod, (torch.rand(2, 2),)) 781 782 def test_override_instance_method_ignore(self): 783 class M(torch.nn.Module): 784 @torch.jit.ignore 785 def i_am_ignored(self): 786 return "old" 787 788 m = M() 789 790 # Override the ignored method by binding a new method to this instance. 791 @torch.jit.ignore 792 def i_am_ignored(self): 793 return "new" 794 795 m.i_am_ignored = types.MethodType(i_am_ignored, m) 796 self.assertEqual(m.i_am_ignored(), "new") 797 798 # ScriptModule should correctly reflect the override. 799 s = torch.jit.script(m) 800 self.assertEqual(s.i_am_ignored(), "new") 801