1# Owner(s): ["oncall: package/deploy"] 2 3from io import BytesIO 4 5import torch 6from torch.package import ( 7 Importer, 8 OrderedImporter, 9 PackageExporter, 10 PackageImporter, 11 sys_importer, 12) 13from torch.testing._internal.common_utils import run_tests 14 15 16try: 17 from .common import PackageTestCase 18except ImportError: 19 # Support the case where we run this file directly. 20 from common import PackageTestCase 21 22 23class TestImporter(PackageTestCase): 24 """Tests for Importer and derived classes.""" 25 26 def test_sys_importer(self): 27 import package_a 28 import package_a.subpackage 29 30 self.assertIs(sys_importer.import_module("package_a"), package_a) 31 self.assertIs( 32 sys_importer.import_module("package_a.subpackage"), package_a.subpackage 33 ) 34 35 def test_sys_importer_roundtrip(self): 36 import package_a 37 import package_a.subpackage 38 39 importer = sys_importer 40 type_ = package_a.subpackage.PackageASubpackageObject 41 module_name, type_name = importer.get_name(type_) 42 43 module = importer.import_module(module_name) 44 self.assertIs(getattr(module, type_name), type_) 45 46 def test_single_ordered_importer(self): 47 import module_a # noqa: F401 48 import package_a 49 50 buffer = BytesIO() 51 with PackageExporter(buffer) as pe: 52 pe.save_module(package_a.__name__) 53 54 buffer.seek(0) 55 importer = PackageImporter(buffer) 56 57 # Construct an importer-only environment. 58 ordered_importer = OrderedImporter(importer) 59 60 # The module returned by this environment should be the same one that's 61 # in the importer. 62 self.assertIs( 63 ordered_importer.import_module("package_a"), 64 importer.import_module("package_a"), 65 ) 66 # It should not be the one available in the outer Python environment. 67 self.assertIsNot(ordered_importer.import_module("package_a"), package_a) 68 69 # We didn't package this module, so it should not be available. 70 with self.assertRaises(ModuleNotFoundError): 71 ordered_importer.import_module("module_a") 72 73 def test_ordered_importer_basic(self): 74 import package_a 75 76 buffer = BytesIO() 77 with PackageExporter(buffer) as pe: 78 pe.save_module(package_a.__name__) 79 80 buffer.seek(0) 81 importer = PackageImporter(buffer) 82 83 ordered_importer_sys_first = OrderedImporter(sys_importer, importer) 84 self.assertIs(ordered_importer_sys_first.import_module("package_a"), package_a) 85 86 ordered_importer_package_first = OrderedImporter(importer, sys_importer) 87 self.assertIs( 88 ordered_importer_package_first.import_module("package_a"), 89 importer.import_module("package_a"), 90 ) 91 92 def test_ordered_importer_whichmodule(self): 93 """OrderedImporter's implementation of whichmodule should try each 94 underlying importer's whichmodule in order. 95 """ 96 97 class DummyImporter(Importer): 98 def __init__(self, whichmodule_return): 99 self._whichmodule_return = whichmodule_return 100 101 def import_module(self, module_name): 102 raise NotImplementedError 103 104 def whichmodule(self, obj, name): 105 return self._whichmodule_return 106 107 class DummyClass: 108 pass 109 110 dummy_importer_foo = DummyImporter("foo") 111 dummy_importer_bar = DummyImporter("bar") 112 dummy_importer_not_found = DummyImporter( 113 "__main__" 114 ) # __main__ is used as a proxy for "not found" by CPython 115 116 foo_then_bar = OrderedImporter(dummy_importer_foo, dummy_importer_bar) 117 self.assertEqual(foo_then_bar.whichmodule(DummyClass(), ""), "foo") 118 119 bar_then_foo = OrderedImporter(dummy_importer_bar, dummy_importer_foo) 120 self.assertEqual(bar_then_foo.whichmodule(DummyClass(), ""), "bar") 121 122 notfound_then_foo = OrderedImporter( 123 dummy_importer_not_found, dummy_importer_foo 124 ) 125 self.assertEqual(notfound_then_foo.whichmodule(DummyClass(), ""), "foo") 126 127 def test_package_importer_whichmodule_no_dunder_module(self): 128 """Exercise corner case where we try to pickle an object whose 129 __module__ doesn't exist because it's from a C extension. 130 """ 131 # torch.float16 is an example of such an object: it is a C extension 132 # type for which there is no __module__ defined. The default pickler 133 # finds it using special logic to traverse sys.modules and look up 134 # `float16` on each module (see pickle.py:whichmodule). 135 # 136 # We must ensure that we emulate the same behavior from PackageImporter. 137 my_dtype = torch.float16 138 139 # Set up a PackageImporter which has a torch.float16 object pickled: 140 buffer = BytesIO() 141 with PackageExporter(buffer) as exporter: 142 exporter.save_pickle("foo", "foo.pkl", my_dtype) 143 buffer.seek(0) 144 145 importer = PackageImporter(buffer) 146 my_loaded_dtype = importer.load_pickle("foo", "foo.pkl") 147 148 # Re-save a package with only our PackageImporter as the importer 149 buffer2 = BytesIO() 150 with PackageExporter(buffer2, importer=importer) as exporter: 151 exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype) 152 153 buffer2.seek(0) 154 155 importer2 = PackageImporter(buffer2) 156 my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl") 157 self.assertIs(my_dtype, my_loaded_dtype) 158 self.assertIs(my_dtype, my_loaded_dtype2) 159 160 161if __name__ == "__main__": 162 run_tests() 163