1# Owner(s): ["oncall: package/deploy"] 2 3from io import BytesIO 4from textwrap import dedent 5from unittest import skipIf 6 7import torch 8from torch.package import PackageExporter, PackageImporter 9from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests 10 11 12try: 13 from .common import PackageTestCase 14except ImportError: 15 # Support the case where we run this file directly. 16 from common import PackageTestCase 17 18try: 19 from torchvision.models import resnet18 20 21 HAS_TORCHVISION = True 22except ImportError: 23 HAS_TORCHVISION = False 24skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision") 25 26 27class TestPackageScript(PackageTestCase): 28 """Tests for compatibility with TorchScript.""" 29 30 def test_package_interface(self): 31 """Packaging an interface class should work correctly.""" 32 33 import package_a.fake_interface as fake 34 35 uses_interface = fake.UsesInterface() 36 scripted = torch.jit.script(uses_interface) 37 scripted.proxy_mod = torch.jit.script(fake.NewModule()) 38 39 buffer = BytesIO() 40 with PackageExporter(buffer) as pe: 41 pe.intern("**") 42 pe.save_pickle("model", "model.pkl", uses_interface) 43 buffer.seek(0) 44 45 package_importer = PackageImporter(buffer) 46 loaded = package_importer.load_pickle("model", "model.pkl") 47 48 scripted_loaded = torch.jit.script(loaded) 49 scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule()) 50 51 input = torch.tensor(1) 52 53 self.assertEqual(scripted(input), scripted_loaded(input)) 54 55 def test_different_package_interface(self): 56 """Test a case where the interface defined in the package is 57 different than the one defined in the loading environment, to make 58 sure TorchScript can distinguish between the two. 59 """ 60 # Import one version of the interface 61 import package_a.fake_interface as fake 62 63 # Simulate a package that contains a different version of the 64 # interface, with the exact same name. 65 buffer = BytesIO() 66 with PackageExporter(buffer) as pe: 67 pe.save_source_string( 68 fake.__name__, 69 dedent( 70 """\ 71 import torch 72 from torch import Tensor 73 74 @torch.jit.interface 75 class ModuleInterface(torch.nn.Module): 76 def one(self, inp1: Tensor) -> Tensor: 77 pass 78 79 class ImplementsInterface(torch.nn.Module): 80 def one(self, inp1: Tensor) -> Tensor: 81 return inp1 + 1 82 83 class UsesInterface(torch.nn.Module): 84 proxy_mod: ModuleInterface 85 86 def __init__(self) -> None: 87 super().__init__() 88 self.proxy_mod = ImplementsInterface() 89 90 def forward(self, input: Tensor) -> Tensor: 91 return self.proxy_mod.one(input) 92 """ 93 ), 94 ) 95 buffer.seek(0) 96 97 package_importer = PackageImporter(buffer) 98 diff_fake = package_importer.import_module(fake.__name__) 99 # We should be able to script successfully. 100 torch.jit.script(diff_fake.UsesInterface()) 101 102 def test_package_script_class(self): 103 import package_a.fake_script_class as fake 104 105 buffer = BytesIO() 106 with PackageExporter(buffer) as pe: 107 pe.save_module(fake.__name__) 108 buffer.seek(0) 109 110 package_importer = PackageImporter(buffer) 111 loaded = package_importer.import_module(fake.__name__) 112 113 input = torch.tensor(1) 114 self.assertTrue( 115 torch.allclose( 116 fake.uses_script_class(input), loaded.uses_script_class(input) 117 ) 118 ) 119 120 def test_package_script_class_referencing_self(self): 121 import package_a.fake_script_class as fake 122 123 obj = fake.UsesIdListFeature() 124 # intentionally script here to fill the compilation cache, to make sure 125 # there is no false sharing between scripted types coming from the 126 # package vs. outside environment. 127 torch.jit.script(obj) 128 129 buffer = BytesIO() 130 with PackageExporter(buffer) as exporter: 131 exporter.intern("**") 132 exporter.save_pickle("obj", "obj.pkl", obj) 133 134 buffer.seek(0) 135 importer = PackageImporter(buffer) 136 obj_loaded = importer.load_pickle("obj", "obj.pkl") 137 scripted_obj_loaded = torch.jit.script(obj_loaded) 138 139 # Make sure the scripted object can be serialized without error. 140 buffer2 = scripted_obj_loaded.save_to_buffer() 141 torch.jit.load(BytesIO(buffer2)) 142 143 def test_different_package_script_class(self): 144 """Test a case where the script class defined in the package is 145 different than the one defined in the loading environment, to make 146 sure TorchScript can distinguish between the two. 147 """ 148 import package_a.fake_script_class as fake 149 150 # Simulate a package that contains a different version of the 151 # script class ,with the attribute `bar` instead of `foo` 152 buffer = BytesIO() 153 with PackageExporter(buffer) as pe2: 154 pe2.save_source_string( 155 fake.__name__, 156 dedent( 157 """\ 158 import torch 159 160 @torch.jit.script 161 class MyScriptClass: 162 def __init__(self, x): 163 self.bar = x 164 """ 165 ), 166 ) 167 buffer.seek(0) 168 169 package_importer = PackageImporter(buffer) 170 diff_fake = package_importer.import_module(fake.__name__) 171 input = torch.rand(2, 3) 172 loaded_script_class = diff_fake.MyScriptClass(input) 173 orig_script_class = fake.MyScriptClass(input) 174 self.assertEqual(loaded_script_class.bar, orig_script_class.foo) 175 176 def test_save_scriptmodule(self): 177 """ 178 Test basic saving of ScriptModule. 179 """ 180 from package_a.test_module import ModWithTensor 181 182 scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) 183 184 buffer = BytesIO() 185 with PackageExporter(buffer) as e: 186 e.save_pickle("res", "mod.pkl", scripted_mod) 187 188 buffer.seek(0) 189 importer = PackageImporter(buffer) 190 loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu") 191 input = torch.rand(1, 2, 3) 192 self.assertEqual(loaded_mod(input), scripted_mod(input)) 193 194 @skipIf( 195 IS_FBCODE or IS_SANDCASTLE, 196 "Tests that use temporary files are disabled in fbcode", 197 ) 198 def test_save_scriptmodule_file(self): 199 """ 200 Test basic saving of ScriptModule in file. 201 """ 202 from package_a.test_module import ModWithTensor 203 204 scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) 205 206 filename = self.temp() 207 with PackageExporter(filename) as e: 208 e.save_pickle("res", "mod.pkl", scripted_mod) 209 210 importer = PackageImporter(filename) 211 loaded_mod = importer.load_pickle("res", "mod.pkl") 212 input = torch.rand(1, 2, 3) 213 self.assertEqual(loaded_mod(input), scripted_mod(input)) 214 215 def test_save_scriptmodule_with_submods(self): 216 """ 217 Test basic saving of ScriptModule with submodule. 218 """ 219 from package_a.test_module import ModWithSubmod, ModWithTensor 220 221 scripted_mod = torch.jit.script( 222 ModWithSubmod(ModWithTensor(torch.rand(1, 2, 3))) 223 ) 224 225 buffer = BytesIO() 226 with PackageExporter(buffer) as e: 227 e.save_pickle("res", "mod.pkl", scripted_mod) 228 229 buffer.seek(0) 230 importer = PackageImporter(buffer) 231 loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu") 232 input = torch.rand(1, 2, 3) 233 self.assertEqual(loaded_mod(input), scripted_mod(input)) 234 235 def test_save_scriptmodules_submod_redefinition(self): 236 """ 237 Test to verify saving multiple ScriptModules with same top module 238 but different submodules works. Submodule is redefined to between 239 the defintion of the top module to check that the different concrete 240 types of the modules are thoroughly recognized by serializaiton code. 241 """ 242 243 class Submod(torch.nn.Module): 244 def forward(self, input: str): 245 input = input + "_submod" 246 return input 247 248 class TopMod(torch.nn.Module): 249 def __init__(self) -> None: 250 super().__init__() 251 self.modB = Submod() 252 253 def forward(self, input: str): 254 return self.modB(input) 255 256 scripted_mod_0 = torch.jit.script(TopMod()) 257 258 # redefinition is intentional, change single inner string 259 # string attribute, should trigger new module type 260 class Submod(torch.nn.Module): # noqa: F811 261 def forward(self, input: str): 262 input = input + "_submod(changed)" 263 return input 264 265 scripted_mod_1 = torch.jit.script(TopMod()) 266 267 buffer = BytesIO() 268 with PackageExporter(buffer) as e: 269 e.save_pickle("res", "mod1.pkl", scripted_mod_0) 270 e.save_pickle("res", "mod2.pkl", scripted_mod_1) 271 272 buffer.seek(0) 273 importer = PackageImporter(buffer) 274 loaded_mod_0 = importer.load_pickle("res", "mod1.pkl") 275 loaded_mod_1 = importer.load_pickle("res", "mod2.pkl") 276 self.assertEqual(loaded_mod_0("input"), scripted_mod_0("input")) 277 self.assertEqual(loaded_mod_1("input"), scripted_mod_1("input")) 278 self.assertNotEqual(loaded_mod_0("input"), loaded_mod_1("input")) 279 280 def test_save_independent_scriptmodules(self): 281 """ 282 Test to verify saving multiple ScriptModules with completely 283 separate code works. 284 """ 285 from package_a.test_module import ModWithTensor, SimpleTest 286 287 scripted_mod_0 = torch.jit.script(SimpleTest()) 288 scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) 289 290 buffer = BytesIO() 291 with PackageExporter(buffer) as e: 292 e.save_pickle("res", "mod1.pkl", scripted_mod_0) 293 e.save_pickle("res", "mod2.pkl", scripted_mod_1) 294 295 buffer.seek(0) 296 importer = PackageImporter(buffer) 297 loaded_mod_0 = importer.load_pickle("res", "mod1.pkl") 298 loaded_mod_1 = importer.load_pickle("res", "mod2.pkl") 299 input = torch.rand(1, 2, 3) 300 self.assertEqual(loaded_mod_0(input), scripted_mod_0(input)) 301 self.assertEqual(loaded_mod_1(input), scripted_mod_1(input)) 302 303 def test_save_repeat_scriptmodules(self): 304 """ 305 Test to verify saving multiple different modules and 306 repeats of same scriptmodule in package works. Also tests that 307 PyTorchStreamReader isn't having code hidden from 308 PyTorchStreamWriter writing ScriptModule code files multiple times. 309 """ 310 from package_a.test_module import ( 311 ModWithSubmodAndTensor, 312 ModWithTensor, 313 SimpleTest, 314 ) 315 316 scripted_mod_0 = torch.jit.script(SimpleTest()) 317 scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) 318 scripted_mod_2 = torch.jit.script( 319 ModWithSubmodAndTensor( 320 torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3)) 321 ) 322 ) 323 324 buffer = BytesIO() 325 with PackageExporter(buffer) as e: 326 e.save_pickle("res", "mod0.pkl", scripted_mod_0) 327 e.save_pickle("res", "mod1.pkl", scripted_mod_1) 328 e.save_pickle("res", "mod2.pkl", scripted_mod_0) 329 e.save_pickle("res", "mod3.pkl", scripted_mod_1) 330 e.save_pickle("res", "mod4.pkl", scripted_mod_2) 331 332 buffer.seek(0) 333 importer = PackageImporter(buffer) 334 loaded_mod_0 = importer.load_pickle("res", "mod0.pkl") 335 loaded_mod_1 = importer.load_pickle("res", "mod3.pkl") 336 loaded_mod_2 = importer.load_pickle("res", "mod4.pkl") 337 input = torch.rand(1, 2, 3) 338 self.assertEqual(loaded_mod_0(input), scripted_mod_0(input)) 339 self.assertEqual(loaded_mod_1(input), scripted_mod_1(input)) 340 self.assertEqual(loaded_mod_2(input), scripted_mod_2(input)) 341 342 def test_scriptmodules_repeat_save(self): 343 """ 344 Test to verify saving and loading same ScriptModule object works 345 across multiple packages. 346 """ 347 from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor 348 349 scripted_mod_0 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) 350 scripted_mod_1 = torch.jit.script( 351 ModWithSubmodAndTensor( 352 torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3)) 353 ) 354 ) 355 356 buffer_0 = BytesIO() 357 with PackageExporter(buffer_0) as e: 358 e.save_pickle("res", "mod1.pkl", scripted_mod_0) 359 360 buffer_0.seek(0) 361 importer_0 = PackageImporter(buffer_0) 362 loaded_module_0 = importer_0.load_pickle("res", "mod1.pkl") 363 364 buffer_1 = BytesIO() 365 with PackageExporter(buffer_1) as e: 366 e.save_pickle("res", "mod1.pkl", scripted_mod_1) 367 e.save_pickle("res", "mod2.pkl", loaded_module_0) 368 369 buffer_1.seek(0) 370 importer_1 = PackageImporter(buffer_1) 371 loaded_module_1 = importer_1.load_pickle("res", "mod1.pkl") 372 reloaded_module_0 = importer_1.load_pickle("res", "mod2.pkl") 373 374 input = torch.rand(1, 2, 3) 375 self.assertEqual(loaded_module_0(input), scripted_mod_0(input)) 376 self.assertEqual(loaded_module_0(input), reloaded_module_0(input)) 377 self.assertEqual(loaded_module_1(input), scripted_mod_1(input)) 378 379 @skipIfNoTorchVision 380 def test_save_scriptmodule_only_necessary_code(self): 381 """ 382 Test to verify when saving multiple packages with same CU 383 that packages don't include unnecessary torchscript code files. 384 The TorchVision code should only be saved in the package that 385 relies on it. 386 """ 387 from package_a.test_module import ModWithTensor 388 389 class ModWithTorchVision(torch.nn.Module): 390 def __init__(self, name: str): 391 super().__init__() 392 self.tvmod = resnet18() 393 394 def forward(self, input): 395 return input * 4 396 397 scripted_mod_0 = torch.jit.script(ModWithTorchVision("foo")) 398 scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) 399 400 buffer_0 = BytesIO() 401 with PackageExporter(buffer_0) as e: 402 e.save_pickle("res", "mod1.pkl", scripted_mod_0) 403 404 buffer_0.seek(0) 405 importer_0 = importer = PackageImporter(buffer_0) 406 407 buffer_1 = BytesIO() 408 with PackageExporter(buffer_1) as e: 409 e.save_pickle("res", "mod1.pkl", scripted_mod_1) 410 411 buffer_1.seek(0) 412 importer_1 = PackageImporter(buffer_1) 413 414 self.assertTrue("torchvision" in str(importer_0.file_structure())) 415 self.assertFalse("torchvision" in str(importer_1.file_structure())) 416 417 def test_save_scriptmodules_in_container(self): 418 """ 419 Test saving of ScriptModules inside of container. Checks that relations 420 between shared modules are upheld. 421 """ 422 from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor 423 424 scripted_mod_a = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) 425 scripted_mod_b = torch.jit.script( 426 ModWithSubmodAndTensor(torch.rand(1, 2, 3), scripted_mod_a) 427 ) 428 script_mods_list = [scripted_mod_a, scripted_mod_b] 429 430 buffer = BytesIO() 431 with PackageExporter(buffer) as e: 432 e.save_pickle("res", "list.pkl", script_mods_list) 433 434 buffer.seek(0) 435 importer = PackageImporter(buffer) 436 loaded_mod_list = importer.load_pickle("res", "list.pkl") 437 input = torch.rand(1, 2, 3) 438 self.assertEqual(loaded_mod_list[0](input), scripted_mod_a(input)) 439 self.assertEqual(loaded_mod_list[1](input), scripted_mod_b(input)) 440 441 def test_save_eager_mods_sharing_scriptmodule(self): 442 """ 443 Test saving of single ScriptModule shared by multiple 444 eager modules (ScriptModule should be saved just once 445 even though is contained in multiple pickles). 446 """ 447 from package_a.test_module import ModWithSubmod, SimpleTest 448 449 scripted_mod = torch.jit.script(SimpleTest()) 450 451 mod1 = ModWithSubmod(scripted_mod) 452 mod2 = ModWithSubmod(scripted_mod) 453 454 buffer = BytesIO() 455 with PackageExporter(buffer) as e: 456 e.intern("**") 457 e.save_pickle("res", "mod1.pkl", mod1) 458 e.save_pickle("res", "mod2.pkl", mod2) 459 460 buffer.seek(0) 461 importer = PackageImporter(buffer) 462 file_structure = importer.file_structure() 463 self.assertTrue(file_structure.has_file(".data/ts_code/0")) 464 self.assertFalse(file_structure.has_file(".data/ts_code/1")) 465 466 def test_load_shared_scriptmodules(self): 467 """ 468 Test loading of single ScriptModule shared by multiple eager 469 modules in single pickle (ScriptModule objects should be the same). 470 """ 471 from package_a.test_module import ( 472 ModWithMultipleSubmods, 473 ModWithSubmod, 474 SimpleTest, 475 ) 476 477 scripted_mod = torch.jit.script(SimpleTest()) 478 479 mod1 = ModWithSubmod(scripted_mod) 480 mod2 = ModWithSubmod(scripted_mod) 481 482 mod_parent = ModWithMultipleSubmods(mod1, mod2) 483 484 buffer = BytesIO() 485 with PackageExporter(buffer) as e: 486 e.intern("**") 487 e.save_pickle("res", "mod.pkl", mod_parent) 488 489 buffer.seek(0) 490 importer = PackageImporter(buffer) 491 492 loaded_mod = importer.load_pickle("res", "mod.pkl") 493 self.assertTrue( 494 id(loaded_mod.mod1.script_mod) == id(loaded_mod.mod2.script_mod) 495 ) 496 497 def test_save_shared_tensors(self): 498 """ 499 Test tensors shared across eager and ScriptModules are serialized once. 500 """ 501 from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor 502 503 shared_tensor = torch.rand(2, 3, 4) 504 scripted_mod = torch.jit.script(ModWithTensor(shared_tensor)) 505 506 mod1 = ModWithSubmodAndTensor(shared_tensor, scripted_mod) 507 mod2 = ModWithSubmodAndTensor(shared_tensor, scripted_mod) 508 509 buffer = BytesIO() 510 with PackageExporter(buffer) as e: 511 e.intern("**") 512 e.save_pickle("res", "tensor", shared_tensor) 513 e.save_pickle("res", "mod1.pkl", mod1) 514 e.save_pickle("res", "mod2.pkl", mod2) 515 516 buffer.seek(0) 517 importer = PackageImporter(buffer) 518 loaded_mod_1 = importer.load_pickle("res", "mod1.pkl") 519 520 # assert that there is only one storage stored in package 521 file_structure = importer.file_structure(include=".data/*.storage") 522 self.assertTrue(len(file_structure.children[".data"].children) == 1) 523 524 input = torch.rand(2, 3, 4) 525 self.assertEqual(loaded_mod_1(input), mod1(input)) 526 527 def test_load_shared_tensors(self): 528 """ 529 Test tensors shared across eager and ScriptModules on load 530 are the same. 531 """ 532 from package_a.test_module import ModWithTensor, ModWithTwoSubmodsAndTensor 533 534 shared_tensor = torch.ones(3, 3) 535 536 scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor)) 537 scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor)) 538 539 mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1) 540 541 self.assertEqual( 542 shared_tensor.storage()._cdata, 543 scripted_mod_0.tensor.storage()._cdata, 544 ) 545 self.assertEqual( 546 shared_tensor.storage()._cdata, 547 scripted_mod_1.tensor.storage()._cdata, 548 ) 549 550 buffer = BytesIO() 551 with PackageExporter(buffer) as e: 552 e.intern("**") 553 e.save_pickle("res", "mod1.pkl", mod1) 554 555 buffer.seek(0) 556 importer = PackageImporter(buffer) 557 loaded_mod_1 = importer.load_pickle("res", "mod1.pkl") 558 559 self.assertEqual( 560 loaded_mod_1.tensor.storage()._cdata, 561 loaded_mod_1.sub_mod_0.tensor.storage()._cdata, 562 ) 563 self.assertEqual( 564 loaded_mod_1.tensor.storage()._cdata, 565 loaded_mod_1.sub_mod_1.tensor.storage()._cdata, 566 ) 567 568 loaded_mod_1.tensor.add_(torch.ones(3, 3)) 569 570 self.assertTrue( 571 torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor) 572 ) 573 self.assertTrue( 574 torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor) 575 ) 576 577 def test_load_shared_tensors_repackaged(self): 578 """ 579 Test tensors shared across eager and ScriptModules on load 580 are the same across multiple package saves and loads. This is 581 an important test because not all of the tensor information is restored 582 in python between packages. The python identity is not maintained, but 583 the backing cpp TensorImpl is. We load/save storages based off of this 584 cpp TensorImpl and not the python identity. 585 """ 586 from package_a.test_module import ModWithTensor, ModWithTwoSubmodsAndTensor 587 588 shared_tensor = torch.ones(3, 3) 589 590 scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor)) 591 scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor)) 592 593 mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1) 594 595 buffer_0 = BytesIO() 596 with PackageExporter(buffer_0) as e: 597 e.intern("**") 598 e.save_pickle("res", "mod1.pkl", mod1) 599 600 buffer_0.seek(0) 601 importer_0 = PackageImporter(buffer_0) 602 loaded_mod_0 = importer_0.load_pickle("res", "mod1.pkl") 603 604 buffer_1 = BytesIO() 605 with PackageExporter(buffer_1, importer=importer_0) as e: 606 e.intern("**") 607 e.save_pickle("res", "mod1.pkl", loaded_mod_0) 608 609 buffer_1.seek(0) 610 importer = PackageImporter(buffer_1) 611 loaded_mod_1 = importer.load_pickle("res", "mod1.pkl") 612 613 self.assertEqual( 614 loaded_mod_1.tensor.storage()._cdata, 615 loaded_mod_1.sub_mod_0.tensor.storage()._cdata, 616 ) 617 self.assertEqual( 618 loaded_mod_1.tensor.storage()._cdata, 619 loaded_mod_1.sub_mod_1.tensor.storage()._cdata, 620 ) 621 622 loaded_mod_1.tensor.add_( 623 torch.ones(3, 3) 624 ) # all tensors should reflect this change 625 626 self.assertTrue( 627 torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor) 628 ) 629 self.assertTrue( 630 torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor) 631 ) 632 633 def test_saving_and_scripting_packaged_mod(self): 634 """ 635 Test scripting a module loaded from a package 636 and saving it in a new package as a script object. 637 """ 638 from package_a.test_module import SimpleTest 639 640 orig_mod = SimpleTest() 641 642 buffer_0 = BytesIO() 643 with PackageExporter(buffer_0) as e: 644 e.intern("**") 645 e.save_pickle("model", "model.pkl", orig_mod) 646 647 buffer_0.seek(0) 648 importer_0 = PackageImporter(buffer_0) 649 loaded_mod = importer_0.load_pickle("model", "model.pkl") 650 651 input = torch.rand(2, 3) 652 self.assertEqual(loaded_mod(input), orig_mod(input)) 653 654 scripted_mod = torch.jit.script(loaded_mod) 655 656 buffer_1 = BytesIO() 657 with PackageExporter(buffer_1, importer=importer_0) as e: 658 e.intern("**") 659 e.save_pickle("res", "scripted_mod.pkl", scripted_mod) 660 661 buffer_1.seek(0) 662 importer_1 = PackageImporter(buffer_1) 663 loaded_mod_scripted = importer_1.load_pickle("res", "scripted_mod.pkl") 664 665 self.assertEqual(loaded_mod_scripted(input), orig_mod(input)) 666 667 def test_mixing_packaged_and_inline_modules(self): 668 """ 669 Test saving inline and imported modules in same package with 670 independent code. 671 """ 672 673 class InlineMod(torch.nn.Module): 674 def __init__(self, name: str): 675 super().__init__() 676 self.name = name 677 self.tensor = torch.rand(1, 2, 3) 678 679 def forward(self, input: str): 680 input = input + "_modInline:" + self.name 681 return input, (self.tensor * 4) 682 683 inline_mod = InlineMod("inline") 684 scripted_inline = torch.jit.script(inline_mod) 685 686 from package_a.test_module import SimpleTest 687 688 imported_mod = SimpleTest() 689 scripted_imported = torch.jit.script(imported_mod) 690 691 buffer = BytesIO() 692 with PackageExporter(buffer) as e: 693 e.save_pickle("model", "inline.pkl", scripted_inline) 694 e.save_pickle("model", "imported.pkl", scripted_imported) 695 696 buffer.seek(0) 697 importer = PackageImporter(buffer) 698 loaded_inline = importer.load_pickle("model", "inline.pkl") 699 loaded_imported = importer.load_pickle("model", "imported.pkl") 700 701 input = torch.rand(2, 3) 702 self.assertEqual(loaded_imported(input), imported_mod(input)) 703 self.assertEqual(loaded_inline("input"), inline_mod("input")) 704 705 @skipIfNoTorchVision 706 def test_mixing_packaged_and_inline_modules_shared_code(self): 707 """ 708 Test saving inline and imported modules in same package that 709 share code. 710 """ 711 712 class TorchVisionTestInline(torch.nn.Module): 713 def __init__(self) -> None: 714 super().__init__() 715 self.tvmod = resnet18() 716 717 def forward(self, x): 718 x = a_non_torch_leaf(x, x) 719 return torch.relu(x + 3.0) 720 721 def a_non_torch_leaf(a, b): 722 return a + b 723 724 inline_mod = TorchVisionTestInline() 725 scripted_inline = torch.jit.script(inline_mod) 726 727 from package_c.test_module import TorchVisionTest 728 729 imported_mod = TorchVisionTest() 730 scripted_imported = torch.jit.script(imported_mod) 731 732 buffer = BytesIO() 733 with PackageExporter(buffer) as e: 734 e.save_pickle("model", "inline.pkl", scripted_inline) 735 e.save_pickle("model", "imported.pkl", scripted_imported) 736 737 buffer.seek(0) 738 importer = PackageImporter(buffer) 739 loaded_inline = importer.load_pickle("model", "inline.pkl") 740 loaded_imported = importer.load_pickle("model", "imported.pkl") 741 742 input = torch.rand(2, 3) 743 self.assertEqual(loaded_imported(input), imported_mod(input)) 744 self.assertEqual(loaded_inline(input), inline_mod(input)) 745 746 def test_tensor_sharing_pickle(self): 747 """Test that saving a ScriptModule and a separately saving a tensor 748 object causes no issues. 749 """ 750 751 class M(torch.nn.Module): 752 def __init__(self) -> None: 753 super().__init__() 754 self.foo = torch.ones(2, 3) 755 756 def forward(self): 757 return self.foo 758 759 scripted_m = torch.jit.script(M()) 760 original_tensor = torch.ones(0) 761 762 f = BytesIO() 763 with torch.package.PackageExporter(f) as exporter: 764 exporter.save_pickle("model", "model.pkl", scripted_m) 765 exporter.save_pickle("model", "input.pkl", original_tensor) 766 767 f.seek(0) 768 # Should be able to load correctly 769 importer = PackageImporter(f) 770 loaded_m = importer.load_pickle("model", "model.pkl") 771 loaded_tensor = importer.load_pickle("model", "input.pkl") 772 773 self.assertEqual(scripted_m.foo, loaded_m.foo) 774 self.assertEqual(original_tensor, loaded_tensor) 775 776 777if __name__ == "__main__": 778 run_tests() 779