1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom io import BytesIO 4*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent 5*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch.package import PackageExporter, PackageImporter, sys_importer 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workertry: 13*da0073e9SAndroid Build Coastguard Worker from torchvision.models import resnet18 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = True 16*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 17*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 18*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision") 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workertry: 21*da0073e9SAndroid Build Coastguard Worker from .common import PackageTestCase 22*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 23*da0073e9SAndroid Build Coastguard Worker # Support the case where we run this file directly. 24*da0073e9SAndroid Build Coastguard Worker from common import PackageTestCase 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker@skipIf( 28*da0073e9SAndroid Build Coastguard Worker True, 29*da0073e9SAndroid Build Coastguard Worker "Does not work with recent torchvision, see https://github.com/pytorch/pytorch/issues/81115", 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Worker@skipIfNoTorchVision 32*da0073e9SAndroid Build Coastguard Workerclass ModelTest(PackageTestCase): 33*da0073e9SAndroid Build Coastguard Worker """End-to-end tests packaging an entire model.""" 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker @skipIf( 36*da0073e9SAndroid Build Coastguard Worker IS_FBCODE or IS_SANDCASTLE, 37*da0073e9SAndroid Build Coastguard Worker "Tests that use temporary files are disabled in fbcode", 38*da0073e9SAndroid Build Coastguard Worker ) 39*da0073e9SAndroid Build Coastguard Worker def test_resnet(self): 40*da0073e9SAndroid Build Coastguard Worker resnet = resnet18() 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker f1 = self.temp() 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker # create a package that will save it along with its code 45*da0073e9SAndroid Build Coastguard Worker with PackageExporter(f1) as e: 46*da0073e9SAndroid Build Coastguard Worker # put the pickled resnet in the package, by default 47*da0073e9SAndroid Build Coastguard Worker # this will also save all the code files references by 48*da0073e9SAndroid Build Coastguard Worker # the objects in the pickle 49*da0073e9SAndroid Build Coastguard Worker e.intern("**") 50*da0073e9SAndroid Build Coastguard Worker e.save_pickle("model", "model.pkl", resnet) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker # we can now load the saved model 53*da0073e9SAndroid Build Coastguard Worker i = PackageImporter(f1) 54*da0073e9SAndroid Build Coastguard Worker r2 = i.load_pickle("model", "model.pkl") 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker # test that it works 57*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 3, 224, 224) 58*da0073e9SAndroid Build Coastguard Worker ref = resnet(input) 59*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2(input), ref) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker # functions exist also to get at the private modules in each package 62*da0073e9SAndroid Build Coastguard Worker torchvision = i.import_module("torchvision") 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker f2 = BytesIO() 65*da0073e9SAndroid Build Coastguard Worker # if we are doing transfer learning we might want to re-save 66*da0073e9SAndroid Build Coastguard Worker # things that were loaded from a package. 67*da0073e9SAndroid Build Coastguard Worker # We need to tell the exporter about any modules that 68*da0073e9SAndroid Build Coastguard Worker # came from imported packages so that it can resolve 69*da0073e9SAndroid Build Coastguard Worker # class names like torchvision.models.resnet.ResNet 70*da0073e9SAndroid Build Coastguard Worker # to their source code. 71*da0073e9SAndroid Build Coastguard Worker with PackageExporter(f2, importer=(i, sys_importer)) as e: 72*da0073e9SAndroid Build Coastguard Worker # e.importers is a list of module importing functions 73*da0073e9SAndroid Build Coastguard Worker # that by default contains importlib.import_module. 74*da0073e9SAndroid Build Coastguard Worker # it is searched in order until the first success and 75*da0073e9SAndroid Build Coastguard Worker # that module is taken to be what torchvision.models.resnet 76*da0073e9SAndroid Build Coastguard Worker # should be in this code package. In the case of name collisions, 77*da0073e9SAndroid Build Coastguard Worker # such as trying to save a ResNet from two different packages, 78*da0073e9SAndroid Build Coastguard Worker # we take the first thing found in the path, so only ResNet objects from 79*da0073e9SAndroid Build Coastguard Worker # one importer will work. This avoids a bunch of name mangling in 80*da0073e9SAndroid Build Coastguard Worker # the source code. If you need to actually mix ResNet objects, 81*da0073e9SAndroid Build Coastguard Worker # we suggest reconstructing the model objects using code from a single package 82*da0073e9SAndroid Build Coastguard Worker # using functions like save_state_dict and load_state_dict to transfer state 83*da0073e9SAndroid Build Coastguard Worker # to the correct code objects. 84*da0073e9SAndroid Build Coastguard Worker e.intern("**") 85*da0073e9SAndroid Build Coastguard Worker e.save_pickle("model", "model.pkl", r2) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker f2.seek(0) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker i2 = PackageImporter(f2) 90*da0073e9SAndroid Build Coastguard Worker r3 = i2.load_pickle("model", "model.pkl") 91*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r3(input), ref) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 94*da0073e9SAndroid Build Coastguard Worker def test_model_save(self): 95*da0073e9SAndroid Build Coastguard Worker # This example shows how you might package a model 96*da0073e9SAndroid Build Coastguard Worker # so that the creator of the model has flexibility about 97*da0073e9SAndroid Build Coastguard Worker # how they want to save it but the 'server' can always 98*da0073e9SAndroid Build Coastguard Worker # use the same API to load the package. 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker # The convension is for each model to provide a 101*da0073e9SAndroid Build Coastguard Worker # 'model' package with a 'load' function that actual 102*da0073e9SAndroid Build Coastguard Worker # reads the model out of the archive. 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker # How the load function is implemented is up to the 105*da0073e9SAndroid Build Coastguard Worker # the packager. 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker # get our normal torchvision resnet 108*da0073e9SAndroid Build Coastguard Worker resnet = resnet18() 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker f1 = BytesIO() 111*da0073e9SAndroid Build Coastguard Worker # Option 1: save by pickling the whole model 112*da0073e9SAndroid Build Coastguard Worker # + single-line, similar to torch.jit.save 113*da0073e9SAndroid Build Coastguard Worker # - more difficult to edit the code after the model is created 114*da0073e9SAndroid Build Coastguard Worker with PackageExporter(f1) as e: 115*da0073e9SAndroid Build Coastguard Worker e.intern("**") 116*da0073e9SAndroid Build Coastguard Worker e.save_pickle("model", "pickled", resnet) 117*da0073e9SAndroid Build Coastguard Worker # note that this source is the same for all models in this approach 118*da0073e9SAndroid Build Coastguard Worker # so it can be made part of an API that just takes the model and 119*da0073e9SAndroid Build Coastguard Worker # packages it with this source. 120*da0073e9SAndroid Build Coastguard Worker src = dedent( 121*da0073e9SAndroid Build Coastguard Worker """\ 122*da0073e9SAndroid Build Coastguard Worker import importlib 123*da0073e9SAndroid Build Coastguard Worker import torch_package_importer as resources 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker # server knows to call model.load() to get the model, 126*da0073e9SAndroid Build Coastguard Worker # maybe in the future it passes options as arguments by convension 127*da0073e9SAndroid Build Coastguard Worker def load(): 128*da0073e9SAndroid Build Coastguard Worker return resources.load_pickle('model', 'pickled') 129*da0073e9SAndroid Build Coastguard Worker """ 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker e.save_source_string("model", src, is_package=True) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker f2 = BytesIO() 134*da0073e9SAndroid Build Coastguard Worker # Option 2: save with state dict 135*da0073e9SAndroid Build Coastguard Worker # - more code to write to save/load the model 136*da0073e9SAndroid Build Coastguard Worker # + but this code can be edited later to adjust adapt the model later 137*da0073e9SAndroid Build Coastguard Worker with PackageExporter(f2) as e: 138*da0073e9SAndroid Build Coastguard Worker e.intern("**") 139*da0073e9SAndroid Build Coastguard Worker e.save_pickle("model", "state_dict", resnet.state_dict()) 140*da0073e9SAndroid Build Coastguard Worker src = dedent( 141*da0073e9SAndroid Build Coastguard Worker """\ 142*da0073e9SAndroid Build Coastguard Worker import importlib 143*da0073e9SAndroid Build Coastguard Worker import torch_package_importer as resources 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker from torchvision.models.resnet import resnet18 146*da0073e9SAndroid Build Coastguard Worker def load(): 147*da0073e9SAndroid Build Coastguard Worker # if you want, you can later edit how resnet is constructed here 148*da0073e9SAndroid Build Coastguard Worker # to edit the model in the package, while still loading the original 149*da0073e9SAndroid Build Coastguard Worker # state dict weights 150*da0073e9SAndroid Build Coastguard Worker r = resnet18() 151*da0073e9SAndroid Build Coastguard Worker state_dict = resources.load_pickle('model', 'state_dict') 152*da0073e9SAndroid Build Coastguard Worker r.load_state_dict(state_dict) 153*da0073e9SAndroid Build Coastguard Worker return r 154*da0073e9SAndroid Build Coastguard Worker """ 155*da0073e9SAndroid Build Coastguard Worker ) 156*da0073e9SAndroid Build Coastguard Worker e.save_source_string("model", src, is_package=True) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker # regardless of how we chose to package, we can now use the model in a server in the same way 159*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 3, 224, 224) 160*da0073e9SAndroid Build Coastguard Worker results = [] 161*da0073e9SAndroid Build Coastguard Worker for m in [f1, f2]: 162*da0073e9SAndroid Build Coastguard Worker m.seek(0) 163*da0073e9SAndroid Build Coastguard Worker importer = PackageImporter(m) 164*da0073e9SAndroid Build Coastguard Worker the_model = importer.import_module("model").load() 165*da0073e9SAndroid Build Coastguard Worker r = the_model(input) 166*da0073e9SAndroid Build Coastguard Worker results.append(r) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(*results) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 171*da0073e9SAndroid Build Coastguard Worker def test_script_resnet(self): 172*da0073e9SAndroid Build Coastguard Worker resnet = resnet18() 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker f1 = BytesIO() 175*da0073e9SAndroid Build Coastguard Worker # Option 1: save by pickling the whole model 176*da0073e9SAndroid Build Coastguard Worker # + single-line, similar to torch.jit.save 177*da0073e9SAndroid Build Coastguard Worker # - more difficult to edit the code after the model is created 178*da0073e9SAndroid Build Coastguard Worker with PackageExporter(f1) as e: 179*da0073e9SAndroid Build Coastguard Worker e.intern("**") 180*da0073e9SAndroid Build Coastguard Worker e.save_pickle("model", "pickled", resnet) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker f1.seek(0) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker i = PackageImporter(f1) 185*da0073e9SAndroid Build Coastguard Worker loaded = i.load_pickle("model", "pickled") 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker # Model should script successfully. 188*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(loaded) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker # Scripted model should save and load successfully. 191*da0073e9SAndroid Build Coastguard Worker f2 = BytesIO() 192*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted, f2) 193*da0073e9SAndroid Build Coastguard Worker f2.seek(0) 194*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(f2) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 3, 224, 224) 197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded(input), resnet(input)) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 201*da0073e9SAndroid Build Coastguard Worker run_tests() 202