xref: /aosp_15_r20/external/pytorch/test/package/test_package_script.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from io import BytesIO
4from textwrap import dedent
5from unittest import skipIf
6
7import torch
8from torch.package import PackageExporter, PackageImporter
9from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
10
11
12try:
13    from .common import PackageTestCase
14except ImportError:
15    # Support the case where we run this file directly.
16    from common import PackageTestCase
17
18try:
19    from torchvision.models import resnet18
20
21    HAS_TORCHVISION = True
22except ImportError:
23    HAS_TORCHVISION = False
24skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
25
26
27class TestPackageScript(PackageTestCase):
28    """Tests for compatibility with TorchScript."""
29
30    def test_package_interface(self):
31        """Packaging an interface class should work correctly."""
32
33        import package_a.fake_interface as fake
34
35        uses_interface = fake.UsesInterface()
36        scripted = torch.jit.script(uses_interface)
37        scripted.proxy_mod = torch.jit.script(fake.NewModule())
38
39        buffer = BytesIO()
40        with PackageExporter(buffer) as pe:
41            pe.intern("**")
42            pe.save_pickle("model", "model.pkl", uses_interface)
43        buffer.seek(0)
44
45        package_importer = PackageImporter(buffer)
46        loaded = package_importer.load_pickle("model", "model.pkl")
47
48        scripted_loaded = torch.jit.script(loaded)
49        scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())
50
51        input = torch.tensor(1)
52
53        self.assertEqual(scripted(input), scripted_loaded(input))
54
55    def test_different_package_interface(self):
56        """Test a case where the interface defined in the package is
57        different than the one defined in the loading environment, to make
58        sure TorchScript can distinguish between the two.
59        """
60        # Import one version of the interface
61        import package_a.fake_interface as fake
62
63        # Simulate a package that contains a different version of the
64        # interface, with the exact same name.
65        buffer = BytesIO()
66        with PackageExporter(buffer) as pe:
67            pe.save_source_string(
68                fake.__name__,
69                dedent(
70                    """\
71                    import torch
72                    from torch import Tensor
73
74                    @torch.jit.interface
75                    class ModuleInterface(torch.nn.Module):
76                        def one(self, inp1: Tensor) -> Tensor:
77                            pass
78
79                    class ImplementsInterface(torch.nn.Module):
80                        def one(self, inp1: Tensor) -> Tensor:
81                            return inp1 + 1
82
83                    class UsesInterface(torch.nn.Module):
84                        proxy_mod: ModuleInterface
85
86                        def __init__(self) -> None:
87                            super().__init__()
88                            self.proxy_mod = ImplementsInterface()
89
90                        def forward(self, input: Tensor) -> Tensor:
91                            return self.proxy_mod.one(input)
92                    """
93                ),
94            )
95        buffer.seek(0)
96
97        package_importer = PackageImporter(buffer)
98        diff_fake = package_importer.import_module(fake.__name__)
99        # We should be able to script successfully.
100        torch.jit.script(diff_fake.UsesInterface())
101
102    def test_package_script_class(self):
103        import package_a.fake_script_class as fake
104
105        buffer = BytesIO()
106        with PackageExporter(buffer) as pe:
107            pe.save_module(fake.__name__)
108        buffer.seek(0)
109
110        package_importer = PackageImporter(buffer)
111        loaded = package_importer.import_module(fake.__name__)
112
113        input = torch.tensor(1)
114        self.assertTrue(
115            torch.allclose(
116                fake.uses_script_class(input), loaded.uses_script_class(input)
117            )
118        )
119
120    def test_package_script_class_referencing_self(self):
121        import package_a.fake_script_class as fake
122
123        obj = fake.UsesIdListFeature()
124        # intentionally script here to fill the compilation cache, to make sure
125        # there is no false sharing between scripted types coming from the
126        # package vs. outside environment.
127        torch.jit.script(obj)
128
129        buffer = BytesIO()
130        with PackageExporter(buffer) as exporter:
131            exporter.intern("**")
132            exporter.save_pickle("obj", "obj.pkl", obj)
133
134        buffer.seek(0)
135        importer = PackageImporter(buffer)
136        obj_loaded = importer.load_pickle("obj", "obj.pkl")
137        scripted_obj_loaded = torch.jit.script(obj_loaded)
138
139        # Make sure the scripted object can be serialized without error.
140        buffer2 = scripted_obj_loaded.save_to_buffer()
141        torch.jit.load(BytesIO(buffer2))
142
143    def test_different_package_script_class(self):
144        """Test a case where the script class defined in the package is
145        different than the one defined in the loading environment, to make
146        sure TorchScript can distinguish between the two.
147        """
148        import package_a.fake_script_class as fake
149
150        # Simulate a package that contains a different version of the
151        # script class ,with the attribute `bar` instead of `foo`
152        buffer = BytesIO()
153        with PackageExporter(buffer) as pe2:
154            pe2.save_source_string(
155                fake.__name__,
156                dedent(
157                    """\
158                    import torch
159
160                    @torch.jit.script
161                    class MyScriptClass:
162                        def __init__(self, x):
163                            self.bar = x
164                    """
165                ),
166            )
167        buffer.seek(0)
168
169        package_importer = PackageImporter(buffer)
170        diff_fake = package_importer.import_module(fake.__name__)
171        input = torch.rand(2, 3)
172        loaded_script_class = diff_fake.MyScriptClass(input)
173        orig_script_class = fake.MyScriptClass(input)
174        self.assertEqual(loaded_script_class.bar, orig_script_class.foo)
175
176    def test_save_scriptmodule(self):
177        """
178        Test basic saving of ScriptModule.
179        """
180        from package_a.test_module import ModWithTensor
181
182        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
183
184        buffer = BytesIO()
185        with PackageExporter(buffer) as e:
186            e.save_pickle("res", "mod.pkl", scripted_mod)
187
188        buffer.seek(0)
189        importer = PackageImporter(buffer)
190        loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu")
191        input = torch.rand(1, 2, 3)
192        self.assertEqual(loaded_mod(input), scripted_mod(input))
193
194    @skipIf(
195        IS_FBCODE or IS_SANDCASTLE,
196        "Tests that use temporary files are disabled in fbcode",
197    )
198    def test_save_scriptmodule_file(self):
199        """
200        Test basic saving of ScriptModule in file.
201        """
202        from package_a.test_module import ModWithTensor
203
204        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
205
206        filename = self.temp()
207        with PackageExporter(filename) as e:
208            e.save_pickle("res", "mod.pkl", scripted_mod)
209
210        importer = PackageImporter(filename)
211        loaded_mod = importer.load_pickle("res", "mod.pkl")
212        input = torch.rand(1, 2, 3)
213        self.assertEqual(loaded_mod(input), scripted_mod(input))
214
215    def test_save_scriptmodule_with_submods(self):
216        """
217        Test basic saving of ScriptModule with submodule.
218        """
219        from package_a.test_module import ModWithSubmod, ModWithTensor
220
221        scripted_mod = torch.jit.script(
222            ModWithSubmod(ModWithTensor(torch.rand(1, 2, 3)))
223        )
224
225        buffer = BytesIO()
226        with PackageExporter(buffer) as e:
227            e.save_pickle("res", "mod.pkl", scripted_mod)
228
229        buffer.seek(0)
230        importer = PackageImporter(buffer)
231        loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu")
232        input = torch.rand(1, 2, 3)
233        self.assertEqual(loaded_mod(input), scripted_mod(input))
234
235    def test_save_scriptmodules_submod_redefinition(self):
236        """
237        Test to verify saving multiple ScriptModules with same top module
238        but different submodules works. Submodule is redefined to between
239        the defintion of the top module to check that the different concrete
240        types of the modules are thoroughly recognized by serializaiton code.
241        """
242
243        class Submod(torch.nn.Module):
244            def forward(self, input: str):
245                input = input + "_submod"
246                return input
247
248        class TopMod(torch.nn.Module):
249            def __init__(self) -> None:
250                super().__init__()
251                self.modB = Submod()
252
253            def forward(self, input: str):
254                return self.modB(input)
255
256        scripted_mod_0 = torch.jit.script(TopMod())
257
258        # redefinition is intentional, change single inner string
259        # string attribute, should trigger new module type
260        class Submod(torch.nn.Module):  # noqa: F811
261            def forward(self, input: str):
262                input = input + "_submod(changed)"
263                return input
264
265        scripted_mod_1 = torch.jit.script(TopMod())
266
267        buffer = BytesIO()
268        with PackageExporter(buffer) as e:
269            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
270            e.save_pickle("res", "mod2.pkl", scripted_mod_1)
271
272        buffer.seek(0)
273        importer = PackageImporter(buffer)
274        loaded_mod_0 = importer.load_pickle("res", "mod1.pkl")
275        loaded_mod_1 = importer.load_pickle("res", "mod2.pkl")
276        self.assertEqual(loaded_mod_0("input"), scripted_mod_0("input"))
277        self.assertEqual(loaded_mod_1("input"), scripted_mod_1("input"))
278        self.assertNotEqual(loaded_mod_0("input"), loaded_mod_1("input"))
279
280    def test_save_independent_scriptmodules(self):
281        """
282        Test to verify saving multiple ScriptModules with completely
283        separate code works.
284        """
285        from package_a.test_module import ModWithTensor, SimpleTest
286
287        scripted_mod_0 = torch.jit.script(SimpleTest())
288        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
289
290        buffer = BytesIO()
291        with PackageExporter(buffer) as e:
292            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
293            e.save_pickle("res", "mod2.pkl", scripted_mod_1)
294
295        buffer.seek(0)
296        importer = PackageImporter(buffer)
297        loaded_mod_0 = importer.load_pickle("res", "mod1.pkl")
298        loaded_mod_1 = importer.load_pickle("res", "mod2.pkl")
299        input = torch.rand(1, 2, 3)
300        self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
301        self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
302
303    def test_save_repeat_scriptmodules(self):
304        """
305        Test to verify saving multiple different modules and
306        repeats of same scriptmodule in package works. Also tests that
307        PyTorchStreamReader isn't having code hidden from
308        PyTorchStreamWriter writing ScriptModule code files multiple times.
309        """
310        from package_a.test_module import (
311            ModWithSubmodAndTensor,
312            ModWithTensor,
313            SimpleTest,
314        )
315
316        scripted_mod_0 = torch.jit.script(SimpleTest())
317        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
318        scripted_mod_2 = torch.jit.script(
319            ModWithSubmodAndTensor(
320                torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3))
321            )
322        )
323
324        buffer = BytesIO()
325        with PackageExporter(buffer) as e:
326            e.save_pickle("res", "mod0.pkl", scripted_mod_0)
327            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
328            e.save_pickle("res", "mod2.pkl", scripted_mod_0)
329            e.save_pickle("res", "mod3.pkl", scripted_mod_1)
330            e.save_pickle("res", "mod4.pkl", scripted_mod_2)
331
332        buffer.seek(0)
333        importer = PackageImporter(buffer)
334        loaded_mod_0 = importer.load_pickle("res", "mod0.pkl")
335        loaded_mod_1 = importer.load_pickle("res", "mod3.pkl")
336        loaded_mod_2 = importer.load_pickle("res", "mod4.pkl")
337        input = torch.rand(1, 2, 3)
338        self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
339        self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
340        self.assertEqual(loaded_mod_2(input), scripted_mod_2(input))
341
342    def test_scriptmodules_repeat_save(self):
343        """
344        Test to verify saving and loading same ScriptModule object works
345        across multiple packages.
346        """
347        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
348
349        scripted_mod_0 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
350        scripted_mod_1 = torch.jit.script(
351            ModWithSubmodAndTensor(
352                torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3))
353            )
354        )
355
356        buffer_0 = BytesIO()
357        with PackageExporter(buffer_0) as e:
358            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
359
360        buffer_0.seek(0)
361        importer_0 = PackageImporter(buffer_0)
362        loaded_module_0 = importer_0.load_pickle("res", "mod1.pkl")
363
364        buffer_1 = BytesIO()
365        with PackageExporter(buffer_1) as e:
366            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
367            e.save_pickle("res", "mod2.pkl", loaded_module_0)
368
369        buffer_1.seek(0)
370        importer_1 = PackageImporter(buffer_1)
371        loaded_module_1 = importer_1.load_pickle("res", "mod1.pkl")
372        reloaded_module_0 = importer_1.load_pickle("res", "mod2.pkl")
373
374        input = torch.rand(1, 2, 3)
375        self.assertEqual(loaded_module_0(input), scripted_mod_0(input))
376        self.assertEqual(loaded_module_0(input), reloaded_module_0(input))
377        self.assertEqual(loaded_module_1(input), scripted_mod_1(input))
378
379    @skipIfNoTorchVision
380    def test_save_scriptmodule_only_necessary_code(self):
381        """
382        Test to verify when saving multiple packages with same CU
383        that packages don't include unnecessary torchscript code files.
384        The TorchVision code should only be saved in the package that
385        relies on it.
386        """
387        from package_a.test_module import ModWithTensor
388
389        class ModWithTorchVision(torch.nn.Module):
390            def __init__(self, name: str):
391                super().__init__()
392                self.tvmod = resnet18()
393
394            def forward(self, input):
395                return input * 4
396
397        scripted_mod_0 = torch.jit.script(ModWithTorchVision("foo"))
398        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
399
400        buffer_0 = BytesIO()
401        with PackageExporter(buffer_0) as e:
402            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
403
404        buffer_0.seek(0)
405        importer_0 = importer = PackageImporter(buffer_0)
406
407        buffer_1 = BytesIO()
408        with PackageExporter(buffer_1) as e:
409            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
410
411        buffer_1.seek(0)
412        importer_1 = PackageImporter(buffer_1)
413
414        self.assertTrue("torchvision" in str(importer_0.file_structure()))
415        self.assertFalse("torchvision" in str(importer_1.file_structure()))
416
417    def test_save_scriptmodules_in_container(self):
418        """
419        Test saving of ScriptModules inside of container. Checks that relations
420        between shared modules are upheld.
421        """
422        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
423
424        scripted_mod_a = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
425        scripted_mod_b = torch.jit.script(
426            ModWithSubmodAndTensor(torch.rand(1, 2, 3), scripted_mod_a)
427        )
428        script_mods_list = [scripted_mod_a, scripted_mod_b]
429
430        buffer = BytesIO()
431        with PackageExporter(buffer) as e:
432            e.save_pickle("res", "list.pkl", script_mods_list)
433
434        buffer.seek(0)
435        importer = PackageImporter(buffer)
436        loaded_mod_list = importer.load_pickle("res", "list.pkl")
437        input = torch.rand(1, 2, 3)
438        self.assertEqual(loaded_mod_list[0](input), scripted_mod_a(input))
439        self.assertEqual(loaded_mod_list[1](input), scripted_mod_b(input))
440
441    def test_save_eager_mods_sharing_scriptmodule(self):
442        """
443        Test saving of single ScriptModule shared by multiple
444        eager modules (ScriptModule should be saved just once
445        even though is contained in multiple pickles).
446        """
447        from package_a.test_module import ModWithSubmod, SimpleTest
448
449        scripted_mod = torch.jit.script(SimpleTest())
450
451        mod1 = ModWithSubmod(scripted_mod)
452        mod2 = ModWithSubmod(scripted_mod)
453
454        buffer = BytesIO()
455        with PackageExporter(buffer) as e:
456            e.intern("**")
457            e.save_pickle("res", "mod1.pkl", mod1)
458            e.save_pickle("res", "mod2.pkl", mod2)
459
460        buffer.seek(0)
461        importer = PackageImporter(buffer)
462        file_structure = importer.file_structure()
463        self.assertTrue(file_structure.has_file(".data/ts_code/0"))
464        self.assertFalse(file_structure.has_file(".data/ts_code/1"))
465
466    def test_load_shared_scriptmodules(self):
467        """
468        Test loading of single ScriptModule shared by multiple eager
469        modules in single pickle (ScriptModule objects should be the same).
470        """
471        from package_a.test_module import (
472            ModWithMultipleSubmods,
473            ModWithSubmod,
474            SimpleTest,
475        )
476
477        scripted_mod = torch.jit.script(SimpleTest())
478
479        mod1 = ModWithSubmod(scripted_mod)
480        mod2 = ModWithSubmod(scripted_mod)
481
482        mod_parent = ModWithMultipleSubmods(mod1, mod2)
483
484        buffer = BytesIO()
485        with PackageExporter(buffer) as e:
486            e.intern("**")
487            e.save_pickle("res", "mod.pkl", mod_parent)
488
489        buffer.seek(0)
490        importer = PackageImporter(buffer)
491
492        loaded_mod = importer.load_pickle("res", "mod.pkl")
493        self.assertTrue(
494            id(loaded_mod.mod1.script_mod) == id(loaded_mod.mod2.script_mod)
495        )
496
497    def test_save_shared_tensors(self):
498        """
499        Test tensors shared across eager and ScriptModules are serialized once.
500        """
501        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
502
503        shared_tensor = torch.rand(2, 3, 4)
504        scripted_mod = torch.jit.script(ModWithTensor(shared_tensor))
505
506        mod1 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)
507        mod2 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)
508
509        buffer = BytesIO()
510        with PackageExporter(buffer) as e:
511            e.intern("**")
512            e.save_pickle("res", "tensor", shared_tensor)
513            e.save_pickle("res", "mod1.pkl", mod1)
514            e.save_pickle("res", "mod2.pkl", mod2)
515
516        buffer.seek(0)
517        importer = PackageImporter(buffer)
518        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
519
520        # assert that there is only one storage stored in package
521        file_structure = importer.file_structure(include=".data/*.storage")
522        self.assertTrue(len(file_structure.children[".data"].children) == 1)
523
524        input = torch.rand(2, 3, 4)
525        self.assertEqual(loaded_mod_1(input), mod1(input))
526
527    def test_load_shared_tensors(self):
528        """
529        Test tensors shared across eager and ScriptModules on load
530        are the same.
531        """
532        from package_a.test_module import ModWithTensor, ModWithTwoSubmodsAndTensor
533
534        shared_tensor = torch.ones(3, 3)
535
536        scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
537        scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))
538
539        mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)
540
541        self.assertEqual(
542            shared_tensor.storage()._cdata,
543            scripted_mod_0.tensor.storage()._cdata,
544        )
545        self.assertEqual(
546            shared_tensor.storage()._cdata,
547            scripted_mod_1.tensor.storage()._cdata,
548        )
549
550        buffer = BytesIO()
551        with PackageExporter(buffer) as e:
552            e.intern("**")
553            e.save_pickle("res", "mod1.pkl", mod1)
554
555        buffer.seek(0)
556        importer = PackageImporter(buffer)
557        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
558
559        self.assertEqual(
560            loaded_mod_1.tensor.storage()._cdata,
561            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
562        )
563        self.assertEqual(
564            loaded_mod_1.tensor.storage()._cdata,
565            loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
566        )
567
568        loaded_mod_1.tensor.add_(torch.ones(3, 3))
569
570        self.assertTrue(
571            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
572        )
573        self.assertTrue(
574            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
575        )
576
577    def test_load_shared_tensors_repackaged(self):
578        """
579        Test tensors shared across eager and ScriptModules on load
580        are the same across multiple package saves and loads. This is
581        an important test because not all of the tensor information is restored
582        in python between packages. The python identity is not maintained, but
583        the backing cpp TensorImpl is. We load/save storages based off of this
584        cpp TensorImpl and not the python identity.
585        """
586        from package_a.test_module import ModWithTensor, ModWithTwoSubmodsAndTensor
587
588        shared_tensor = torch.ones(3, 3)
589
590        scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
591        scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))
592
593        mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)
594
595        buffer_0 = BytesIO()
596        with PackageExporter(buffer_0) as e:
597            e.intern("**")
598            e.save_pickle("res", "mod1.pkl", mod1)
599
600        buffer_0.seek(0)
601        importer_0 = PackageImporter(buffer_0)
602        loaded_mod_0 = importer_0.load_pickle("res", "mod1.pkl")
603
604        buffer_1 = BytesIO()
605        with PackageExporter(buffer_1, importer=importer_0) as e:
606            e.intern("**")
607            e.save_pickle("res", "mod1.pkl", loaded_mod_0)
608
609        buffer_1.seek(0)
610        importer = PackageImporter(buffer_1)
611        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
612
613        self.assertEqual(
614            loaded_mod_1.tensor.storage()._cdata,
615            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
616        )
617        self.assertEqual(
618            loaded_mod_1.tensor.storage()._cdata,
619            loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
620        )
621
622        loaded_mod_1.tensor.add_(
623            torch.ones(3, 3)
624        )  # all tensors should reflect this change
625
626        self.assertTrue(
627            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
628        )
629        self.assertTrue(
630            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
631        )
632
633    def test_saving_and_scripting_packaged_mod(self):
634        """
635        Test scripting a module loaded from a package
636        and saving it in a new package as a script object.
637        """
638        from package_a.test_module import SimpleTest
639
640        orig_mod = SimpleTest()
641
642        buffer_0 = BytesIO()
643        with PackageExporter(buffer_0) as e:
644            e.intern("**")
645            e.save_pickle("model", "model.pkl", orig_mod)
646
647        buffer_0.seek(0)
648        importer_0 = PackageImporter(buffer_0)
649        loaded_mod = importer_0.load_pickle("model", "model.pkl")
650
651        input = torch.rand(2, 3)
652        self.assertEqual(loaded_mod(input), orig_mod(input))
653
654        scripted_mod = torch.jit.script(loaded_mod)
655
656        buffer_1 = BytesIO()
657        with PackageExporter(buffer_1, importer=importer_0) as e:
658            e.intern("**")
659            e.save_pickle("res", "scripted_mod.pkl", scripted_mod)
660
661        buffer_1.seek(0)
662        importer_1 = PackageImporter(buffer_1)
663        loaded_mod_scripted = importer_1.load_pickle("res", "scripted_mod.pkl")
664
665        self.assertEqual(loaded_mod_scripted(input), orig_mod(input))
666
667    def test_mixing_packaged_and_inline_modules(self):
668        """
669        Test saving inline and imported modules in same package with
670        independent code.
671        """
672
673        class InlineMod(torch.nn.Module):
674            def __init__(self, name: str):
675                super().__init__()
676                self.name = name
677                self.tensor = torch.rand(1, 2, 3)
678
679            def forward(self, input: str):
680                input = input + "_modInline:" + self.name
681                return input, (self.tensor * 4)
682
683        inline_mod = InlineMod("inline")
684        scripted_inline = torch.jit.script(inline_mod)
685
686        from package_a.test_module import SimpleTest
687
688        imported_mod = SimpleTest()
689        scripted_imported = torch.jit.script(imported_mod)
690
691        buffer = BytesIO()
692        with PackageExporter(buffer) as e:
693            e.save_pickle("model", "inline.pkl", scripted_inline)
694            e.save_pickle("model", "imported.pkl", scripted_imported)
695
696        buffer.seek(0)
697        importer = PackageImporter(buffer)
698        loaded_inline = importer.load_pickle("model", "inline.pkl")
699        loaded_imported = importer.load_pickle("model", "imported.pkl")
700
701        input = torch.rand(2, 3)
702        self.assertEqual(loaded_imported(input), imported_mod(input))
703        self.assertEqual(loaded_inline("input"), inline_mod("input"))
704
705    @skipIfNoTorchVision
706    def test_mixing_packaged_and_inline_modules_shared_code(self):
707        """
708        Test saving inline and imported modules in same package that
709        share code.
710        """
711
712        class TorchVisionTestInline(torch.nn.Module):
713            def __init__(self) -> None:
714                super().__init__()
715                self.tvmod = resnet18()
716
717            def forward(self, x):
718                x = a_non_torch_leaf(x, x)
719                return torch.relu(x + 3.0)
720
721        def a_non_torch_leaf(a, b):
722            return a + b
723
724        inline_mod = TorchVisionTestInline()
725        scripted_inline = torch.jit.script(inline_mod)
726
727        from package_c.test_module import TorchVisionTest
728
729        imported_mod = TorchVisionTest()
730        scripted_imported = torch.jit.script(imported_mod)
731
732        buffer = BytesIO()
733        with PackageExporter(buffer) as e:
734            e.save_pickle("model", "inline.pkl", scripted_inline)
735            e.save_pickle("model", "imported.pkl", scripted_imported)
736
737        buffer.seek(0)
738        importer = PackageImporter(buffer)
739        loaded_inline = importer.load_pickle("model", "inline.pkl")
740        loaded_imported = importer.load_pickle("model", "imported.pkl")
741
742        input = torch.rand(2, 3)
743        self.assertEqual(loaded_imported(input), imported_mod(input))
744        self.assertEqual(loaded_inline(input), inline_mod(input))
745
746    def test_tensor_sharing_pickle(self):
747        """Test that saving a ScriptModule and a separately saving a tensor
748        object causes no issues.
749        """
750
751        class M(torch.nn.Module):
752            def __init__(self) -> None:
753                super().__init__()
754                self.foo = torch.ones(2, 3)
755
756            def forward(self):
757                return self.foo
758
759        scripted_m = torch.jit.script(M())
760        original_tensor = torch.ones(0)
761
762        f = BytesIO()
763        with torch.package.PackageExporter(f) as exporter:
764            exporter.save_pickle("model", "model.pkl", scripted_m)
765            exporter.save_pickle("model", "input.pkl", original_tensor)
766
767        f.seek(0)
768        # Should be able to load correctly
769        importer = PackageImporter(f)
770        loaded_m = importer.load_pickle("model", "model.pkl")
771        loaded_tensor = importer.load_pickle("model", "input.pkl")
772
773        self.assertEqual(scripted_m.foo, loaded_m.foo)
774        self.assertEqual(original_tensor, loaded_tensor)
775
776
777if __name__ == "__main__":
778    run_tests()
779