# Owner(s): ["oncall: package/deploy"] from io import BytesIO import torch from torch.package import ( Importer, OrderedImporter, PackageExporter, PackageImporter, sys_importer, ) from torch.testing._internal.common_utils import run_tests try: from .common import PackageTestCase except ImportError: # Support the case where we run this file directly. from common import PackageTestCase class TestImporter(PackageTestCase): """Tests for Importer and derived classes.""" def test_sys_importer(self): import package_a import package_a.subpackage self.assertIs(sys_importer.import_module("package_a"), package_a) self.assertIs( sys_importer.import_module("package_a.subpackage"), package_a.subpackage ) def test_sys_importer_roundtrip(self): import package_a import package_a.subpackage importer = sys_importer type_ = package_a.subpackage.PackageASubpackageObject module_name, type_name = importer.get_name(type_) module = importer.import_module(module_name) self.assertIs(getattr(module, type_name), type_) def test_single_ordered_importer(self): import module_a # noqa: F401 import package_a buffer = BytesIO() with PackageExporter(buffer) as pe: pe.save_module(package_a.__name__) buffer.seek(0) importer = PackageImporter(buffer) # Construct an importer-only environment. ordered_importer = OrderedImporter(importer) # The module returned by this environment should be the same one that's # in the importer. self.assertIs( ordered_importer.import_module("package_a"), importer.import_module("package_a"), ) # It should not be the one available in the outer Python environment. self.assertIsNot(ordered_importer.import_module("package_a"), package_a) # We didn't package this module, so it should not be available. with self.assertRaises(ModuleNotFoundError): ordered_importer.import_module("module_a") def test_ordered_importer_basic(self): import package_a buffer = BytesIO() with PackageExporter(buffer) as pe: pe.save_module(package_a.__name__) buffer.seek(0) importer = PackageImporter(buffer) ordered_importer_sys_first = OrderedImporter(sys_importer, importer) self.assertIs(ordered_importer_sys_first.import_module("package_a"), package_a) ordered_importer_package_first = OrderedImporter(importer, sys_importer) self.assertIs( ordered_importer_package_first.import_module("package_a"), importer.import_module("package_a"), ) def test_ordered_importer_whichmodule(self): """OrderedImporter's implementation of whichmodule should try each underlying importer's whichmodule in order. """ class DummyImporter(Importer): def __init__(self, whichmodule_return): self._whichmodule_return = whichmodule_return def import_module(self, module_name): raise NotImplementedError def whichmodule(self, obj, name): return self._whichmodule_return class DummyClass: pass dummy_importer_foo = DummyImporter("foo") dummy_importer_bar = DummyImporter("bar") dummy_importer_not_found = DummyImporter( "__main__" ) # __main__ is used as a proxy for "not found" by CPython foo_then_bar = OrderedImporter(dummy_importer_foo, dummy_importer_bar) self.assertEqual(foo_then_bar.whichmodule(DummyClass(), ""), "foo") bar_then_foo = OrderedImporter(dummy_importer_bar, dummy_importer_foo) self.assertEqual(bar_then_foo.whichmodule(DummyClass(), ""), "bar") notfound_then_foo = OrderedImporter( dummy_importer_not_found, dummy_importer_foo ) self.assertEqual(notfound_then_foo.whichmodule(DummyClass(), ""), "foo") def test_package_importer_whichmodule_no_dunder_module(self): """Exercise corner case where we try to pickle an object whose __module__ doesn't exist because it's from a C extension. """ # torch.float16 is an example of such an object: it is a C extension # type for which there is no __module__ defined. The default pickler # finds it using special logic to traverse sys.modules and look up # `float16` on each module (see pickle.py:whichmodule). # # We must ensure that we emulate the same behavior from PackageImporter. my_dtype = torch.float16 # Set up a PackageImporter which has a torch.float16 object pickled: buffer = BytesIO() with PackageExporter(buffer) as exporter: exporter.save_pickle("foo", "foo.pkl", my_dtype) buffer.seek(0) importer = PackageImporter(buffer) my_loaded_dtype = importer.load_pickle("foo", "foo.pkl") # Re-save a package with only our PackageImporter as the importer buffer2 = BytesIO() with PackageExporter(buffer2, importer=importer) as exporter: exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype) buffer2.seek(0) importer2 = PackageImporter(buffer2) my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl") self.assertIs(my_dtype, my_loaded_dtype) self.assertIs(my_dtype, my_loaded_dtype2) if __name__ == "__main__": run_tests()