1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from typing import Any, List 6 7import torch 8import torch.nn as nn 9from torch import Tensor 10from torch.testing._internal.jit_utils import JitTestCase, make_global 11 12 13# Make the helper files in test/ importable 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25class OrigModule(nn.Module): 26 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 27 return inp1 + inp2 + 1 28 29 def two(self, input: Tensor) -> Tensor: 30 return input + 2 31 32 def forward(self, input: Tensor) -> Tensor: 33 return input + self.one(input, input) + 1 34 35 36class NewModule(nn.Module): 37 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 38 return inp1 * inp2 + 1 39 40 def forward(self, input: Tensor) -> Tensor: 41 return self.one(input, input + 1) 42 43 44class TestModuleInterface(JitTestCase): 45 def test_not_submodule_interface_call(self): 46 @torch.jit.interface 47 class ModuleInterface(nn.Module): 48 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 49 pass 50 51 class TestNotModuleInterfaceCall(nn.Module): 52 proxy_mod: ModuleInterface 53 54 def __init__(self) -> None: 55 super().__init__() 56 self.proxy_mod = OrigModule() 57 58 def forward(self, input: Tensor) -> Tensor: 59 return self.proxy_mod.two(input) 60 61 with self.assertRaisesRegexWithHighlight( 62 RuntimeError, "object has no attribute or method", "self.proxy_mod.two" 63 ): 64 torch.jit.script(TestNotModuleInterfaceCall()) 65 66 def test_module_interface(self): 67 @torch.jit.interface 68 class OneTwoModule(nn.Module): 69 def one(self, x: Tensor, y: Tensor) -> Tensor: 70 pass 71 72 def two(self, x: Tensor) -> Tensor: 73 pass 74 75 def forward(self, x: Tensor) -> Tensor: 76 pass 77 78 @torch.jit.interface 79 class OneTwoClass: 80 def one(self, x: Tensor, y: Tensor) -> Tensor: 81 pass 82 83 def two(self, x: Tensor) -> Tensor: 84 pass 85 86 class FooMod(nn.Module): 87 def one(self, x: Tensor, y: Tensor) -> Tensor: 88 return x + y 89 90 def two(self, x: Tensor) -> Tensor: 91 return 2 * x 92 93 def forward(self, x: Tensor) -> Tensor: 94 return self.one(self.two(x), x) 95 96 class BarMod(nn.Module): 97 def one(self, x: Tensor, y: Tensor) -> Tensor: 98 return x * y 99 100 def two(self, x: Tensor) -> Tensor: 101 return 2 / x 102 103 def forward(self, x: Tensor) -> Tensor: 104 return self.two(self.one(x, x)) 105 106 @torch.jit.export 107 def forward2(self, x: Tensor) -> Tensor: 108 return self.two(self.one(x, x)) + 1 109 110 make_global(OneTwoModule, OneTwoClass) 111 112 def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): 113 return mod_list[0].forward(x) + mod_list[1].forward(x) 114 115 def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor: 116 return mod_list[0].two(x) + mod_list[1].one(x, x) 117 118 scripted_foo_mod = torch.jit.script(FooMod()) 119 scripted_bar_mod = torch.jit.script(BarMod()) 120 self.checkScript( 121 use_module_interface, 122 ( 123 [scripted_foo_mod, scripted_bar_mod], 124 torch.rand(3, 4), 125 ), 126 ) 127 self.checkScript( 128 use_class_interface, 129 ( 130 [scripted_foo_mod, scripted_bar_mod], 131 torch.rand(3, 4), 132 ), 133 ) 134 135 def call_module_interface_on_other_method( 136 mod_interface: OneTwoModule, x: Tensor 137 ) -> Tensor: 138 return mod_interface.forward2(x) 139 140 # ensure error out when we call the module on the method other than the interface specified. 141 with self.assertRaisesRegexWithHighlight( 142 RuntimeError, "object has no attribute or method", "mod_interface.forward2" 143 ): 144 self.checkScript( 145 call_module_interface_on_other_method, 146 ( 147 scripted_bar_mod, 148 torch.rand(3, 4), 149 ), 150 ) 151 152 def test_module_doc_string(self): 153 @torch.jit.interface 154 class TestInterface(nn.Module): 155 def one(self, inp1, inp2): 156 # type: (Tensor, Tensor) -> Tensor 157 pass 158 159 def forward(self, input): 160 # type: (Tensor) -> Tensor 161 r"""stuff 1""" 162 r"""stuff 2""" 163 pass # noqa: PIE790 164 r"""stuff 3""" 165 166 class TestModule(nn.Module): 167 proxy_mod: TestInterface 168 169 def __init__(self) -> None: 170 super().__init__() 171 self.proxy_mod = OrigModule() 172 173 def forward(self, input): 174 # type: (Tensor) -> Tensor 175 return self.proxy_mod.forward(input) 176 177 input = torch.randn(3, 4) 178 self.checkModule(TestModule(), (input,)) 179 180 def test_module_interface_subtype(self): 181 @torch.jit.interface 182 class OneTwoModule(nn.Module): 183 def one(self, x: Tensor, y: Tensor) -> Tensor: 184 pass 185 186 def two(self, x: Tensor) -> Tensor: 187 pass 188 189 def forward(self, x: Tensor) -> Tensor: 190 pass 191 192 make_global(OneTwoModule) 193 194 @torch.jit.script 195 def as_module_interface(x: OneTwoModule) -> OneTwoModule: 196 return x 197 198 @torch.jit.script 199 class Foo: 200 def one(self, x: Tensor, y: Tensor) -> Tensor: 201 return x + y 202 203 def two(self, x: Tensor) -> Tensor: 204 return 2 * x 205 206 def forward(self, x: Tensor) -> Tensor: 207 return self.one(self.two(x), x) 208 209 # check class object is not a subtype of module interface 210 with self.assertRaisesRegex( 211 RuntimeError, "ScriptModule class can be subtype of module interface" 212 ): 213 as_module_interface(Foo()) 214 215 class WrongMod(nn.Module): 216 def two(self, x: int) -> int: 217 return 2 * x 218 219 def forward(self, x: Tensor) -> Tensor: 220 return x + torch.randn(3, self.two(3)) 221 222 scripted_wrong_mod = torch.jit.script(WrongMod()) 223 224 # wrong module that is not compatible with module interface 225 with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"): 226 as_module_interface(scripted_wrong_mod) 227 228 # Check that interface implementations can be contravariant in argument types and covariant in return type. 229 @torch.jit.interface 230 class TensorToAny(nn.Module): 231 def forward(self, input: torch.Tensor) -> Any: 232 pass 233 234 make_global(TensorToAny) 235 236 @torch.jit.script 237 def as_tensor_to_any(x: TensorToAny) -> TensorToAny: 238 return x 239 240 @torch.jit.interface 241 class AnyToAny(nn.Module): 242 def forward(self, input: Any) -> Any: 243 pass 244 245 make_global(AnyToAny) 246 247 @torch.jit.script 248 def as_any_to_any(x: AnyToAny) -> AnyToAny: 249 return x 250 251 class TensorToAnyImplA(nn.Module): 252 def forward(self, input: Any) -> Any: 253 return input 254 255 class TensorToAnyImplB(nn.Module): 256 def forward(self, input: Any) -> torch.Tensor: 257 return torch.tensor([1]) 258 259 class AnyToAnyImpl(nn.Module): 260 def forward(self, input: Any) -> torch.Tensor: 261 return torch.tensor([1]) 262 263 as_tensor_to_any(torch.jit.script(TensorToAnyImplA())) 264 as_tensor_to_any(torch.jit.script(TensorToAnyImplB())) 265 as_any_to_any(torch.jit.script(AnyToAnyImpl())) 266 267 def test_module_interface_inheritance(self): 268 with self.assertRaisesRegex( 269 RuntimeError, "does not support inheritance yet. Please directly" 270 ): 271 272 @torch.jit.interface 273 class InheritMod(nn.ReLU): 274 def three(self, x: Tensor) -> Tensor: 275 return 3 * x 276 277 def test_module_swap(self): 278 @torch.jit.interface 279 class ModuleInterface(nn.Module): 280 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 281 pass 282 283 def forward(self, input: Tensor) -> Tensor: 284 pass 285 286 class TestModule(nn.Module): 287 proxy_mod: ModuleInterface 288 289 def __init__(self) -> None: 290 super().__init__() 291 self.proxy_mod = OrigModule() 292 293 def forward(self, input: Tensor) -> Tensor: 294 return self.proxy_mod.forward(input) 295 296 scripted_mod = torch.jit.script(TestModule()) 297 input = torch.randn(3, 4) 298 self.assertEqual(scripted_mod(input), 3 * input + 2) 299 300 # module swap with module that have the same interface 301 scripted_mod.proxy_mod = torch.jit.script(NewModule()) 302 self.assertEqual(scripted_mod(input), input * (input + 1) + 1) 303 304 # module swap with non-scripted module should throw error 305 with self.assertRaisesRegex( 306 RuntimeError, "a ScriptModule with non-scripted module" 307 ): 308 scripted_mod.proxy_mod = NewModule() 309 310 def test_module_swap_wrong_module(self): 311 @torch.jit.interface 312 class ModuleInterface(nn.Module): 313 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 314 pass 315 316 def forward(self, input: Tensor) -> Tensor: 317 pass 318 319 class NewModuleWrong(nn.Module): 320 def forward(self, input: int) -> int: 321 return input + 1 322 323 class TestModule(nn.Module): 324 proxy_mod: ModuleInterface 325 326 def __init__(self) -> None: 327 super().__init__() 328 self.proxy_mod = OrigModule() 329 330 def forward(self, input: Tensor) -> Tensor: 331 return self.proxy_mod.forward(input) 332 333 scripted_mod = torch.jit.script(TestModule()) 334 # module swap with in-compatible interface 335 with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"): 336 scripted_mod.proxy_mod = torch.jit.script(NewModuleWrong()) 337 338 def test_module_swap_no_lazy_compile(self): 339 @torch.jit.interface 340 class ModuleInterface(nn.Module): 341 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 342 pass 343 344 def forward(self, input: Tensor) -> Tensor: 345 pass 346 347 class TestModule(nn.Module): 348 proxy_mod: ModuleInterface 349 350 def __init__(self) -> None: 351 super().__init__() 352 self.proxy_mod = OrigModule() 353 354 def forward(self, input: Tensor) -> Tensor: 355 return self.proxy_mod.forward(input) 356 357 class NewModuleMethodNotLazyCompile(nn.Module): 358 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 359 return inp1 * inp2 + 1 360 361 def forward(self, input: Tensor) -> Tensor: 362 return input + 1 363 364 scripted_mod = torch.jit.script(TestModule()) 365 # module swap with module that have the same interface, but the method not get 366 # lazily compiled from forward, user need to export it explicitly for swap to work 367 with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"): 368 scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodNotLazyCompile()) 369 370 class NewModuleMethodManualExport(nn.Module): 371 @torch.jit.export 372 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 373 return inp1 * inp2 + 1 374 375 def forward(self, input: Tensor) -> Tensor: 376 return input + 1 377 378 scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport()) 379 input = torch.randn(3, 4) 380 self.assertEqual(scripted_mod(input), input + 1) 381 382 def test_module_swap_no_module_interface(self): 383 # test module swapping with no module interface 384 class TestNoModuleInterface(nn.Module): 385 def __init__(self) -> None: 386 super().__init__() 387 self.proxy_mod = OrigModule() 388 389 def forward(self, input: Tensor) -> Tensor: 390 return self.proxy_mod(input) 391 392 scripted_no_module_interface = torch.jit.script(TestNoModuleInterface()) 393 # proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed. 394 scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule()) 395 # proxy_mod is neither a module interface or have the same JIT type, should fail 396 with self.assertRaisesRegex( 397 RuntimeError, 398 r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " 399 + r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'", 400 ): 401 scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule()) 402 403 def test_script_module_as_interface_swap(self): 404 @torch.jit.interface 405 class ModuleInterface(nn.Module): 406 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 407 pass 408 409 def forward(self, input: Tensor) -> Tensor: 410 pass 411 412 class OrigScriptModule(torch.jit.ScriptModule): 413 @torch.jit.script_method 414 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 415 return inp1 + inp2 + 1 416 417 @torch.jit.script_method 418 def forward(self, input: Tensor) -> Tensor: 419 return input + self.one(input, input) + 1 420 421 class NewScriptModule(torch.jit.ScriptModule): 422 @torch.jit.script_method 423 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 424 return inp1 * inp2 + 1 425 426 @torch.jit.script_method 427 def forward(self, input: Tensor) -> Tensor: 428 return self.one(input, input + 1) 429 430 class TestNNModuleWithScriptModule(nn.Module): 431 proxy_mod: ModuleInterface 432 433 def __init__(self) -> None: 434 super().__init__() 435 self.proxy_mod = OrigScriptModule() 436 437 def forward(self, input: Tensor) -> Tensor: 438 return self.proxy_mod.forward(input) 439 440 input = torch.randn(3, 4) 441 scripted_mod = torch.jit.script(TestNNModuleWithScriptModule()) 442 self.assertEqual(scripted_mod(input), 3 * input + 2) 443 444 scripted_mod.proxy_mod = NewScriptModule() 445 self.assertEqual(scripted_mod(input), input * (input + 1) + 1) 446 447 # The call to forward of proxy_mod cannot be inlined. Making sure 448 # Freezing is throwing an error for now. 449 def test_freeze_module_with_interface(self): 450 class SubModule(torch.nn.Module): 451 def __init__(self) -> None: 452 super().__init__() 453 self.b = 20 454 455 def forward(self, x): 456 return self.b 457 458 class OrigMod(torch.nn.Module): 459 def __init__(self) -> None: 460 super().__init__() 461 self.a = 0 462 463 def forward(self, x): 464 return self.a 465 466 @torch.jit.interface 467 class ModInterface(torch.nn.Module): 468 def forward(self, x: Tensor) -> int: 469 pass 470 471 class TestModule(torch.nn.Module): 472 proxy_mod: ModInterface 473 474 def __init__(self) -> None: 475 super().__init__() 476 self.proxy_mod = OrigMod() 477 self.sub = SubModule() # folded 478 479 def forward(self, x): 480 return self.proxy_mod(x) + self.sub(x) 481 482 m = torch.jit.script(TestModule()) 483 m.eval() 484 mf = torch._C._freeze_module(m._c) 485 # Assume interface has no aliasing 486 mf = torch._C._freeze_module(m._c, freezeInterfaces=True) 487 input = torch.tensor([1]) 488 out_s = m.forward(input) 489 out_f = mf.forward(input) 490 self.assertEqual(out_s, out_f) 491 492 def test_freeze_module_with_setattr_in_interface(self): 493 class SubModule(torch.nn.Module): 494 def __init__(self) -> None: 495 super().__init__() 496 self.b = 20 497 498 def forward(self, x): 499 self.b += 2 500 return self.b 501 502 @torch.jit.export 503 def getb(self, x): 504 return self.b 505 506 class OrigMod(torch.nn.Module): 507 def __init__(self) -> None: 508 super().__init__() 509 self.a = 0 510 511 def forward(self, x): 512 return self.a 513 514 @torch.jit.interface 515 class ModInterface(torch.nn.Module): 516 def forward(self, x: Tensor) -> int: 517 pass 518 519 class TestModule(torch.nn.Module): 520 proxy_mod: ModInterface 521 522 def __init__(self) -> None: 523 super().__init__() 524 self.proxy_mod = OrigMod() 525 self.sub = SubModule() 526 527 def forward(self, x): 528 return self.proxy_mod(x) + self.sub.getb(x) 529 530 m = torch.jit.script(TestModule()) 531 m.proxy_mod = m.sub 532 m.eval() 533 mf = torch._C._freeze_module(m._c, freezeInterfaces=True) 534 535 def test_freeze_module_with_inplace_mutation_in_interface(self): 536 class SubModule(torch.nn.Module): 537 def __init__(self) -> None: 538 super().__init__() 539 self.b = torch.tensor([1.5]) 540 541 def forward(self, x): 542 self.b[0] += 2 543 return self.b 544 545 @torch.jit.export 546 def getb(self, x): 547 return self.b 548 549 class OrigMod(torch.nn.Module): 550 def __init__(self) -> None: 551 super().__init__() 552 self.a = torch.tensor([0.5]) 553 554 def forward(self, x): 555 return self.a 556 557 @torch.jit.interface 558 class ModInterface(torch.nn.Module): 559 def forward(self, x: Tensor) -> Tensor: 560 pass 561 562 class TestModule(torch.nn.Module): 563 proxy_mod: ModInterface 564 565 def __init__(self) -> None: 566 super().__init__() 567 self.proxy_mod = OrigMod() 568 self.sub = SubModule() 569 570 def forward(self, x): 571 y = self.proxy_mod(x) 572 z = self.sub.getb(x) 573 return y[0] + z[0] 574 575 m = torch.jit.script(TestModule()) 576 m.proxy_mod = m.sub 577 m.sub.b = m.proxy_mod.b 578 m.eval() 579 mf = torch._C._freeze_module(m._c, freezeInterfaces=True) 580 581 def test_freeze_module_with_mutated_interface(self): 582 class SubModule(torch.nn.Module): 583 def __init__(self) -> None: 584 super().__init__() 585 self.b = torch.tensor([1.5]) 586 587 def forward(self, x): 588 return self.b 589 590 @torch.jit.export 591 def getb(self, x): 592 return self.b 593 594 class OrigMod(torch.nn.Module): 595 def __init__(self) -> None: 596 super().__init__() 597 self.a = torch.tensor([0.5]) 598 599 def forward(self, x): 600 return self.a 601 602 @torch.jit.interface 603 class ModInterface(torch.nn.Module): 604 def forward(self, x: Tensor) -> Tensor: 605 pass 606 607 class TestModule(torch.nn.Module): 608 proxy_mod: ModInterface 609 610 def __init__(self) -> None: 611 super().__init__() 612 self.proxy_mod = OrigMod() 613 self.sub = SubModule() 614 615 def forward(self, x): 616 self.proxy_mod = self.sub 617 y = self.proxy_mod(x) 618 z = self.sub.getb(x) 619 return y[0] + z[0] 620 621 m = torch.jit.script(TestModule()) 622 m.eval() 623 with self.assertRaisesRegex( 624 RuntimeError, "Freezing does not support SetAttr on an interface type." 625 ): 626 mf = torch._C._freeze_module(m._c, freezeInterfaces=True) 627 628 def test_freeze_module_with_interface_and_fork(self): 629 class SubModule(torch.nn.Module): 630 def __init__(self) -> None: 631 super().__init__() 632 self.b = torch.tensor([1.5]) 633 634 def forward(self, x): 635 self.b[0] += 3.2 636 return self.b 637 638 class OrigMod(torch.nn.Module): 639 def __init__(self) -> None: 640 super().__init__() 641 self.a = torch.tensor([0.5]) 642 643 def forward(self, x): 644 return self.a 645 646 @torch.jit.interface 647 class ModInterface(torch.nn.Module): 648 def forward(self, x: Tensor) -> Tensor: 649 pass 650 651 class TestModule(torch.nn.Module): 652 proxy_mod: ModInterface 653 654 def __init__(self) -> None: 655 super().__init__() 656 self.proxy_mod = OrigMod() 657 self.sub = SubModule() 658 659 def forward(self, x): 660 y = self.proxy_mod(x) 661 z = self.sub(x) 662 return y + z 663 664 class MainModule(torch.nn.Module): 665 def __init__(self) -> None: 666 super().__init__() 667 self.test = TestModule() 668 669 def forward(self, x): 670 fut = torch.jit._fork(self.test.forward, x) 671 y = self.test(x) 672 z = torch.jit._wait(fut) 673 return y + z 674 675 m = torch.jit.script(MainModule()) 676 m.eval() 677 mf = torch._C._freeze_module(m._c, freezeInterfaces=True) 678 679 def test_module_apis_interface(self): 680 @torch.jit.interface 681 class ModuleInterface(nn.Module): 682 def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: 683 pass 684 685 class TestModule(nn.Module): 686 proxy_mod: ModuleInterface 687 688 def __init__(self) -> None: 689 super().__init__() 690 self.proxy_mod = OrigModule() 691 692 def forward(self, input): 693 return input * 2 694 695 @torch.jit.export 696 def method(self, input): 697 for module in self.modules(): 698 input = module(input) 699 return input 700 701 with self.assertRaisesRegex(Exception, "Could not compile"): 702 scripted_mod = torch.jit.script(TestModule()) 703