# Owner(s): ["oncall: package/deploy"] from io import BytesIO from textwrap import dedent from unittest import skipIf import torch from torch.package import PackageExporter, PackageImporter, sys_importer from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests try: from torchvision.models import resnet18 HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision") try: from .common import PackageTestCase except ImportError: # Support the case where we run this file directly. from common import PackageTestCase @skipIf( True, "Does not work with recent torchvision, see https://github.com/pytorch/pytorch/issues/81115", ) @skipIfNoTorchVision class ModelTest(PackageTestCase): """End-to-end tests packaging an entire model.""" @skipIf( IS_FBCODE or IS_SANDCASTLE, "Tests that use temporary files are disabled in fbcode", ) def test_resnet(self): resnet = resnet18() f1 = self.temp() # create a package that will save it along with its code with PackageExporter(f1) as e: # put the pickled resnet in the package, by default # this will also save all the code files references by # the objects in the pickle e.intern("**") e.save_pickle("model", "model.pkl", resnet) # we can now load the saved model i = PackageImporter(f1) r2 = i.load_pickle("model", "model.pkl") # test that it works input = torch.rand(1, 3, 224, 224) ref = resnet(input) self.assertEqual(r2(input), ref) # functions exist also to get at the private modules in each package torchvision = i.import_module("torchvision") f2 = BytesIO() # if we are doing transfer learning we might want to re-save # things that were loaded from a package. # We need to tell the exporter about any modules that # came from imported packages so that it can resolve # class names like torchvision.models.resnet.ResNet # to their source code. with PackageExporter(f2, importer=(i, sys_importer)) as e: # e.importers is a list of module importing functions # that by default contains importlib.import_module. # it is searched in order until the first success and # that module is taken to be what torchvision.models.resnet # should be in this code package. In the case of name collisions, # such as trying to save a ResNet from two different packages, # we take the first thing found in the path, so only ResNet objects from # one importer will work. This avoids a bunch of name mangling in # the source code. If you need to actually mix ResNet objects, # we suggest reconstructing the model objects using code from a single package # using functions like save_state_dict and load_state_dict to transfer state # to the correct code objects. e.intern("**") e.save_pickle("model", "model.pkl", r2) f2.seek(0) i2 = PackageImporter(f2) r3 = i2.load_pickle("model", "model.pkl") self.assertEqual(r3(input), ref) @skipIfNoTorchVision def test_model_save(self): # This example shows how you might package a model # so that the creator of the model has flexibility about # how they want to save it but the 'server' can always # use the same API to load the package. # The convension is for each model to provide a # 'model' package with a 'load' function that actual # reads the model out of the archive. # How the load function is implemented is up to the # the packager. # get our normal torchvision resnet resnet = resnet18() f1 = BytesIO() # Option 1: save by pickling the whole model # + single-line, similar to torch.jit.save # - more difficult to edit the code after the model is created with PackageExporter(f1) as e: e.intern("**") e.save_pickle("model", "pickled", resnet) # note that this source is the same for all models in this approach # so it can be made part of an API that just takes the model and # packages it with this source. src = dedent( """\ import importlib import torch_package_importer as resources # server knows to call model.load() to get the model, # maybe in the future it passes options as arguments by convension def load(): return resources.load_pickle('model', 'pickled') """ ) e.save_source_string("model", src, is_package=True) f2 = BytesIO() # Option 2: save with state dict # - more code to write to save/load the model # + but this code can be edited later to adjust adapt the model later with PackageExporter(f2) as e: e.intern("**") e.save_pickle("model", "state_dict", resnet.state_dict()) src = dedent( """\ import importlib import torch_package_importer as resources from torchvision.models.resnet import resnet18 def load(): # if you want, you can later edit how resnet is constructed here # to edit the model in the package, while still loading the original # state dict weights r = resnet18() state_dict = resources.load_pickle('model', 'state_dict') r.load_state_dict(state_dict) return r """ ) e.save_source_string("model", src, is_package=True) # regardless of how we chose to package, we can now use the model in a server in the same way input = torch.rand(1, 3, 224, 224) results = [] for m in [f1, f2]: m.seek(0) importer = PackageImporter(m) the_model = importer.import_module("model").load() r = the_model(input) results.append(r) self.assertEqual(*results) @skipIfNoTorchVision def test_script_resnet(self): resnet = resnet18() f1 = BytesIO() # Option 1: save by pickling the whole model # + single-line, similar to torch.jit.save # - more difficult to edit the code after the model is created with PackageExporter(f1) as e: e.intern("**") e.save_pickle("model", "pickled", resnet) f1.seek(0) i = PackageImporter(f1) loaded = i.load_pickle("model", "pickled") # Model should script successfully. scripted = torch.jit.script(loaded) # Scripted model should save and load successfully. f2 = BytesIO() torch.jit.save(scripted, f2) f2.seek(0) loaded = torch.jit.load(f2) input = torch.rand(1, 3, 224, 224) self.assertEqual(loaded(input), resnet(input)) if __name__ == "__main__": run_tests()