xref: /aosp_15_r20/external/pytorch/torch/package/package_importer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import builtins
3import importlib
4import importlib.machinery
5import inspect
6import io
7import linecache
8import os
9import types
10from contextlib import contextmanager
11from typing import (
12    Any,
13    BinaryIO,
14    Callable,
15    cast,
16    Dict,
17    Iterable,
18    List,
19    Optional,
20    TYPE_CHECKING,
21    Union,
22)
23from weakref import WeakValueDictionary
24
25import torch
26from torch.serialization import _get_restore_location, _maybe_decode_ascii
27
28from ._directory_reader import DirectoryReader
29from ._importlib import (
30    _calc___package__,
31    _normalize_line_endings,
32    _normalize_path,
33    _resolve_name,
34    _sanity_check,
35)
36from ._mangling import demangle, PackageMangler
37from ._package_unpickler import PackageUnpickler
38from .file_structure_representation import _create_directory_from_file_list, Directory
39from .importer import Importer
40
41
42if TYPE_CHECKING:
43    from .glob_group import GlobPattern
44
45__all__ = ["PackageImporter"]
46
47
48# This is a list of imports that are implicitly allowed even if they haven't
49# been marked as extern. This is to work around the fact that Torch implicitly
50# depends on numpy and package can't track it.
51# https://github.com/pytorch/MultiPy/issues/46
52IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
53    "numpy",
54    "numpy.core",
55    "numpy.core._multiarray_umath",
56    # FX GraphModule might depend on builtins module and users usually
57    # don't extern builtins. Here we import it here by default.
58    "builtins",
59]
60
61
62# Compatibility name mapping to facilitate upgrade of external modules.
63# The primary motivation is to enable Numpy upgrade that many modules
64# depend on. The latest release of Numpy removed `numpy.str` and
65# `numpy.bool` breaking unpickling for many modules.
66EXTERN_IMPORT_COMPAT_NAME_MAPPING: Dict[str, Dict[str, Any]] = {
67    "numpy": {
68        "str": str,
69        "bool": bool,
70    },
71}
72
73
74class PackageImporter(Importer):
75    """Importers allow you to load code written to packages by :class:`PackageExporter`.
76    Code is loaded in a hermetic way, using files from the package
77    rather than the normal python import system. This allows
78    for the packaging of PyTorch model code and data so that it can be run
79    on a server or used in the future for transfer learning.
80
81    The importer for packages ensures that code in the module can only be loaded from
82    within the package, except for modules explicitly listed as external during export.
83    The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
84    This prevents "implicit" dependencies where the package runs locally because it is importing
85    a locally-installed package, but then fails when the package is copied to another machine.
86    """
87
88    """The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
89    local to this importer.
90    """
91
92    modules: Dict[str, types.ModuleType]
93
94    def __init__(
95        self,
96        file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO],
97        module_allowed: Callable[[str], bool] = lambda module_name: True,
98    ):
99        """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
100        allowed by ``module_allowed``
101
102        Args:
103            file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
104                a string, or an ``os.PathLike`` object containing a filename.
105            module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
106                should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
107                does not support. Defaults to allowing anything.
108
109        Raises:
110            ImportError: If the package will use a disallowed module.
111        """
112        torch._C._log_api_usage_once("torch.package.PackageImporter")
113
114        self.zip_reader: Any
115        if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
116            self.filename = "<pytorch_file_reader>"
117            self.zip_reader = file_or_buffer
118        elif isinstance(file_or_buffer, (os.PathLike, str)):
119            self.filename = os.fspath(file_or_buffer)
120            if not os.path.isdir(self.filename):
121                self.zip_reader = torch._C.PyTorchFileReader(self.filename)
122            else:
123                self.zip_reader = DirectoryReader(self.filename)
124        else:
125            self.filename = "<binary>"
126            self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
127
128        torch._C._log_api_usage_metadata(
129            "torch.package.PackageImporter.metadata",
130            {
131                "serialization_id": self.zip_reader.serialization_id(),
132                "file_name": self.filename,
133            },
134        )
135
136        self.root = _PackageNode(None)
137        self.modules = {}
138        self.extern_modules = self._read_extern()
139
140        for extern_module in self.extern_modules:
141            if not module_allowed(extern_module):
142                raise ImportError(
143                    f"package '{file_or_buffer}' needs the external module '{extern_module}' "
144                    f"but that module has been disallowed"
145                )
146            self._add_extern(extern_module)
147
148        for fname in self.zip_reader.get_all_records():
149            self._add_file(fname)
150
151        self.patched_builtins = builtins.__dict__.copy()
152        self.patched_builtins["__import__"] = self.__import__
153        # Allow packaged modules to reference their PackageImporter
154        self.modules["torch_package_importer"] = self  # type: ignore[assignment]
155
156        self._mangler = PackageMangler()
157
158        # used for reduce deserializaiton
159        self.storage_context: Any = None
160        self.last_map_location = None
161
162        # used for torch.serialization._load
163        self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
164
165    def import_module(self, name: str, package=None):
166        """Load a module from the package if it hasn't already been loaded, and then return
167        the module. Modules are loaded locally
168        to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
169
170        Args:
171            name (str): Fully qualified name of the module to load.
172            package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
173
174        Returns:
175            types.ModuleType: The (possibly already) loaded module.
176        """
177        # We should always be able to support importing modules from this package.
178        # This is to support something like:
179        #   obj = importer.load_pickle(...)
180        #   importer.import_module(obj.__module__)  <- this string will be mangled
181        #
182        # Note that _mangler.demangle will not demangle any module names
183        # produced by a different PackageImporter instance.
184        name = self._mangler.demangle(name)
185
186        return self._gcd_import(name)
187
188    def load_binary(self, package: str, resource: str) -> bytes:
189        """Load raw bytes.
190
191        Args:
192            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
193            resource (str): The unique name for the resource.
194
195        Returns:
196            bytes: The loaded data.
197        """
198
199        path = self._zipfile_path(package, resource)
200        return self.zip_reader.get_record(path)
201
202    def load_text(
203        self,
204        package: str,
205        resource: str,
206        encoding: str = "utf-8",
207        errors: str = "strict",
208    ) -> str:
209        """Load a string.
210
211        Args:
212            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
213            resource (str): The unique name for the resource.
214            encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
215            errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
216
217        Returns:
218            str: The loaded text.
219        """
220        data = self.load_binary(package, resource)
221        return data.decode(encoding, errors)
222
223    def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
224        """Unpickles the resource from the package, loading any modules that are needed to construct the objects
225        using :meth:`import_module`.
226
227        Args:
228            package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
229            resource (str): The unique name for the resource.
230            map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.
231
232        Returns:
233            Any: The unpickled object.
234        """
235        pickle_file = self._zipfile_path(package, resource)
236        restore_location = _get_restore_location(map_location)
237        loaded_storages = {}
238        loaded_reduces = {}
239        storage_context = torch._C.DeserializationStorageContext()
240
241        def load_tensor(dtype, size, key, location, restore_location):
242            name = f"{key}.storage"
243
244            if storage_context.has_storage(name):
245                storage = storage_context.get_storage(name, dtype)._typed_storage()
246            else:
247                tensor = self.zip_reader.get_storage_from_record(
248                    ".data/" + name, size, dtype
249                )
250                if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
251                    storage_context.add_storage(name, tensor)
252                storage = tensor._typed_storage()
253            loaded_storages[key] = restore_location(storage, location)
254
255        def persistent_load(saved_id):
256            assert isinstance(saved_id, tuple)
257            typename = _maybe_decode_ascii(saved_id[0])
258            data = saved_id[1:]
259
260            if typename == "storage":
261                storage_type, key, location, size = data
262                dtype = storage_type.dtype
263
264                if key not in loaded_storages:
265                    load_tensor(
266                        dtype,
267                        size,
268                        key,
269                        _maybe_decode_ascii(location),
270                        restore_location,
271                    )
272                storage = loaded_storages[key]
273                # TODO: Once we decide to break serialization FC, we can
274                # stop wrapping with TypedStorage
275                return torch.storage.TypedStorage(
276                    wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
277                )
278            elif typename == "reduce_package":
279                # to fix BC breaking change, objects on this load path
280                # will be loaded multiple times erroneously
281                if len(data) == 2:
282                    func, args = data
283                    return func(self, *args)
284                reduce_id, func, args = data
285                if reduce_id not in loaded_reduces:
286                    loaded_reduces[reduce_id] = func(self, *args)
287                return loaded_reduces[reduce_id]
288            else:
289                f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'"
290
291        # Load the data (which may in turn use `persistent_load` to load tensors)
292        data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
293        unpickler = self.Unpickler(data_file)
294        unpickler.persistent_load = persistent_load  # type: ignore[assignment]
295
296        @contextmanager
297        def set_deserialization_context():
298            # to let reduce_package access deserializaiton context
299            self.storage_context = storage_context
300            self.last_map_location = map_location
301            try:
302                yield
303            finally:
304                self.storage_context = None
305                self.last_map_location = None
306
307        with set_deserialization_context():
308            result = unpickler.load()
309
310        # TODO from zdevito:
311        #   This stateful weird function will need to be removed in our efforts
312        #   to unify the format. It has a race condition if multiple python
313        #   threads try to read independent files
314        torch._utils._validate_loaded_sparse_tensors()
315
316        return result
317
318    def id(self):
319        """
320        Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
321        Looks like::
322
323            <torch_package_0>
324        """
325        return self._mangler.parent_name()
326
327    def file_structure(
328        self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
329    ) -> Directory:
330        """Returns a file structure representation of package's zipfile.
331
332        Args:
333            include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
334                for the names of the files to be included in the zipfile representation. This can also be
335                a glob-style pattern, as described in :meth:`PackageExporter.mock`
336
337            exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
338
339        Returns:
340            :class:`Directory`
341        """
342        return _create_directory_from_file_list(
343            self.filename, self.zip_reader.get_all_records(), include, exclude
344        )
345
346    def python_version(self):
347        """Returns the version of python that was used to create this package.
348
349        Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
350        file later on.
351
352        Returns:
353            :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
354        """
355        python_version_path = ".data/python_version"
356        return (
357            self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
358            if self.zip_reader.has_record(python_version_path)
359            else None
360        )
361
362    def _read_extern(self):
363        return (
364            self.zip_reader.get_record(".data/extern_modules")
365            .decode("utf-8")
366            .splitlines(keepends=False)
367        )
368
369    def _make_module(
370        self, name: str, filename: Optional[str], is_package: bool, parent: str
371    ):
372        mangled_filename = self._mangler.mangle(filename) if filename else None
373        spec = importlib.machinery.ModuleSpec(
374            name,
375            self,  # type: ignore[arg-type]
376            origin="<package_importer>",
377            is_package=is_package,
378        )
379        module = importlib.util.module_from_spec(spec)
380        self.modules[name] = module
381        module.__name__ = self._mangler.mangle(name)
382        ns = module.__dict__
383        ns["__spec__"] = spec
384        ns["__loader__"] = self
385        ns["__file__"] = mangled_filename
386        ns["__cached__"] = None
387        ns["__builtins__"] = self.patched_builtins
388        ns["__torch_package__"] = True
389
390        # Add this module to our private global registry. It should be unique due to mangling.
391        assert module.__name__ not in _package_imported_modules
392        _package_imported_modules[module.__name__] = module
393
394        # pre-emptively install on the parent to prevent IMPORT_FROM from trying to
395        # access sys.modules
396        self._install_on_parent(parent, name, module)
397
398        if filename is not None:
399            assert mangled_filename is not None
400            # pre-emptively install the source in `linecache` so that stack traces,
401            # `inspect`, etc. work.
402            assert filename not in linecache.cache  # type: ignore[attr-defined]
403            linecache.lazycache(mangled_filename, ns)
404
405            code = self._compile_source(filename, mangled_filename)
406            exec(code, ns)
407
408        return module
409
410    def _load_module(self, name: str, parent: str):
411        cur: _PathNode = self.root
412        for atom in name.split("."):
413            if not isinstance(cur, _PackageNode) or atom not in cur.children:
414                if name in IMPLICIT_IMPORT_ALLOWLIST:
415                    module = self.modules[name] = importlib.import_module(name)
416                    return module
417                raise ModuleNotFoundError(
418                    f'No module named "{name}" in self-contained archive "{self.filename}"'
419                    f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
420                    name=name,
421                )
422            cur = cur.children[atom]
423            if isinstance(cur, _ExternNode):
424                module = self.modules[name] = importlib.import_module(name)
425
426                if compat_mapping := EXTERN_IMPORT_COMPAT_NAME_MAPPING.get(name):
427                    for old_name, new_name in compat_mapping.items():
428                        module.__dict__.setdefault(old_name, new_name)
429
430                return module
431        return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent)  # type: ignore[attr-defined]
432
433    def _compile_source(self, fullpath: str, mangled_filename: str):
434        source = self.zip_reader.get_record(fullpath)
435        source = _normalize_line_endings(source)
436        return compile(source, mangled_filename, "exec", dont_inherit=True)
437
438    # note: named `get_source` so that linecache can find the source
439    # when this is the __loader__ of a module.
440    def get_source(self, module_name) -> str:
441        # linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
442        module = self.import_module(demangle(module_name))
443        return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
444
445    # note: named `get_resource_reader` so that importlib.resources can find it.
446    # This is otherwise considered an internal method.
447    def get_resource_reader(self, fullname):
448        try:
449            package = self._get_package(fullname)
450        except ImportError:
451            return None
452        if package.__loader__ is not self:
453            return None
454        return _PackageResourceReader(self, fullname)
455
456    def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
457        if not parent:
458            return
459        # Set the module as an attribute on its parent.
460        parent_module = self.modules[parent]
461        if parent_module.__loader__ is self:
462            setattr(parent_module, name.rpartition(".")[2], module)
463
464    # note: copied from cpython's import code, with call to create module replaced with _make_module
465    def _do_find_and_load(self, name):
466        path = None
467        parent = name.rpartition(".")[0]
468        module_name_no_parent = name.rpartition(".")[-1]
469        if parent:
470            if parent not in self.modules:
471                self._gcd_import(parent)
472            # Crazy side-effects!
473            if name in self.modules:
474                return self.modules[name]
475            parent_module = self.modules[parent]
476
477            try:
478                path = parent_module.__path__  # type: ignore[attr-defined]
479
480            except AttributeError:
481                # when we attempt to import a package only containing pybinded files,
482                # the parent directory isn't always a package as defined by python,
483                # so we search if the package is actually there or not before calling the error.
484                if isinstance(
485                    parent_module.__loader__,
486                    importlib.machinery.ExtensionFileLoader,
487                ):
488                    if name not in self.extern_modules:
489                        msg = (
490                            _ERR_MSG
491                            + "; {!r} is a c extension module which was not externed. C extension modules \
492                            need to be externed by the PackageExporter in order to be used as we do not support interning them.}."
493                        ).format(name, name)
494                        raise ModuleNotFoundError(msg, name=name) from None
495                    if not isinstance(
496                        parent_module.__dict__.get(module_name_no_parent),
497                        types.ModuleType,
498                    ):
499                        msg = (
500                            _ERR_MSG
501                            + "; {!r} is a c extension package which does not contain {!r}."
502                        ).format(name, parent, name)
503                        raise ModuleNotFoundError(msg, name=name) from None
504                else:
505                    msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
506                    raise ModuleNotFoundError(msg, name=name) from None
507
508        module = self._load_module(name, parent)
509
510        self._install_on_parent(parent, name, module)
511
512        return module
513
514    # note: copied from cpython's import code
515    def _find_and_load(self, name):
516        module = self.modules.get(name, _NEEDS_LOADING)
517        if module is _NEEDS_LOADING:
518            return self._do_find_and_load(name)
519
520        if module is None:
521            message = f"import of {name} halted; None in sys.modules"
522            raise ModuleNotFoundError(message, name=name)
523
524        # To handle https://github.com/pytorch/pytorch/issues/57490, where std's
525        # creation of fake submodules via the hacking of sys.modules is not import
526        # friendly
527        if name == "os":
528            self.modules["os.path"] = cast(Any, module).path
529        elif name == "typing":
530            self.modules["typing.io"] = cast(Any, module).io
531            self.modules["typing.re"] = cast(Any, module).re
532
533        return module
534
535    def _gcd_import(self, name, package=None, level=0):
536        """Import and return the module based on its name, the package the call is
537        being made from, and the level adjustment.
538
539        This function represents the greatest common denominator of functionality
540        between import_module and __import__. This includes setting __package__ if
541        the loader did not.
542
543        """
544        _sanity_check(name, package, level)
545        if level > 0:
546            name = _resolve_name(name, package, level)
547
548        return self._find_and_load(name)
549
550    # note: copied from cpython's import code
551    def _handle_fromlist(self, module, fromlist, *, recursive=False):
552        """Figure out what __import__ should return.
553
554        The import_ parameter is a callable which takes the name of module to
555        import. It is required to decouple the function from assuming importlib's
556        import implementation is desired.
557
558        """
559        module_name = demangle(module.__name__)
560        # The hell that is fromlist ...
561        # If a package was imported, try to import stuff from fromlist.
562        if hasattr(module, "__path__"):
563            for x in fromlist:
564                if not isinstance(x, str):
565                    if recursive:
566                        where = module_name + ".__all__"
567                    else:
568                        where = "``from list''"
569                    raise TypeError(
570                        f"Item in {where} must be str, " f"not {type(x).__name__}"
571                    )
572                elif x == "*":
573                    if not recursive and hasattr(module, "__all__"):
574                        self._handle_fromlist(module, module.__all__, recursive=True)
575                elif not hasattr(module, x):
576                    from_name = f"{module_name}.{x}"
577                    try:
578                        self._gcd_import(from_name)
579                    except ModuleNotFoundError as exc:
580                        # Backwards-compatibility dictates we ignore failed
581                        # imports triggered by fromlist for modules that don't
582                        # exist.
583                        if (
584                            exc.name == from_name
585                            and self.modules.get(from_name, _NEEDS_LOADING) is not None
586                        ):
587                            continue
588                        raise
589        return module
590
591    def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
592        if level == 0:
593            module = self._gcd_import(name)
594        else:
595            globals_ = globals if globals is not None else {}
596            package = _calc___package__(globals_)
597            module = self._gcd_import(name, package, level)
598        if not fromlist:
599            # Return up to the first dot in 'name'. This is complicated by the fact
600            # that 'name' may be relative.
601            if level == 0:
602                return self._gcd_import(name.partition(".")[0])
603            elif not name:
604                return module
605            else:
606                # Figure out where to slice the module's name up to the first dot
607                # in 'name'.
608                cut_off = len(name) - len(name.partition(".")[0])
609                # Slice end needs to be positive to alleviate need to special-case
610                # when ``'.' not in name``.
611                module_name = demangle(module.__name__)
612                return self.modules[module_name[: len(module_name) - cut_off]]
613        else:
614            return self._handle_fromlist(module, fromlist)
615
616    def _get_package(self, package):
617        """Take a package name or module object and return the module.
618
619        If a name, the module is imported.  If the passed or imported module
620        object is not a package, raise an exception.
621        """
622        if hasattr(package, "__spec__"):
623            if package.__spec__.submodule_search_locations is None:
624                raise TypeError(f"{package.__spec__.name!r} is not a package")
625            else:
626                return package
627        else:
628            module = self.import_module(package)
629            if module.__spec__.submodule_search_locations is None:
630                raise TypeError(f"{package!r} is not a package")
631            else:
632                return module
633
634    def _zipfile_path(self, package, resource=None):
635        package = self._get_package(package)
636        assert package.__loader__ is self
637        name = demangle(package.__name__)
638        if resource is not None:
639            resource = _normalize_path(resource)
640            return f"{name.replace('.', '/')}/{resource}"
641        else:
642            return f"{name.replace('.', '/')}"
643
644    def _get_or_create_package(
645        self, atoms: List[str]
646    ) -> "Union[_PackageNode, _ExternNode]":
647        cur = self.root
648        for i, atom in enumerate(atoms):
649            node = cur.children.get(atom, None)
650            if node is None:
651                node = cur.children[atom] = _PackageNode(None)
652            if isinstance(node, _ExternNode):
653                return node
654            if isinstance(node, _ModuleNode):
655                name = ".".join(atoms[:i])
656                raise ImportError(
657                    f"inconsistent module structure. module {name} is not a package, but has submodules"
658                )
659            assert isinstance(node, _PackageNode)
660            cur = node
661        return cur
662
663    def _add_file(self, filename: str):
664        """Assembles a Python module out of the given file. Will ignore files in the .data directory.
665
666        Args:
667            filename (str): the name of the file inside of the package archive to be added
668        """
669        *prefix, last = filename.split("/")
670        if len(prefix) > 1 and prefix[0] == ".data":
671            return
672        package = self._get_or_create_package(prefix)
673        if isinstance(package, _ExternNode):
674            raise ImportError(
675                f"inconsistent module structure. package contains a module file {filename}"
676                f" that is a subpackage of a module marked external."
677            )
678        if last == "__init__.py":
679            package.source_file = filename
680        elif last.endswith(".py"):
681            package_name = last[: -len(".py")]
682            package.children[package_name] = _ModuleNode(filename)
683
684    def _add_extern(self, extern_name: str):
685        *prefix, last = extern_name.split(".")
686        package = self._get_or_create_package(prefix)
687        if isinstance(package, _ExternNode):
688            return  # the shorter extern covers this extern case
689        package.children[last] = _ExternNode()
690
691
692_NEEDS_LOADING = object()
693_ERR_MSG_PREFIX = "No module named "
694_ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
695
696
697class _PathNode:
698    pass
699
700
701class _PackageNode(_PathNode):
702    def __init__(self, source_file: Optional[str]):
703        self.source_file = source_file
704        self.children: Dict[str, _PathNode] = {}
705
706
707class _ModuleNode(_PathNode):
708    __slots__ = ["source_file"]
709
710    def __init__(self, source_file: str):
711        self.source_file = source_file
712
713
714class _ExternNode(_PathNode):
715    pass
716
717
718# A private global registry of all modules that have been package-imported.
719_package_imported_modules: WeakValueDictionary = WeakValueDictionary()
720
721# `inspect` by default only looks in `sys.modules` to find source files for classes.
722# Patch it to check our private registry of package-imported modules as well.
723_orig_getfile = inspect.getfile
724
725
726def _patched_getfile(object):
727    if inspect.isclass(object):
728        if object.__module__ in _package_imported_modules:
729            return _package_imported_modules[object.__module__].__file__
730    return _orig_getfile(object)
731
732
733inspect.getfile = _patched_getfile
734
735
736class _PackageResourceReader:
737    """Private class used to support PackageImporter.get_resource_reader().
738
739    Confirms to the importlib.abc.ResourceReader interface. Allowed to access
740    the innards of PackageImporter.
741    """
742
743    def __init__(self, importer, fullname):
744        self.importer = importer
745        self.fullname = fullname
746
747    def open_resource(self, resource):
748        from io import BytesIO
749
750        return BytesIO(self.importer.load_binary(self.fullname, resource))
751
752    def resource_path(self, resource):
753        # The contract for resource_path is that it either returns a concrete
754        # file system path or raises FileNotFoundError.
755        if isinstance(
756            self.importer.zip_reader, DirectoryReader
757        ) and self.importer.zip_reader.has_record(
758            os.path.join(self.fullname, resource)
759        ):
760            return os.path.join(
761                self.importer.zip_reader.directory, self.fullname, resource
762            )
763        raise FileNotFoundError
764
765    def is_resource(self, name):
766        path = self.importer._zipfile_path(self.fullname, name)
767        return self.importer.zip_reader.has_record(path)
768
769    def contents(self):
770        from pathlib import Path
771
772        filename = self.fullname.replace(".", "/")
773
774        fullname_path = Path(self.importer._zipfile_path(self.fullname))
775        files = self.importer.zip_reader.get_all_records()
776        subdirs_seen = set()
777        for filename in files:
778            try:
779                relative = Path(filename).relative_to(fullname_path)
780            except ValueError:
781                continue
782            # If the path of the file (which is relative to the top of the zip
783            # namespace), relative to the package given when the resource
784            # reader was created, has a parent, then it's a name in a
785            # subdirectory and thus we skip it.
786            parent_name = relative.parent.name
787            if len(parent_name) == 0:
788                yield relative.name
789            elif parent_name not in subdirs_seen:
790                subdirs_seen.add(parent_name)
791                yield parent_name
792