1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6import unittest 7 8import torch 9import torch._C 10from torch.jit.mobile import _load_for_lite_interpreter 11from torch.testing import FileCheck 12from torch.testing._internal.common_utils import ( 13 find_library_location, 14 IS_FBCODE, 15 IS_MACOS, 16 IS_SANDCASTLE, 17 IS_WINDOWS, 18 skipIfRocm, 19 TEST_WITH_ROCM, 20) 21from torch.testing._internal.jit_utils import JitTestCase 22 23 24# Make the helper files in test/ importable 25pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 26sys.path.append(pytorch_test_dir) 27 28if __name__ == "__main__": 29 raise RuntimeError( 30 "This test file is not meant to be run directly, use:\n\n" 31 "\tpython test/test_jit.py TESTNAME\n\n" 32 "instead." 33 ) 34 35 36def to_test_backend(module, method_compile_spec): 37 return torch._C._jit_to_backend( 38 "test_backend", module, {"forward": method_compile_spec} 39 ) 40 41 42def to_test_backend_multi(module, method_compile_spec): 43 return torch._C._jit_to_backend("test_backend", module, method_compile_spec) 44 45 46def to_test_backend_selective(module, method_compile_spec, submodules): 47 def _to_test_backend(module): 48 return to_test_backend(module, method_compile_spec) 49 50 return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules) 51 52 53class BasicModule(torch.nn.Module): 54 """ 55 A simple Module used to test to_backend lowering machinery. 56 """ 57 58 def forward(self, x, h): 59 return self.accum(x, h), self.sub_accum(x, h) 60 61 def accum(self, x, h): 62 return x + h 63 64 def sub_accum(self, x, h): 65 return x - h 66 67 68# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. 69@unittest.skipIf( 70 TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, 71 "Non-portable load_library call used in test", 72) 73class JitBackendTestCase(JitTestCase): 74 """ 75 A common base class for JIT backend tests that contains common utility 76 functions for output comparison and serialization/deserialization. 77 """ 78 79 def setUp(self): 80 super().setUp() 81 lib_file_path = find_library_location("libjitbackend_test.so") 82 torch.ops.load_library(str(lib_file_path)) 83 # Subclasses are expected to set up three variables in their setUp methods: 84 # module - a regular, Python version of the module being tested 85 # scripted_module - a scripted version of module 86 # lowered_module - a version of module lowered to a backend 87 88 def check_function(self, function_name, input): 89 """ 90 Check that the function named 'function_name' produces the same output using 91 Python, regular JIT and the backend for the given 'input'. 92 """ 93 # Get handles for Python, JIT and backend methods. 94 python_method = self.module.__getattribute__(function_name) 95 jit_method = self.scripted_module.__getattr__(function_name) 96 backend_method = self.lowered_module.__getattr__(function_name) 97 98 # Run methods. 99 python_output = python_method(*input) 100 jit_output = jit_method(*input) 101 backend_output = backend_method(*input) 102 103 # The answers returned by Python, JIT and to_backend should all match. 104 self.assertEqual(python_output, backend_output) 105 self.assertEqual(jit_output, backend_output) 106 107 def save_load(self): 108 """ 109 Save and load the lowered module. 110 """ 111 self.lowered_module = self.getExportImportCopy(self.lowered_module) 112 113 def test_execution(self): 114 """ 115 Stub for correctness tests. 116 """ 117 118 def test_save_load(self): 119 """ 120 Stub for serialization tests. 121 """ 122 123 def test_errors(self): 124 """ 125 Stub for testing error checking. 126 """ 127 128 129class BasicModuleTest(JitBackendTestCase): 130 """ 131 Tests for BasicModule. 132 """ 133 134 def setUp(self): 135 super().setUp() 136 # Create Python, JIT and backend versions of BasicModule. 137 self.module = BasicModule() 138 self.scripted_module = torch.jit.script(BasicModule()) 139 self.lowered_module = to_test_backend_multi( 140 self.scripted_module, 141 {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, 142 ) 143 144 def test_execution(self): 145 # Test execution with backend against Python and JIT. 146 input = torch.randn(5) 147 148 # Test all three module methods. 149 self.check_function("accum", (input, input)) 150 self.check_function("sub_accum", (input, input)) 151 self.check_function("forward", (input, input)) 152 153 @skipIfRocm 154 def test_save_load(self): 155 # Lowered module should produce the same outputs. 156 self.test_execution() 157 158 # Save the compile spec to compare against the version retrieved after loading. 159 pre_compile_spec = self.lowered_module.__getattr__( 160 "__loweredModule__" 161 ).__getattr__("__method_compile_spec") 162 163 # Save and load the lowered module. 164 self.save_load() 165 166 # Get the compile spec after loading. 167 post_compile_spec = self.lowered_module.__getattr__( 168 "__loweredModule__" 169 ).__getattr__("__method_compile_spec") 170 171 # Compile specs should match. 172 self.assertEqual(pre_compile_spec, post_compile_spec) 173 174 # Loaded module should produce the same outputs. 175 self.test_execution() 176 177 178class BasicModuleUnavailableTest(JitBackendTestCase): 179 """ 180 Tests for BasicModule with a backend that is not available. 181 Fundamentally: 182 * _jit_to_backend is successful. 183 * Execution fails with an exception. 184 * Saving is successful. 185 * Loading fails with an exception. 186 """ 187 188 def setUp(self): 189 super().setUp() 190 # Create Python, JIT and backend versions of BasicModule. 191 self.module = BasicModule() 192 self.scripted_module = torch.jit.script(BasicModule()) 193 self.lowered_module = torch._C._jit_to_backend( 194 "test_backend_unavailable", 195 self.scripted_module, 196 {"forward": {"": ""}}, 197 ) 198 199 def test_execution(self): 200 # Test execution with backend fails because the backend that is not available. 201 input = torch.randn(5) 202 203 # Test exception is thrown. 204 with self.assertRaisesRegexWithHighlight( 205 Exception, 206 r"Backend is not available.", 207 'raise Exception("Backend is not available."', 208 ): 209 backend_method = self.lowered_module.__getattr__("forward") 210 backend_output = backend_method(*(input, input)) 211 212 @skipIfRocm 213 def test_save_load(self): 214 # Test that saving the lowered module is OK but loading fails because the backend is not available. 215 buffer = io.BytesIO() 216 torch.jit.save(self.lowered_module, buffer) 217 buffer.seek(0) 218 with self.assertRaisesRegexWithHighlight( 219 Exception, 220 r"Backend is not available.", 221 'raise Exception("Backend is not available."', 222 ): 223 imported = torch.jit.load(buffer) 224 225 226class NestedModuleTest(JitBackendTestCase): 227 """ 228 Tests for NestedModule that check that a module lowered to a backend can be used 229 as a submodule. 230 """ 231 232 class NestedModule(torch.nn.Module): 233 """ 234 A Module with one submodule that is used to test that lowered Modules 235 can be used as submodules. 236 """ 237 238 def __init__(self, submodule): 239 super().__init__() 240 self.submodule = submodule 241 242 def forward(self, x, h): 243 return self.submodule.forward(x, h) 244 245 def setUp(self): 246 super().setUp() 247 # Create Python, JIT and backend versions of NestedModule. 248 # Both modules in self.module are regular Python modules. 249 self.module = NestedModuleTest.NestedModule(BasicModule()) 250 # Both modules in self.scripted_module are ScriptModules. 251 self.scripted_module = torch.jit.script( 252 NestedModuleTest.NestedModule(BasicModule()) 253 ) 254 255 # First, script another instance of NestedModule with share_types=False so that it can be 256 # selectively lowered without modifying the type of self.scripted_module. 257 lowered_module = to_test_backend_multi( 258 torch.jit.script(BasicModule()), 259 {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, 260 ) 261 # self.lowered_module is a ScriptModule, but its submodule is a lowered module. 262 self.lowered_module = torch.jit.script( 263 NestedModuleTest.NestedModule(lowered_module) 264 ) 265 266 def test_execution(self): 267 # Test execution with backend against Python and JIT. 268 input = torch.randn(5) 269 270 # Test forward. 271 self.check_function("forward", (input, input)) 272 273 def test_save_load(self): 274 # Lowered module should produce the same outputs. 275 self.test_execution() 276 277 # Save and load the lowered module. 278 self.save_load() 279 280 # Loaded module should produce the same outputs. 281 self.test_execution() 282 283 284class SelectiveLoweringTest(JitBackendTestCase): 285 """ 286 Tests for the selective lowering API. 287 """ 288 289 class OuterModule(torch.nn.Module): 290 def __init__(self, sub1, sub2, other): 291 super().__init__() 292 self.sub1 = sub1 293 self.sub2 = sub2 294 self.other = other 295 296 def forward(self, x, y): 297 # Call the module that will be lowered directly to test 298 # type remapping in modules that are not its parent. 299 a, b = self.sub1.submodule.forward(x, y) 300 c, d = self.sub2.forward(x, y) 301 e, f = self.other.forward(x, y) 302 return a + c + e, b + d + f 303 304 class MiddleModule(torch.nn.Module): 305 def __init__(self, submodule): 306 super().__init__() 307 self.submodule = submodule 308 309 def forward(self, x, y): 310 return self.submodule.forward(x, y) 311 312 def setUp(self): 313 super().setUp() 314 OuterModule = SelectiveLoweringTest.OuterModule 315 MiddleModule = SelectiveLoweringTest.MiddleModule 316 317 def script_without_type_sharing(mod): 318 return torch.jit._recursive.create_script_module( 319 mod, torch.jit._recursive.infer_methods_to_compile, share_types=False 320 ) 321 322 # Create Python, JIT and backend versions of a hierarchy that looks like this: 323 # --------- OuterModule -------- 324 # | | | 325 # MiddleModule MiddleModule MiddleModule 326 # | | | 327 # BasicModule BasicModule BasicModule 328 # 329 # Two BasicModules will be lowered and the third will not. 330 self.module = OuterModule( 331 MiddleModule(BasicModule()), 332 MiddleModule(BasicModule()), 333 MiddleModule(BasicModule()), 334 ) 335 self.scripted_module = script_without_type_sharing( 336 OuterModule( 337 MiddleModule(BasicModule()), 338 MiddleModule(BasicModule()), 339 MiddleModule(BasicModule()), 340 ) 341 ) 342 self.lowered_module = script_without_type_sharing( 343 OuterModule( 344 MiddleModule(BasicModule()), 345 MiddleModule(BasicModule()), 346 MiddleModule(BasicModule()), 347 ) 348 ) 349 self.lowered_module = to_test_backend_selective( 350 self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"] 351 ) 352 353 def test_execution(self): 354 input = torch.randn(5) 355 self.check_function("forward", (input, input)) 356 357 self.test_selective_lowering_type_remap() 358 359 def test_save_load(self): 360 self.test_execution() 361 self.save_load() 362 self.test_execution() 363 364 self.test_selective_lowering_type_remap() 365 366 def test_selective_lowering_type_remap(self): 367 """ 368 Check that type remapping and replacement occurred during selective lowering. 369 """ 370 # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it 371 # calling the lowered module directly. 372 FileCheck().check("OuterModule").check("BasicModule").run( 373 self.scripted_module.graph 374 ) 375 FileCheck().check("OuterModule").check_not( 376 "__torch__.torch.classes.__backends__.test_backend" 377 ).check("LoweredWrapper.test_backend").run(self.lowered_module.graph) 378 379 # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs. 380 FileCheck().check("MiddleModule").check("BasicModule").check_not( 381 "LoweredWrapper.test_backend" 382 ).run(self.scripted_module.sub1.graph) 383 FileCheck().check("MiddleModule").check_not( 384 "__torch__.torch.classes.__backends__.test_backend" 385 ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph) 386 387 FileCheck().check("MiddleModule").check("BasicModule").check_not( 388 "LoweredWrapper.test_backend" 389 ).run(self.scripted_module.sub2.graph) 390 FileCheck().check("MiddleModule").check_not( 391 "__torch__.torch.classes.__backends__.test_backend" 392 ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph) 393 394 # Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute 395 # __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend, 396 # the TorchBind class for executing functions on the test JIT backend. 397 FileCheck().check("LoweredModule.test_backend").check( 398 "__torch__.torch.classes.__backends__.test_backend" 399 ).run(self.lowered_module.sub1.submodule.__loweredModule__.graph) 400 401 FileCheck().check("LoweredModule.test_backend").check( 402 "__torch__.torch.classes.__backends__.test_backend" 403 ).run(self.lowered_module.sub2.submodule.__loweredModule__.graph) 404 405 # Check that self.other and self.other.submodule have been left untouched by the selective lowering process. 406 FileCheck().check("MiddleModule").check("BasicModule").check_not( 407 "__torch__.torch.classes.__backends__.test_backend" 408 ).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph) 409 FileCheck().check("BasicModule").check_not( 410 "__torch__.torch.classes.__backends__.test_backend" 411 ).check_not("LoweredModule.test_backend").run( 412 self.scripted_module.other.submodule.graph 413 ) 414 415 def test_errors(self): 416 """ 417 Check errors associated with selective lowering. 418 """ 419 # Check error messages thrown when attempting to lower something that is not a ScriptModule. 420 with self.assertRaisesRegexWithHighlight( 421 RuntimeError, r"Object .* is not a ScriptModule", "" 422 ): 423 to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"]) 424 425 MiddleModule = SelectiveLoweringTest.MiddleModule 426 mod = MiddleModule(BasicModule()) 427 mod.new_attr = 3 428 429 with self.assertRaisesRegexWithHighlight( 430 RuntimeError, r"Attribute named new_attr is not a Module", "" 431 ): 432 to_test_backend_selective( 433 torch.jit.script(mod), {"forward": ""}, ["new_attr"] 434 ) 435 436 # Check error message thrown when module hierarchy doesn't have unique types. 437 OuterModule = SelectiveLoweringTest.OuterModule 438 mod = OuterModule( 439 MiddleModule(BasicModule()), 440 MiddleModule(BasicModule()), 441 MiddleModule(BasicModule()), 442 ) 443 444 with self.assertRaisesRegexWithHighlight( 445 RuntimeError, 446 r"Selective lowering is only supported for module hierarchies with unique types", 447 "", 448 ): 449 to_test_backend_selective( 450 torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"] 451 ) 452 453 454# This is needed for IS_WINDOWS or IS_MACOS to skip the tests. 455@unittest.skipIf( 456 TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, 457 "Non-portable load_library call used in test", 458) 459class TestBackends(JitTestCase): 460 """ 461 This class wraps and invokes all subclasses of JitBackendTestCase so that each one 462 does not have to be individually imported in test_jit.py. 463 """ 464 465 def __init__(self, name): 466 super().__init__(name) 467 self.basic_module_test = BasicModuleTest(name) 468 self.basic_module_unavailable_test = BasicModuleUnavailableTest(name) 469 self.nested_module_test = NestedModuleTest(name) 470 self.selective_lowering_test = SelectiveLoweringTest(name) 471 472 def setUp(self): 473 super().setUp() 474 if not TEST_WITH_ROCM: 475 self.basic_module_test.setUp() 476 self.basic_module_unavailable_test.setUp() 477 self.nested_module_test.setUp() 478 self.selective_lowering_test.setUp() 479 480 @skipIfRocm 481 def test_execution(self): 482 self.basic_module_test.test_execution() 483 self.basic_module_unavailable_test.test_execution() 484 self.nested_module_test.test_execution() 485 self.selective_lowering_test.test_execution() 486 487 @skipIfRocm 488 def test_save_load(self): 489 self.basic_module_test.test_save_load() 490 self.basic_module_unavailable_test.test_save_load() 491 self.nested_module_test.test_save_load() 492 self.selective_lowering_test.test_save_load() 493 494 @skipIfRocm 495 def test_errors(self): 496 self.selective_lowering_test.test_errors() 497 498 499""" 500Unit Tests for backend with compiler 501This test case and the existing TestBackends are separate because they cover different aspects. 502The actual backend implementation in this test is different. 503It has a simple demo compiler to test the end-to-end flow in mobile. 504However, this test cannot cover the selective_lowering for now, which is covered in TestBackends. 505""" 506 507 508class BasicModuleAdd(torch.nn.Module): 509 """ 510 A simple add Module used to test to_backend lowering machinery. 511 """ 512 513 def forward(self, x, h): 514 return x + h 515 516 517# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. 518@unittest.skipIf( 519 TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, 520 "Non-portable load_library call used in test", 521) 522class JitBackendTestCaseWithCompiler(JitTestCase): 523 """ 524 A common base class for JIT backend tests with compilers that contains common utility 525 functions for output comparison. 526 """ 527 528 def setUp(self): 529 super().setUp() 530 lib_file_path = find_library_location("libbackend_with_compiler.so") 531 torch.ops.load_library(str(lib_file_path)) 532 # Subclasses are expected to set up four variables in their setUp methods: 533 # module - a regular, Python version of the module being tested 534 # scripted_module - a scripted version of module 535 # lowered_module - a version of module lowered to a backend 536 # mobile_module - a module with a format that Pytorch Mobile can execute 537 538 def check_forward(self, input): 539 """ 540 Check that the forward function produces the same output using 541 Python, regular JIT, the backend, and mobile for the given 'input'. 542 """ 543 544 # Get outputs from forward. 545 python_output = self.module.forward(*input) 546 jit_output = self.scripted_module.forward(*input) 547 backend_output = self.lowered_module(*input) 548 mobile_output = self.mobile_module(*input) 549 550 # The answers returned by Python, JIT, to_backend, and mobile should all match. 551 self.assertEqual(python_output, backend_output) 552 self.assertEqual(jit_output, backend_output) 553 self.assertEqual(mobile_output, backend_output) 554 555 def test_execution(self): 556 """ 557 Stub for correctness tests. 558 """ 559 560 def test_errors(self): 561 """ 562 Stub for testing error checking. 563 """ 564 565 566class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler): 567 """ 568 Tests for BasicModuleAdd. 569 """ 570 571 def setUp(self): 572 super().setUp() 573 # Create Python, JIT and backend versions of BasicModuleAdd. 574 self.module = BasicModuleAdd() 575 self.scripted_module = torch.jit.script(BasicModuleAdd()) 576 compile_spec = { 577 "forward": { 578 "input_shapes": "((1, 1, 320, 240), (1, 3))", 579 "some_other_option": "True", 580 }, 581 } 582 self.lowered_module = torch._C._jit_to_backend( 583 "backend_with_compiler_demo", self.scripted_module, compile_spec 584 ) 585 # Create mobile version of BasicModuleAdd 586 buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter()) 587 buffer.seek(0) 588 self.mobile_module = _load_for_lite_interpreter(buffer) 589 590 def test_execution(self): 591 # Test execution with backend against Python and JIT. 592 input = torch.ones(1, dtype=torch.float) 593 self.check_forward((input, input)) 594 595 596class ErrorMessagesWithCompiler(JitBackendTestCase): 597 """ 598 Tests for errors that occur with compiler, specifically: 599 * an operator is not supported by the backend 600 """ 601 602 class ModuleNotSupported(torch.nn.Module): 603 """ 604 A module with an operator that is not supported. 605 """ 606 607 def forward(self, x, h): 608 return x * h 609 self._loweredmodule.forward() 610 611 def test_errors(self): 612 scripted_module_n = torch.jit.script( 613 ErrorMessagesWithCompiler.ModuleNotSupported() 614 ) 615 # Test exception is thrown when lowering a module with an unsupported operator 616 with self.assertRaisesRegexWithHighlight( 617 RuntimeError, 618 # Special escape characters are replaced with '.' 619 r"""The node of aten::mul is not supported in this compiler. .* 620 def forward.self, x, h.: 621 return x . h 622 ~~~~~ <--- HERE 623 self._loweredmodule.forward.. 624""", 625 "", 626 ): 627 lowered_module_n = torch._C._jit_to_backend( 628 "backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}} 629 ) 630 631 632class CompModuleTestWithCompiler(JitBackendTestCase): 633 """ 634 Tests for CompModule, which is a module with two lowered submodules 635 """ 636 637 class BasicModuleSub(torch.nn.Module): 638 """ 639 A simple subtraction Module to be used in CompModule. 640 """ 641 642 def forward(self, x, h): 643 return x - h 644 645 class CompModule(torch.nn.Module): 646 """ 647 A module with two lowered submodules. 648 """ 649 650 def __init__(self, addmodule, submodule): 651 super().__init__() 652 self.lowered_add = addmodule 653 self.lowered_sub = submodule 654 655 def forward(self, a, b, s): 656 c = self.lowered_add.forward(a, b) 657 d = self.lowered_sub.forward(a, b) 658 y = s * (c * d) 659 return y 660 661 def setUp(self): 662 super().setUp() 663 # Create Python and JIT versions of CompModule with lowered submodules. 664 compile_spec = { 665 "forward": { 666 "input_shapes": "((1, 1, 320, 240), (1, 3))", 667 "some_other_option": "True", 668 }, 669 } 670 lowered_add = torch._C._jit_to_backend( 671 "backend_with_compiler_demo", 672 torch.jit.script(BasicModuleAdd()), 673 compile_spec, 674 ) 675 lowered_sub = torch._C._jit_to_backend( 676 "backend_with_compiler_demo", 677 torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()), 678 {"forward": {"": ""}}, 679 ) 680 self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub) 681 self.scripted_module = torch.jit.script( 682 CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub) 683 ) 684 # No backend version of CompModule currently, so this is filler. 685 self.lowered_module = self.scripted_module 686 # Create a mobile version of CompModule from JIT version 687 buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter()) 688 buffer.seek(0) 689 self.mobile_module = _load_for_lite_interpreter(buffer) 690 691 def test_execution(self): 692 # Test execution with backend against Python and JIT. 693 input1 = torch.ones(1, dtype=torch.float) 694 input2 = torch.ones(1, dtype=torch.float) 695 696 # Test forward. 697 self.check_function("forward", (input1, input2, input2)) 698 699 700# This is needed for IS_WINDOWS or IS_MACOS to skip the tests. 701@unittest.skipIf( 702 IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, 703 "Non-portable load_library call used in test", 704) 705class TestBackendsWithCompiler(JitTestCase): 706 """ 707 This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler 708 so that each one does not have to be individually imported in test_jit.py. 709 """ 710 711 def __init__(self, name): 712 super().__init__(name) 713 self.basic_module_compiler_test = BasicModuleTestWithCompiler(name) 714 self.error_module_compiler_test = ErrorMessagesWithCompiler(name) 715 self.comp_module_compiler_test = CompModuleTestWithCompiler(name) 716 717 def setUp(self): 718 super().setUp() 719 self.basic_module_compiler_test.setUp() 720 self.error_module_compiler_test.setUp() 721 self.comp_module_compiler_test.setUp() 722 723 def test_execution(self): 724 self.basic_module_compiler_test.test_execution() 725 self.comp_module_compiler_test.test_execution() 726 727 def test_errors(self): 728 self.error_module_compiler_test.test_errors() 729 730 731class CompModuleTestSameNameWithCompiler(JitBackendTestCase): 732 """ 733 Tests for CompModule, which is a module with two lowered submodules with same module name 734 """ 735 736 class ModuleAdd(torch.nn.Module): 737 """ 738 A simple Module used to test to_backend lowering machinery. 739 """ 740 741 def forward(self, x, h): 742 return x + h 743 744 class CompModule(torch.nn.Module): 745 """ 746 A module with two lowered submodules. 747 """ 748 749 def __init__(self) -> None: 750 super().__init__() 751 compile_spec = { 752 "forward": { 753 "some_other_option": "True", 754 }, 755 } 756 self.add = torch._C._jit_to_backend( 757 "backend_with_compiler_demo", 758 torch.jit.script(ModuleAdd()), # noqa: F821 759 compile_spec, 760 ) 761 self.sub = torch._C._jit_to_backend( 762 "backend_with_compiler_demo", 763 torch.jit.script(ModuleAdd()), # noqa: F821 764 compile_spec, 765 ) 766 767 def forward(self, a, b, s: int): 768 c = self.add.forward(a, b) 769 d = self.sub.forward(a, b) 770 y = s * (c * d) 771 return y 772 773 def setUp(self): 774 super().setUp() 775 776 self.module = CompModule() # noqa: F821 777 self.scripted_module = torch.jit.script(self.module) 778 buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter()) 779 buffer.seek(0) 780 self.mobile_module = _load_for_lite_interpreter(buffer) 781 782 def test_execution(self): 783 a = torch.ones(1) 784 b = 3 * torch.ones(1) 785 s = 3 786 # Test forward. 787 self.check_function("forward", (a, b, s)) 788 789 790class AddedAttributesTest(JitBackendTestCase): 791 """ 792 Tests for adding attributes to a model after lowering. 793 """ 794 795 def setUp(self): 796 super().setUp() 797 # Create Python, JIT and backend versions of BasicModule. 798 self.module = BasicModule() 799 self.scripted_module = torch.jit.script(BasicModule()) 800 self.lowered_module = to_test_backend_multi( 801 self.scripted_module, 802 {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, 803 ) 804 805 def test_attribute(self): 806 input = [(torch.ones(5),)] 807 pre_bundled = self.lowered_module(*input[0]) 808 # Attach bundled inputs which adds several attributes and functions to the model 809 self.lowered_module = ( 810 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 811 lowered_module, input # noqa: F821 812 ) 813 ) 814 post_bundled = self.lowered_module( 815 *self.lowered_module.get_all_bundled_inputs()[0] 816 ) 817 # Save and load the lowered module. 818 self.save_load() 819 # Use bundled after save and load to prove its preserved 820 post_load = self.lowered_module( 821 *self.lowered_module.get_all_bundled_inputs()[0] 822 ) 823 self.assertEqual(pre_bundled, post_bundled) 824 self.assertEqual(post_bundled, post_load) 825