xref: /aosp_15_r20/external/pytorch/test/package/test_model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom io import BytesIO
4*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
5*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch.package import PackageExporter, PackageImporter, sys_importer
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workertry:
13*da0073e9SAndroid Build Coastguard Worker    from torchvision.models import resnet18
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = True
16*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
17*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
18*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workertry:
21*da0073e9SAndroid Build Coastguard Worker    from .common import PackageTestCase
22*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
23*da0073e9SAndroid Build Coastguard Worker    # Support the case where we run this file directly.
24*da0073e9SAndroid Build Coastguard Worker    from common import PackageTestCase
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker@skipIf(
28*da0073e9SAndroid Build Coastguard Worker    True,
29*da0073e9SAndroid Build Coastguard Worker    "Does not work with recent torchvision, see https://github.com/pytorch/pytorch/issues/81115",
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Worker@skipIfNoTorchVision
32*da0073e9SAndroid Build Coastguard Workerclass ModelTest(PackageTestCase):
33*da0073e9SAndroid Build Coastguard Worker    """End-to-end tests packaging an entire model."""
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    @skipIf(
36*da0073e9SAndroid Build Coastguard Worker        IS_FBCODE or IS_SANDCASTLE,
37*da0073e9SAndroid Build Coastguard Worker        "Tests that use temporary files are disabled in fbcode",
38*da0073e9SAndroid Build Coastguard Worker    )
39*da0073e9SAndroid Build Coastguard Worker    def test_resnet(self):
40*da0073e9SAndroid Build Coastguard Worker        resnet = resnet18()
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker        f1 = self.temp()
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        # create a package that will save it along with its code
45*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(f1) as e:
46*da0073e9SAndroid Build Coastguard Worker            # put the pickled resnet in the package, by default
47*da0073e9SAndroid Build Coastguard Worker            # this will also save all the code files references by
48*da0073e9SAndroid Build Coastguard Worker            # the objects in the pickle
49*da0073e9SAndroid Build Coastguard Worker            e.intern("**")
50*da0073e9SAndroid Build Coastguard Worker            e.save_pickle("model", "model.pkl", resnet)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        # we can now load the saved model
53*da0073e9SAndroid Build Coastguard Worker        i = PackageImporter(f1)
54*da0073e9SAndroid Build Coastguard Worker        r2 = i.load_pickle("model", "model.pkl")
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        # test that it works
57*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 3, 224, 224)
58*da0073e9SAndroid Build Coastguard Worker        ref = resnet(input)
59*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r2(input), ref)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        # functions exist also to get at the private modules in each package
62*da0073e9SAndroid Build Coastguard Worker        torchvision = i.import_module("torchvision")
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        f2 = BytesIO()
65*da0073e9SAndroid Build Coastguard Worker        # if we are doing transfer learning we might want to re-save
66*da0073e9SAndroid Build Coastguard Worker        # things that were loaded from a package.
67*da0073e9SAndroid Build Coastguard Worker        # We need to tell the exporter about any modules that
68*da0073e9SAndroid Build Coastguard Worker        # came from imported packages so that it can resolve
69*da0073e9SAndroid Build Coastguard Worker        # class names like torchvision.models.resnet.ResNet
70*da0073e9SAndroid Build Coastguard Worker        # to their source code.
71*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(f2, importer=(i, sys_importer)) as e:
72*da0073e9SAndroid Build Coastguard Worker            # e.importers is a list of module importing functions
73*da0073e9SAndroid Build Coastguard Worker            # that by default contains importlib.import_module.
74*da0073e9SAndroid Build Coastguard Worker            # it is searched in order until the first success and
75*da0073e9SAndroid Build Coastguard Worker            # that module is taken to be what torchvision.models.resnet
76*da0073e9SAndroid Build Coastguard Worker            # should be in this code package. In the case of name collisions,
77*da0073e9SAndroid Build Coastguard Worker            # such as trying to save a ResNet from two different packages,
78*da0073e9SAndroid Build Coastguard Worker            # we take the first thing found in the path, so only ResNet objects from
79*da0073e9SAndroid Build Coastguard Worker            # one importer will work. This avoids a bunch of name mangling in
80*da0073e9SAndroid Build Coastguard Worker            # the source code. If you need to actually mix ResNet objects,
81*da0073e9SAndroid Build Coastguard Worker            # we suggest reconstructing the model objects using code from a single package
82*da0073e9SAndroid Build Coastguard Worker            # using functions like save_state_dict and load_state_dict to transfer state
83*da0073e9SAndroid Build Coastguard Worker            # to the correct code objects.
84*da0073e9SAndroid Build Coastguard Worker            e.intern("**")
85*da0073e9SAndroid Build Coastguard Worker            e.save_pickle("model", "model.pkl", r2)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        f2.seek(0)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        i2 = PackageImporter(f2)
90*da0073e9SAndroid Build Coastguard Worker        r3 = i2.load_pickle("model", "model.pkl")
91*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r3(input), ref)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
94*da0073e9SAndroid Build Coastguard Worker    def test_model_save(self):
95*da0073e9SAndroid Build Coastguard Worker        # This example shows how you might package a model
96*da0073e9SAndroid Build Coastguard Worker        # so that the creator of the model has flexibility about
97*da0073e9SAndroid Build Coastguard Worker        # how they want to save it but the 'server' can always
98*da0073e9SAndroid Build Coastguard Worker        # use the same API to load the package.
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        # The convension is for each model to provide a
101*da0073e9SAndroid Build Coastguard Worker        # 'model' package with a 'load' function that actual
102*da0073e9SAndroid Build Coastguard Worker        # reads the model out of the archive.
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        # How the load function is implemented is up to the
105*da0073e9SAndroid Build Coastguard Worker        # the packager.
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        # get our normal torchvision resnet
108*da0073e9SAndroid Build Coastguard Worker        resnet = resnet18()
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        f1 = BytesIO()
111*da0073e9SAndroid Build Coastguard Worker        # Option 1: save by pickling the whole model
112*da0073e9SAndroid Build Coastguard Worker        # + single-line, similar to torch.jit.save
113*da0073e9SAndroid Build Coastguard Worker        # - more difficult to edit the code after the model is created
114*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(f1) as e:
115*da0073e9SAndroid Build Coastguard Worker            e.intern("**")
116*da0073e9SAndroid Build Coastguard Worker            e.save_pickle("model", "pickled", resnet)
117*da0073e9SAndroid Build Coastguard Worker            # note that this source is the same for all models in this approach
118*da0073e9SAndroid Build Coastguard Worker            # so it can be made part of an API that just takes the model and
119*da0073e9SAndroid Build Coastguard Worker            # packages it with this source.
120*da0073e9SAndroid Build Coastguard Worker            src = dedent(
121*da0073e9SAndroid Build Coastguard Worker                """\
122*da0073e9SAndroid Build Coastguard Worker                import importlib
123*da0073e9SAndroid Build Coastguard Worker                import torch_package_importer as resources
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker                # server knows to call model.load() to get the model,
126*da0073e9SAndroid Build Coastguard Worker                # maybe in the future it passes options as arguments by convension
127*da0073e9SAndroid Build Coastguard Worker                def load():
128*da0073e9SAndroid Build Coastguard Worker                    return resources.load_pickle('model', 'pickled')
129*da0073e9SAndroid Build Coastguard Worker                """
130*da0073e9SAndroid Build Coastguard Worker            )
131*da0073e9SAndroid Build Coastguard Worker            e.save_source_string("model", src, is_package=True)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        f2 = BytesIO()
134*da0073e9SAndroid Build Coastguard Worker        # Option 2: save with state dict
135*da0073e9SAndroid Build Coastguard Worker        # - more code to write to save/load the model
136*da0073e9SAndroid Build Coastguard Worker        # + but this code can be edited later to adjust adapt the model later
137*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(f2) as e:
138*da0073e9SAndroid Build Coastguard Worker            e.intern("**")
139*da0073e9SAndroid Build Coastguard Worker            e.save_pickle("model", "state_dict", resnet.state_dict())
140*da0073e9SAndroid Build Coastguard Worker            src = dedent(
141*da0073e9SAndroid Build Coastguard Worker                """\
142*da0073e9SAndroid Build Coastguard Worker                import importlib
143*da0073e9SAndroid Build Coastguard Worker                import torch_package_importer as resources
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker                from torchvision.models.resnet import resnet18
146*da0073e9SAndroid Build Coastguard Worker                def load():
147*da0073e9SAndroid Build Coastguard Worker                    # if you want, you can later edit how resnet is constructed here
148*da0073e9SAndroid Build Coastguard Worker                    # to edit the model in the package, while still loading the original
149*da0073e9SAndroid Build Coastguard Worker                    # state dict weights
150*da0073e9SAndroid Build Coastguard Worker                    r = resnet18()
151*da0073e9SAndroid Build Coastguard Worker                    state_dict = resources.load_pickle('model', 'state_dict')
152*da0073e9SAndroid Build Coastguard Worker                    r.load_state_dict(state_dict)
153*da0073e9SAndroid Build Coastguard Worker                    return r
154*da0073e9SAndroid Build Coastguard Worker                """
155*da0073e9SAndroid Build Coastguard Worker            )
156*da0073e9SAndroid Build Coastguard Worker            e.save_source_string("model", src, is_package=True)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        # regardless of how we chose to package, we can now use the model in a server in the same way
159*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 3, 224, 224)
160*da0073e9SAndroid Build Coastguard Worker        results = []
161*da0073e9SAndroid Build Coastguard Worker        for m in [f1, f2]:
162*da0073e9SAndroid Build Coastguard Worker            m.seek(0)
163*da0073e9SAndroid Build Coastguard Worker            importer = PackageImporter(m)
164*da0073e9SAndroid Build Coastguard Worker            the_model = importer.import_module("model").load()
165*da0073e9SAndroid Build Coastguard Worker            r = the_model(input)
166*da0073e9SAndroid Build Coastguard Worker            results.append(r)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(*results)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
171*da0073e9SAndroid Build Coastguard Worker    def test_script_resnet(self):
172*da0073e9SAndroid Build Coastguard Worker        resnet = resnet18()
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        f1 = BytesIO()
175*da0073e9SAndroid Build Coastguard Worker        # Option 1: save by pickling the whole model
176*da0073e9SAndroid Build Coastguard Worker        # + single-line, similar to torch.jit.save
177*da0073e9SAndroid Build Coastguard Worker        # - more difficult to edit the code after the model is created
178*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(f1) as e:
179*da0073e9SAndroid Build Coastguard Worker            e.intern("**")
180*da0073e9SAndroid Build Coastguard Worker            e.save_pickle("model", "pickled", resnet)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        f1.seek(0)
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        i = PackageImporter(f1)
185*da0073e9SAndroid Build Coastguard Worker        loaded = i.load_pickle("model", "pickled")
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        # Model should script successfully.
188*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(loaded)
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        # Scripted model should save and load successfully.
191*da0073e9SAndroid Build Coastguard Worker        f2 = BytesIO()
192*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted, f2)
193*da0073e9SAndroid Build Coastguard Worker        f2.seek(0)
194*da0073e9SAndroid Build Coastguard Worker        loaded = torch.jit.load(f2)
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 3, 224, 224)
197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded(input), resnet(input))
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
201*da0073e9SAndroid Build Coastguard Worker    run_tests()
202