1# mypy: allow-untyped-defs 2import importlib 3from abc import ABC, abstractmethod 4from pickle import ( # type: ignore[attr-defined] 5 _getattribute, 6 _Pickler, 7 whichmodule as _pickle_whichmodule, 8) 9from types import ModuleType 10from typing import Any, Dict, List, Optional, Tuple 11 12from ._mangling import demangle, get_mangle_prefix, is_mangled 13 14 15__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] 16 17 18class ObjNotFoundError(Exception): 19 """Raised when an importer cannot find an object by searching for its name.""" 20 21 22class ObjMismatchError(Exception): 23 """Raised when an importer found a different object with the same name as the user-provided one.""" 24 25 26class Importer(ABC): 27 """Represents an environment to import modules from. 28 29 By default, you can figure out what module an object belongs by checking 30 __module__ and importing the result using __import__ or importlib.import_module. 31 32 torch.package introduces module importers other than the default one. 33 Each PackageImporter introduces a new namespace. Potentially a single 34 name (e.g. 'foo.bar') is present in multiple namespaces. 35 36 It supports two main operations: 37 import_module: module_name -> module object 38 get_name: object -> (parent module name, name of obj within module) 39 40 The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError. 41 module_name, obj_name = env.get_name(obj) 42 module = env.import_module(module_name) 43 obj2 = getattr(module, obj_name) 44 assert obj1 is obj2 45 """ 46 47 modules: Dict[str, ModuleType] 48 49 @abstractmethod 50 def import_module(self, module_name: str) -> ModuleType: 51 """Import `module_name` from this environment. 52 53 The contract is the same as for importlib.import_module. 54 """ 55 56 def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]: 57 """Given an object, return a name that can be used to retrieve the 58 object from this environment. 59 60 Args: 61 obj: An object to get the module-environment-relative name for. 62 name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`. 63 This is only here to match how Pickler handles __reduce__ functions that return a string, 64 don't use otherwise. 65 Returns: 66 A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment. 67 Use it like: 68 mod = importer.import_module(parent_module_name) 69 obj = getattr(mod, attr_name) 70 71 Raises: 72 ObjNotFoundError: we couldn't retrieve `obj by name. 73 ObjMisMatchError: we found a different object with the same name as `obj`. 74 """ 75 if name is None and obj and _Pickler.dispatch.get(type(obj)) is None: 76 # Honor the string return variant of __reduce__, which will give us 77 # a global name to search for in this environment. 78 # TODO: I guess we should do copyreg too? 79 reduce = getattr(obj, "__reduce__", None) 80 if reduce is not None: 81 try: 82 rv = reduce() 83 if isinstance(rv, str): 84 name = rv 85 except Exception: 86 pass 87 if name is None: 88 name = getattr(obj, "__qualname__", None) 89 if name is None: 90 name = obj.__name__ 91 92 orig_module_name = self.whichmodule(obj, name) 93 # Demangle the module name before importing. If this obj came out of a 94 # PackageImporter, `__module__` will be mangled. See mangling.md for 95 # details. 96 module_name = demangle(orig_module_name) 97 98 # Check that this name will indeed return the correct object 99 try: 100 module = self.import_module(module_name) 101 obj2, _ = _getattribute(module, name) 102 except (ImportError, KeyError, AttributeError): 103 raise ObjNotFoundError( 104 f"{obj} was not found as {module_name}.{name}" 105 ) from None 106 107 if obj is obj2: 108 return module_name, name 109 110 def get_obj_info(obj): 111 assert name is not None 112 module_name = self.whichmodule(obj, name) 113 is_mangled_ = is_mangled(module_name) 114 location = ( 115 get_mangle_prefix(module_name) 116 if is_mangled_ 117 else "the current Python environment" 118 ) 119 importer_name = ( 120 f"the importer for {get_mangle_prefix(module_name)}" 121 if is_mangled_ 122 else "'sys_importer'" 123 ) 124 return module_name, location, importer_name 125 126 obj_module_name, obj_location, obj_importer_name = get_obj_info(obj) 127 obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2) 128 msg = ( 129 f"\n\nThe object provided is from '{obj_module_name}', " 130 f"which is coming from {obj_location}." 131 f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}." 132 "\nTo fix this, make sure this 'PackageExporter's importer lists " 133 f"{obj_importer_name} before {obj2_importer_name}." 134 ) 135 raise ObjMismatchError(msg) 136 137 def whichmodule(self, obj: Any, name: str) -> str: 138 """Find the module name an object belongs to. 139 140 This should be considered internal for end-users, but developers of 141 an importer can override it to customize the behavior. 142 143 Taken from pickle.py, but modified to exclude the search into sys.modules 144 """ 145 module_name = getattr(obj, "__module__", None) 146 if module_name is not None: 147 return module_name 148 149 # Protect the iteration by using a list copy of self.modules against dynamic 150 # modules that trigger imports of other modules upon calls to getattr. 151 for module_name, module in self.modules.copy().items(): 152 if ( 153 module_name == "__main__" 154 or module_name == "__mp_main__" # bpo-42406 155 or module is None 156 ): 157 continue 158 try: 159 if _getattribute(module, name)[0] is obj: 160 return module_name 161 except AttributeError: 162 pass 163 164 return "__main__" 165 166 167class _SysImporter(Importer): 168 """An importer that implements the default behavior of Python.""" 169 170 def import_module(self, module_name: str): 171 return importlib.import_module(module_name) 172 173 def whichmodule(self, obj: Any, name: str) -> str: 174 return _pickle_whichmodule(obj, name) 175 176 177sys_importer = _SysImporter() 178 179 180class OrderedImporter(Importer): 181 """A compound importer that takes a list of importers and tries them one at a time. 182 183 The first importer in the list that returns a result "wins". 184 """ 185 186 def __init__(self, *args): 187 self._importers: List[Importer] = list(args) 188 189 def _is_torchpackage_dummy(self, module): 190 """Returns true iff this module is an empty PackageNode in a torch.package. 191 192 If you intern `a.b` but never use `a` in your code, then `a` will be an 193 empty module with no source. This can break cases where we are trying to 194 re-package an object after adding a real dependency on `a`, since 195 OrderedImportere will resolve `a` to the dummy package and stop there. 196 197 See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769 198 """ 199 if not getattr(module, "__torch_package__", False): 200 return False 201 if not hasattr(module, "__path__"): 202 return False 203 if not hasattr(module, "__file__"): 204 return True 205 return module.__file__ is None 206 207 def import_module(self, module_name: str) -> ModuleType: 208 last_err = None 209 for importer in self._importers: 210 if not isinstance(importer, Importer): 211 raise TypeError( 212 f"{importer} is not a Importer. " 213 "All importers in OrderedImporter must inherit from Importer." 214 ) 215 try: 216 module = importer.import_module(module_name) 217 if self._is_torchpackage_dummy(module): 218 continue 219 return module 220 except ModuleNotFoundError as err: 221 last_err = err 222 223 if last_err is not None: 224 raise last_err 225 else: 226 raise ModuleNotFoundError(module_name) 227 228 def whichmodule(self, obj: Any, name: str) -> str: 229 for importer in self._importers: 230 module_name = importer.whichmodule(obj, name) 231 if module_name != "__main__": 232 return module_name 233 234 return "__main__" 235