1# Owner(s): ["oncall: package/deploy"] 2 3from io import BytesIO 4from textwrap import dedent 5from unittest import skipIf 6 7import torch 8from torch.package import PackageExporter, PackageImporter, sys_importer 9from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests 10 11 12try: 13 from torchvision.models import resnet18 14 15 HAS_TORCHVISION = True 16except ImportError: 17 HAS_TORCHVISION = False 18skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision") 19 20try: 21 from .common import PackageTestCase 22except ImportError: 23 # Support the case where we run this file directly. 24 from common import PackageTestCase 25 26 27@skipIf( 28 True, 29 "Does not work with recent torchvision, see https://github.com/pytorch/pytorch/issues/81115", 30) 31@skipIfNoTorchVision 32class ModelTest(PackageTestCase): 33 """End-to-end tests packaging an entire model.""" 34 35 @skipIf( 36 IS_FBCODE or IS_SANDCASTLE, 37 "Tests that use temporary files are disabled in fbcode", 38 ) 39 def test_resnet(self): 40 resnet = resnet18() 41 42 f1 = self.temp() 43 44 # create a package that will save it along with its code 45 with PackageExporter(f1) as e: 46 # put the pickled resnet in the package, by default 47 # this will also save all the code files references by 48 # the objects in the pickle 49 e.intern("**") 50 e.save_pickle("model", "model.pkl", resnet) 51 52 # we can now load the saved model 53 i = PackageImporter(f1) 54 r2 = i.load_pickle("model", "model.pkl") 55 56 # test that it works 57 input = torch.rand(1, 3, 224, 224) 58 ref = resnet(input) 59 self.assertEqual(r2(input), ref) 60 61 # functions exist also to get at the private modules in each package 62 torchvision = i.import_module("torchvision") 63 64 f2 = BytesIO() 65 # if we are doing transfer learning we might want to re-save 66 # things that were loaded from a package. 67 # We need to tell the exporter about any modules that 68 # came from imported packages so that it can resolve 69 # class names like torchvision.models.resnet.ResNet 70 # to their source code. 71 with PackageExporter(f2, importer=(i, sys_importer)) as e: 72 # e.importers is a list of module importing functions 73 # that by default contains importlib.import_module. 74 # it is searched in order until the first success and 75 # that module is taken to be what torchvision.models.resnet 76 # should be in this code package. In the case of name collisions, 77 # such as trying to save a ResNet from two different packages, 78 # we take the first thing found in the path, so only ResNet objects from 79 # one importer will work. This avoids a bunch of name mangling in 80 # the source code. If you need to actually mix ResNet objects, 81 # we suggest reconstructing the model objects using code from a single package 82 # using functions like save_state_dict and load_state_dict to transfer state 83 # to the correct code objects. 84 e.intern("**") 85 e.save_pickle("model", "model.pkl", r2) 86 87 f2.seek(0) 88 89 i2 = PackageImporter(f2) 90 r3 = i2.load_pickle("model", "model.pkl") 91 self.assertEqual(r3(input), ref) 92 93 @skipIfNoTorchVision 94 def test_model_save(self): 95 # This example shows how you might package a model 96 # so that the creator of the model has flexibility about 97 # how they want to save it but the 'server' can always 98 # use the same API to load the package. 99 100 # The convension is for each model to provide a 101 # 'model' package with a 'load' function that actual 102 # reads the model out of the archive. 103 104 # How the load function is implemented is up to the 105 # the packager. 106 107 # get our normal torchvision resnet 108 resnet = resnet18() 109 110 f1 = BytesIO() 111 # Option 1: save by pickling the whole model 112 # + single-line, similar to torch.jit.save 113 # - more difficult to edit the code after the model is created 114 with PackageExporter(f1) as e: 115 e.intern("**") 116 e.save_pickle("model", "pickled", resnet) 117 # note that this source is the same for all models in this approach 118 # so it can be made part of an API that just takes the model and 119 # packages it with this source. 120 src = dedent( 121 """\ 122 import importlib 123 import torch_package_importer as resources 124 125 # server knows to call model.load() to get the model, 126 # maybe in the future it passes options as arguments by convension 127 def load(): 128 return resources.load_pickle('model', 'pickled') 129 """ 130 ) 131 e.save_source_string("model", src, is_package=True) 132 133 f2 = BytesIO() 134 # Option 2: save with state dict 135 # - more code to write to save/load the model 136 # + but this code can be edited later to adjust adapt the model later 137 with PackageExporter(f2) as e: 138 e.intern("**") 139 e.save_pickle("model", "state_dict", resnet.state_dict()) 140 src = dedent( 141 """\ 142 import importlib 143 import torch_package_importer as resources 144 145 from torchvision.models.resnet import resnet18 146 def load(): 147 # if you want, you can later edit how resnet is constructed here 148 # to edit the model in the package, while still loading the original 149 # state dict weights 150 r = resnet18() 151 state_dict = resources.load_pickle('model', 'state_dict') 152 r.load_state_dict(state_dict) 153 return r 154 """ 155 ) 156 e.save_source_string("model", src, is_package=True) 157 158 # regardless of how we chose to package, we can now use the model in a server in the same way 159 input = torch.rand(1, 3, 224, 224) 160 results = [] 161 for m in [f1, f2]: 162 m.seek(0) 163 importer = PackageImporter(m) 164 the_model = importer.import_module("model").load() 165 r = the_model(input) 166 results.append(r) 167 168 self.assertEqual(*results) 169 170 @skipIfNoTorchVision 171 def test_script_resnet(self): 172 resnet = resnet18() 173 174 f1 = BytesIO() 175 # Option 1: save by pickling the whole model 176 # + single-line, similar to torch.jit.save 177 # - more difficult to edit the code after the model is created 178 with PackageExporter(f1) as e: 179 e.intern("**") 180 e.save_pickle("model", "pickled", resnet) 181 182 f1.seek(0) 183 184 i = PackageImporter(f1) 185 loaded = i.load_pickle("model", "pickled") 186 187 # Model should script successfully. 188 scripted = torch.jit.script(loaded) 189 190 # Scripted model should save and load successfully. 191 f2 = BytesIO() 192 torch.jit.save(scripted, f2) 193 f2.seek(0) 194 loaded = torch.jit.load(f2) 195 196 input = torch.rand(1, 3, 224, 224) 197 self.assertEqual(loaded(input), resnet(input)) 198 199 200if __name__ == "__main__": 201 run_tests() 202