xref: /aosp_15_r20/external/pytorch/torch/package/importer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import importlib
3from abc import ABC, abstractmethod
4from pickle import (  # type: ignore[attr-defined]
5    _getattribute,
6    _Pickler,
7    whichmodule as _pickle_whichmodule,
8)
9from types import ModuleType
10from typing import Any, Dict, List, Optional, Tuple
11
12from ._mangling import demangle, get_mangle_prefix, is_mangled
13
14
15__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]
16
17
18class ObjNotFoundError(Exception):
19    """Raised when an importer cannot find an object by searching for its name."""
20
21
22class ObjMismatchError(Exception):
23    """Raised when an importer found a different object with the same name as the user-provided one."""
24
25
26class Importer(ABC):
27    """Represents an environment to import modules from.
28
29    By default, you can figure out what module an object belongs by checking
30    __module__ and importing the result using __import__ or importlib.import_module.
31
32    torch.package introduces module importers other than the default one.
33    Each PackageImporter introduces a new namespace. Potentially a single
34    name (e.g. 'foo.bar') is present in multiple namespaces.
35
36    It supports two main operations:
37        import_module: module_name -> module object
38        get_name: object -> (parent module name, name of obj within module)
39
40    The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError.
41        module_name, obj_name = env.get_name(obj)
42        module = env.import_module(module_name)
43        obj2 = getattr(module, obj_name)
44        assert obj1 is obj2
45    """
46
47    modules: Dict[str, ModuleType]
48
49    @abstractmethod
50    def import_module(self, module_name: str) -> ModuleType:
51        """Import `module_name` from this environment.
52
53        The contract is the same as for importlib.import_module.
54        """
55
56    def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
57        """Given an object, return a name that can be used to retrieve the
58        object from this environment.
59
60        Args:
61            obj: An object to get the module-environment-relative name for.
62            name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`.
63                This is only here to match how Pickler handles __reduce__ functions that return a string,
64                don't use otherwise.
65        Returns:
66            A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment.
67            Use it like:
68                mod = importer.import_module(parent_module_name)
69                obj = getattr(mod, attr_name)
70
71        Raises:
72            ObjNotFoundError: we couldn't retrieve `obj by name.
73            ObjMisMatchError: we found a different object with the same name as `obj`.
74        """
75        if name is None and obj and _Pickler.dispatch.get(type(obj)) is None:
76            # Honor the string return variant of __reduce__, which will give us
77            # a global name to search for in this environment.
78            # TODO: I guess we should do copyreg too?
79            reduce = getattr(obj, "__reduce__", None)
80            if reduce is not None:
81                try:
82                    rv = reduce()
83                    if isinstance(rv, str):
84                        name = rv
85                except Exception:
86                    pass
87        if name is None:
88            name = getattr(obj, "__qualname__", None)
89        if name is None:
90            name = obj.__name__
91
92        orig_module_name = self.whichmodule(obj, name)
93        # Demangle the module name before importing. If this obj came out of a
94        # PackageImporter, `__module__` will be mangled. See mangling.md for
95        # details.
96        module_name = demangle(orig_module_name)
97
98        # Check that this name will indeed return the correct object
99        try:
100            module = self.import_module(module_name)
101            obj2, _ = _getattribute(module, name)
102        except (ImportError, KeyError, AttributeError):
103            raise ObjNotFoundError(
104                f"{obj} was not found as {module_name}.{name}"
105            ) from None
106
107        if obj is obj2:
108            return module_name, name
109
110        def get_obj_info(obj):
111            assert name is not None
112            module_name = self.whichmodule(obj, name)
113            is_mangled_ = is_mangled(module_name)
114            location = (
115                get_mangle_prefix(module_name)
116                if is_mangled_
117                else "the current Python environment"
118            )
119            importer_name = (
120                f"the importer for {get_mangle_prefix(module_name)}"
121                if is_mangled_
122                else "'sys_importer'"
123            )
124            return module_name, location, importer_name
125
126        obj_module_name, obj_location, obj_importer_name = get_obj_info(obj)
127        obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2)
128        msg = (
129            f"\n\nThe object provided is from '{obj_module_name}', "
130            f"which is coming from {obj_location}."
131            f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}."
132            "\nTo fix this, make sure this 'PackageExporter's importer lists "
133            f"{obj_importer_name} before {obj2_importer_name}."
134        )
135        raise ObjMismatchError(msg)
136
137    def whichmodule(self, obj: Any, name: str) -> str:
138        """Find the module name an object belongs to.
139
140        This should be considered internal for end-users, but developers of
141        an importer can override it to customize the behavior.
142
143        Taken from pickle.py, but modified to exclude the search into sys.modules
144        """
145        module_name = getattr(obj, "__module__", None)
146        if module_name is not None:
147            return module_name
148
149        # Protect the iteration by using a list copy of self.modules against dynamic
150        # modules that trigger imports of other modules upon calls to getattr.
151        for module_name, module in self.modules.copy().items():
152            if (
153                module_name == "__main__"
154                or module_name == "__mp_main__"  # bpo-42406
155                or module is None
156            ):
157                continue
158            try:
159                if _getattribute(module, name)[0] is obj:
160                    return module_name
161            except AttributeError:
162                pass
163
164        return "__main__"
165
166
167class _SysImporter(Importer):
168    """An importer that implements the default behavior of Python."""
169
170    def import_module(self, module_name: str):
171        return importlib.import_module(module_name)
172
173    def whichmodule(self, obj: Any, name: str) -> str:
174        return _pickle_whichmodule(obj, name)
175
176
177sys_importer = _SysImporter()
178
179
180class OrderedImporter(Importer):
181    """A compound importer that takes a list of importers and tries them one at a time.
182
183    The first importer in the list that returns a result "wins".
184    """
185
186    def __init__(self, *args):
187        self._importers: List[Importer] = list(args)
188
189    def _is_torchpackage_dummy(self, module):
190        """Returns true iff this module is an empty PackageNode in a torch.package.
191
192        If you intern `a.b` but never use `a` in your code, then `a` will be an
193        empty module with no source. This can break cases where we are trying to
194        re-package an object after adding a real dependency on `a`, since
195        OrderedImportere will resolve `a` to the dummy package and stop there.
196
197        See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769
198        """
199        if not getattr(module, "__torch_package__", False):
200            return False
201        if not hasattr(module, "__path__"):
202            return False
203        if not hasattr(module, "__file__"):
204            return True
205        return module.__file__ is None
206
207    def import_module(self, module_name: str) -> ModuleType:
208        last_err = None
209        for importer in self._importers:
210            if not isinstance(importer, Importer):
211                raise TypeError(
212                    f"{importer} is not a Importer. "
213                    "All importers in OrderedImporter must inherit from Importer."
214                )
215            try:
216                module = importer.import_module(module_name)
217                if self._is_torchpackage_dummy(module):
218                    continue
219                return module
220            except ModuleNotFoundError as err:
221                last_err = err
222
223        if last_err is not None:
224            raise last_err
225        else:
226            raise ModuleNotFoundError(module_name)
227
228    def whichmodule(self, obj: Any, name: str) -> str:
229        for importer in self._importers:
230            module_name = importer.whichmodule(obj, name)
231            if module_name != "__main__":
232                return module_name
233
234        return "__main__"
235