xref: /aosp_15_r20/external/pytorch/test/package/test_model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from io import BytesIO
4from textwrap import dedent
5from unittest import skipIf
6
7import torch
8from torch.package import PackageExporter, PackageImporter, sys_importer
9from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
10
11
12try:
13    from torchvision.models import resnet18
14
15    HAS_TORCHVISION = True
16except ImportError:
17    HAS_TORCHVISION = False
18skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
19
20try:
21    from .common import PackageTestCase
22except ImportError:
23    # Support the case where we run this file directly.
24    from common import PackageTestCase
25
26
27@skipIf(
28    True,
29    "Does not work with recent torchvision, see https://github.com/pytorch/pytorch/issues/81115",
30)
31@skipIfNoTorchVision
32class ModelTest(PackageTestCase):
33    """End-to-end tests packaging an entire model."""
34
35    @skipIf(
36        IS_FBCODE or IS_SANDCASTLE,
37        "Tests that use temporary files are disabled in fbcode",
38    )
39    def test_resnet(self):
40        resnet = resnet18()
41
42        f1 = self.temp()
43
44        # create a package that will save it along with its code
45        with PackageExporter(f1) as e:
46            # put the pickled resnet in the package, by default
47            # this will also save all the code files references by
48            # the objects in the pickle
49            e.intern("**")
50            e.save_pickle("model", "model.pkl", resnet)
51
52        # we can now load the saved model
53        i = PackageImporter(f1)
54        r2 = i.load_pickle("model", "model.pkl")
55
56        # test that it works
57        input = torch.rand(1, 3, 224, 224)
58        ref = resnet(input)
59        self.assertEqual(r2(input), ref)
60
61        # functions exist also to get at the private modules in each package
62        torchvision = i.import_module("torchvision")
63
64        f2 = BytesIO()
65        # if we are doing transfer learning we might want to re-save
66        # things that were loaded from a package.
67        # We need to tell the exporter about any modules that
68        # came from imported packages so that it can resolve
69        # class names like torchvision.models.resnet.ResNet
70        # to their source code.
71        with PackageExporter(f2, importer=(i, sys_importer)) as e:
72            # e.importers is a list of module importing functions
73            # that by default contains importlib.import_module.
74            # it is searched in order until the first success and
75            # that module is taken to be what torchvision.models.resnet
76            # should be in this code package. In the case of name collisions,
77            # such as trying to save a ResNet from two different packages,
78            # we take the first thing found in the path, so only ResNet objects from
79            # one importer will work. This avoids a bunch of name mangling in
80            # the source code. If you need to actually mix ResNet objects,
81            # we suggest reconstructing the model objects using code from a single package
82            # using functions like save_state_dict and load_state_dict to transfer state
83            # to the correct code objects.
84            e.intern("**")
85            e.save_pickle("model", "model.pkl", r2)
86
87        f2.seek(0)
88
89        i2 = PackageImporter(f2)
90        r3 = i2.load_pickle("model", "model.pkl")
91        self.assertEqual(r3(input), ref)
92
93    @skipIfNoTorchVision
94    def test_model_save(self):
95        # This example shows how you might package a model
96        # so that the creator of the model has flexibility about
97        # how they want to save it but the 'server' can always
98        # use the same API to load the package.
99
100        # The convension is for each model to provide a
101        # 'model' package with a 'load' function that actual
102        # reads the model out of the archive.
103
104        # How the load function is implemented is up to the
105        # the packager.
106
107        # get our normal torchvision resnet
108        resnet = resnet18()
109
110        f1 = BytesIO()
111        # Option 1: save by pickling the whole model
112        # + single-line, similar to torch.jit.save
113        # - more difficult to edit the code after the model is created
114        with PackageExporter(f1) as e:
115            e.intern("**")
116            e.save_pickle("model", "pickled", resnet)
117            # note that this source is the same for all models in this approach
118            # so it can be made part of an API that just takes the model and
119            # packages it with this source.
120            src = dedent(
121                """\
122                import importlib
123                import torch_package_importer as resources
124
125                # server knows to call model.load() to get the model,
126                # maybe in the future it passes options as arguments by convension
127                def load():
128                    return resources.load_pickle('model', 'pickled')
129                """
130            )
131            e.save_source_string("model", src, is_package=True)
132
133        f2 = BytesIO()
134        # Option 2: save with state dict
135        # - more code to write to save/load the model
136        # + but this code can be edited later to adjust adapt the model later
137        with PackageExporter(f2) as e:
138            e.intern("**")
139            e.save_pickle("model", "state_dict", resnet.state_dict())
140            src = dedent(
141                """\
142                import importlib
143                import torch_package_importer as resources
144
145                from torchvision.models.resnet import resnet18
146                def load():
147                    # if you want, you can later edit how resnet is constructed here
148                    # to edit the model in the package, while still loading the original
149                    # state dict weights
150                    r = resnet18()
151                    state_dict = resources.load_pickle('model', 'state_dict')
152                    r.load_state_dict(state_dict)
153                    return r
154                """
155            )
156            e.save_source_string("model", src, is_package=True)
157
158        # regardless of how we chose to package, we can now use the model in a server in the same way
159        input = torch.rand(1, 3, 224, 224)
160        results = []
161        for m in [f1, f2]:
162            m.seek(0)
163            importer = PackageImporter(m)
164            the_model = importer.import_module("model").load()
165            r = the_model(input)
166            results.append(r)
167
168        self.assertEqual(*results)
169
170    @skipIfNoTorchVision
171    def test_script_resnet(self):
172        resnet = resnet18()
173
174        f1 = BytesIO()
175        # Option 1: save by pickling the whole model
176        # + single-line, similar to torch.jit.save
177        # - more difficult to edit the code after the model is created
178        with PackageExporter(f1) as e:
179            e.intern("**")
180            e.save_pickle("model", "pickled", resnet)
181
182        f1.seek(0)
183
184        i = PackageImporter(f1)
185        loaded = i.load_pickle("model", "pickled")
186
187        # Model should script successfully.
188        scripted = torch.jit.script(loaded)
189
190        # Scripted model should save and load successfully.
191        f2 = BytesIO()
192        torch.jit.save(scripted, f2)
193        f2.seek(0)
194        loaded = torch.jit.load(f2)
195
196        input = torch.rand(1, 3, 224, 224)
197        self.assertEqual(loaded(input), resnet(input))
198
199
200if __name__ == "__main__":
201    run_tests()
202