1# mypy: allow-untyped-defs 2import builtins 3import importlib 4import importlib.machinery 5import inspect 6import io 7import linecache 8import os 9import types 10from contextlib import contextmanager 11from typing import ( 12 Any, 13 BinaryIO, 14 Callable, 15 cast, 16 Dict, 17 Iterable, 18 List, 19 Optional, 20 TYPE_CHECKING, 21 Union, 22) 23from weakref import WeakValueDictionary 24 25import torch 26from torch.serialization import _get_restore_location, _maybe_decode_ascii 27 28from ._directory_reader import DirectoryReader 29from ._importlib import ( 30 _calc___package__, 31 _normalize_line_endings, 32 _normalize_path, 33 _resolve_name, 34 _sanity_check, 35) 36from ._mangling import demangle, PackageMangler 37from ._package_unpickler import PackageUnpickler 38from .file_structure_representation import _create_directory_from_file_list, Directory 39from .importer import Importer 40 41 42if TYPE_CHECKING: 43 from .glob_group import GlobPattern 44 45__all__ = ["PackageImporter"] 46 47 48# This is a list of imports that are implicitly allowed even if they haven't 49# been marked as extern. This is to work around the fact that Torch implicitly 50# depends on numpy and package can't track it. 51# https://github.com/pytorch/MultiPy/issues/46 52IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [ 53 "numpy", 54 "numpy.core", 55 "numpy.core._multiarray_umath", 56 # FX GraphModule might depend on builtins module and users usually 57 # don't extern builtins. Here we import it here by default. 58 "builtins", 59] 60 61 62# Compatibility name mapping to facilitate upgrade of external modules. 63# The primary motivation is to enable Numpy upgrade that many modules 64# depend on. The latest release of Numpy removed `numpy.str` and 65# `numpy.bool` breaking unpickling for many modules. 66EXTERN_IMPORT_COMPAT_NAME_MAPPING: Dict[str, Dict[str, Any]] = { 67 "numpy": { 68 "str": str, 69 "bool": bool, 70 }, 71} 72 73 74class PackageImporter(Importer): 75 """Importers allow you to load code written to packages by :class:`PackageExporter`. 76 Code is loaded in a hermetic way, using files from the package 77 rather than the normal python import system. This allows 78 for the packaging of PyTorch model code and data so that it can be run 79 on a server or used in the future for transfer learning. 80 81 The importer for packages ensures that code in the module can only be loaded from 82 within the package, except for modules explicitly listed as external during export. 83 The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. 84 This prevents "implicit" dependencies where the package runs locally because it is importing 85 a locally-installed package, but then fails when the package is copied to another machine. 86 """ 87 88 """The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but 89 local to this importer. 90 """ 91 92 modules: Dict[str, types.ModuleType] 93 94 def __init__( 95 self, 96 file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO], 97 module_allowed: Callable[[str], bool] = lambda module_name: True, 98 ): 99 """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules 100 allowed by ``module_allowed`` 101 102 Args: 103 file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), 104 a string, or an ``os.PathLike`` object containing a filename. 105 module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module 106 should be allowed. Can be used to ensure packages loaded do not depend on modules that the server 107 does not support. Defaults to allowing anything. 108 109 Raises: 110 ImportError: If the package will use a disallowed module. 111 """ 112 torch._C._log_api_usage_once("torch.package.PackageImporter") 113 114 self.zip_reader: Any 115 if isinstance(file_or_buffer, torch._C.PyTorchFileReader): 116 self.filename = "<pytorch_file_reader>" 117 self.zip_reader = file_or_buffer 118 elif isinstance(file_or_buffer, (os.PathLike, str)): 119 self.filename = os.fspath(file_or_buffer) 120 if not os.path.isdir(self.filename): 121 self.zip_reader = torch._C.PyTorchFileReader(self.filename) 122 else: 123 self.zip_reader = DirectoryReader(self.filename) 124 else: 125 self.filename = "<binary>" 126 self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer) 127 128 torch._C._log_api_usage_metadata( 129 "torch.package.PackageImporter.metadata", 130 { 131 "serialization_id": self.zip_reader.serialization_id(), 132 "file_name": self.filename, 133 }, 134 ) 135 136 self.root = _PackageNode(None) 137 self.modules = {} 138 self.extern_modules = self._read_extern() 139 140 for extern_module in self.extern_modules: 141 if not module_allowed(extern_module): 142 raise ImportError( 143 f"package '{file_or_buffer}' needs the external module '{extern_module}' " 144 f"but that module has been disallowed" 145 ) 146 self._add_extern(extern_module) 147 148 for fname in self.zip_reader.get_all_records(): 149 self._add_file(fname) 150 151 self.patched_builtins = builtins.__dict__.copy() 152 self.patched_builtins["__import__"] = self.__import__ 153 # Allow packaged modules to reference their PackageImporter 154 self.modules["torch_package_importer"] = self # type: ignore[assignment] 155 156 self._mangler = PackageMangler() 157 158 # used for reduce deserializaiton 159 self.storage_context: Any = None 160 self.last_map_location = None 161 162 # used for torch.serialization._load 163 self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs) 164 165 def import_module(self, name: str, package=None): 166 """Load a module from the package if it hasn't already been loaded, and then return 167 the module. Modules are loaded locally 168 to the importer and will appear in ``self.modules`` rather than ``sys.modules``. 169 170 Args: 171 name (str): Fully qualified name of the module to load. 172 package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``. 173 174 Returns: 175 types.ModuleType: The (possibly already) loaded module. 176 """ 177 # We should always be able to support importing modules from this package. 178 # This is to support something like: 179 # obj = importer.load_pickle(...) 180 # importer.import_module(obj.__module__) <- this string will be mangled 181 # 182 # Note that _mangler.demangle will not demangle any module names 183 # produced by a different PackageImporter instance. 184 name = self._mangler.demangle(name) 185 186 return self._gcd_import(name) 187 188 def load_binary(self, package: str, resource: str) -> bytes: 189 """Load raw bytes. 190 191 Args: 192 package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). 193 resource (str): The unique name for the resource. 194 195 Returns: 196 bytes: The loaded data. 197 """ 198 199 path = self._zipfile_path(package, resource) 200 return self.zip_reader.get_record(path) 201 202 def load_text( 203 self, 204 package: str, 205 resource: str, 206 encoding: str = "utf-8", 207 errors: str = "strict", 208 ) -> str: 209 """Load a string. 210 211 Args: 212 package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). 213 resource (str): The unique name for the resource. 214 encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``. 215 errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``. 216 217 Returns: 218 str: The loaded text. 219 """ 220 data = self.load_binary(package, resource) 221 return data.decode(encoding, errors) 222 223 def load_pickle(self, package: str, resource: str, map_location=None) -> Any: 224 """Unpickles the resource from the package, loading any modules that are needed to construct the objects 225 using :meth:`import_module`. 226 227 Args: 228 package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). 229 resource (str): The unique name for the resource. 230 map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``. 231 232 Returns: 233 Any: The unpickled object. 234 """ 235 pickle_file = self._zipfile_path(package, resource) 236 restore_location = _get_restore_location(map_location) 237 loaded_storages = {} 238 loaded_reduces = {} 239 storage_context = torch._C.DeserializationStorageContext() 240 241 def load_tensor(dtype, size, key, location, restore_location): 242 name = f"{key}.storage" 243 244 if storage_context.has_storage(name): 245 storage = storage_context.get_storage(name, dtype)._typed_storage() 246 else: 247 tensor = self.zip_reader.get_storage_from_record( 248 ".data/" + name, size, dtype 249 ) 250 if isinstance(self.zip_reader, torch._C.PyTorchFileReader): 251 storage_context.add_storage(name, tensor) 252 storage = tensor._typed_storage() 253 loaded_storages[key] = restore_location(storage, location) 254 255 def persistent_load(saved_id): 256 assert isinstance(saved_id, tuple) 257 typename = _maybe_decode_ascii(saved_id[0]) 258 data = saved_id[1:] 259 260 if typename == "storage": 261 storage_type, key, location, size = data 262 dtype = storage_type.dtype 263 264 if key not in loaded_storages: 265 load_tensor( 266 dtype, 267 size, 268 key, 269 _maybe_decode_ascii(location), 270 restore_location, 271 ) 272 storage = loaded_storages[key] 273 # TODO: Once we decide to break serialization FC, we can 274 # stop wrapping with TypedStorage 275 return torch.storage.TypedStorage( 276 wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True 277 ) 278 elif typename == "reduce_package": 279 # to fix BC breaking change, objects on this load path 280 # will be loaded multiple times erroneously 281 if len(data) == 2: 282 func, args = data 283 return func(self, *args) 284 reduce_id, func, args = data 285 if reduce_id not in loaded_reduces: 286 loaded_reduces[reduce_id] = func(self, *args) 287 return loaded_reduces[reduce_id] 288 else: 289 f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'" 290 291 # Load the data (which may in turn use `persistent_load` to load tensors) 292 data_file = io.BytesIO(self.zip_reader.get_record(pickle_file)) 293 unpickler = self.Unpickler(data_file) 294 unpickler.persistent_load = persistent_load # type: ignore[assignment] 295 296 @contextmanager 297 def set_deserialization_context(): 298 # to let reduce_package access deserializaiton context 299 self.storage_context = storage_context 300 self.last_map_location = map_location 301 try: 302 yield 303 finally: 304 self.storage_context = None 305 self.last_map_location = None 306 307 with set_deserialization_context(): 308 result = unpickler.load() 309 310 # TODO from zdevito: 311 # This stateful weird function will need to be removed in our efforts 312 # to unify the format. It has a race condition if multiple python 313 # threads try to read independent files 314 torch._utils._validate_loaded_sparse_tensors() 315 316 return result 317 318 def id(self): 319 """ 320 Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances. 321 Looks like:: 322 323 <torch_package_0> 324 """ 325 return self._mangler.parent_name() 326 327 def file_structure( 328 self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = () 329 ) -> Directory: 330 """Returns a file structure representation of package's zipfile. 331 332 Args: 333 include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings 334 for the names of the files to be included in the zipfile representation. This can also be 335 a glob-style pattern, as described in :meth:`PackageExporter.mock` 336 337 exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern. 338 339 Returns: 340 :class:`Directory` 341 """ 342 return _create_directory_from_file_list( 343 self.filename, self.zip_reader.get_all_records(), include, exclude 344 ) 345 346 def python_version(self): 347 """Returns the version of python that was used to create this package. 348 349 Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock 350 file later on. 351 352 Returns: 353 :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package 354 """ 355 python_version_path = ".data/python_version" 356 return ( 357 self.zip_reader.get_record(python_version_path).decode("utf-8").strip() 358 if self.zip_reader.has_record(python_version_path) 359 else None 360 ) 361 362 def _read_extern(self): 363 return ( 364 self.zip_reader.get_record(".data/extern_modules") 365 .decode("utf-8") 366 .splitlines(keepends=False) 367 ) 368 369 def _make_module( 370 self, name: str, filename: Optional[str], is_package: bool, parent: str 371 ): 372 mangled_filename = self._mangler.mangle(filename) if filename else None 373 spec = importlib.machinery.ModuleSpec( 374 name, 375 self, # type: ignore[arg-type] 376 origin="<package_importer>", 377 is_package=is_package, 378 ) 379 module = importlib.util.module_from_spec(spec) 380 self.modules[name] = module 381 module.__name__ = self._mangler.mangle(name) 382 ns = module.__dict__ 383 ns["__spec__"] = spec 384 ns["__loader__"] = self 385 ns["__file__"] = mangled_filename 386 ns["__cached__"] = None 387 ns["__builtins__"] = self.patched_builtins 388 ns["__torch_package__"] = True 389 390 # Add this module to our private global registry. It should be unique due to mangling. 391 assert module.__name__ not in _package_imported_modules 392 _package_imported_modules[module.__name__] = module 393 394 # pre-emptively install on the parent to prevent IMPORT_FROM from trying to 395 # access sys.modules 396 self._install_on_parent(parent, name, module) 397 398 if filename is not None: 399 assert mangled_filename is not None 400 # pre-emptively install the source in `linecache` so that stack traces, 401 # `inspect`, etc. work. 402 assert filename not in linecache.cache # type: ignore[attr-defined] 403 linecache.lazycache(mangled_filename, ns) 404 405 code = self._compile_source(filename, mangled_filename) 406 exec(code, ns) 407 408 return module 409 410 def _load_module(self, name: str, parent: str): 411 cur: _PathNode = self.root 412 for atom in name.split("."): 413 if not isinstance(cur, _PackageNode) or atom not in cur.children: 414 if name in IMPLICIT_IMPORT_ALLOWLIST: 415 module = self.modules[name] = importlib.import_module(name) 416 return module 417 raise ModuleNotFoundError( 418 f'No module named "{name}" in self-contained archive "{self.filename}"' 419 f" and the module is also not in the list of allowed external modules: {self.extern_modules}", 420 name=name, 421 ) 422 cur = cur.children[atom] 423 if isinstance(cur, _ExternNode): 424 module = self.modules[name] = importlib.import_module(name) 425 426 if compat_mapping := EXTERN_IMPORT_COMPAT_NAME_MAPPING.get(name): 427 for old_name, new_name in compat_mapping.items(): 428 module.__dict__.setdefault(old_name, new_name) 429 430 return module 431 return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] 432 433 def _compile_source(self, fullpath: str, mangled_filename: str): 434 source = self.zip_reader.get_record(fullpath) 435 source = _normalize_line_endings(source) 436 return compile(source, mangled_filename, "exec", dont_inherit=True) 437 438 # note: named `get_source` so that linecache can find the source 439 # when this is the __loader__ of a module. 440 def get_source(self, module_name) -> str: 441 # linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here. 442 module = self.import_module(demangle(module_name)) 443 return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8") 444 445 # note: named `get_resource_reader` so that importlib.resources can find it. 446 # This is otherwise considered an internal method. 447 def get_resource_reader(self, fullname): 448 try: 449 package = self._get_package(fullname) 450 except ImportError: 451 return None 452 if package.__loader__ is not self: 453 return None 454 return _PackageResourceReader(self, fullname) 455 456 def _install_on_parent(self, parent: str, name: str, module: types.ModuleType): 457 if not parent: 458 return 459 # Set the module as an attribute on its parent. 460 parent_module = self.modules[parent] 461 if parent_module.__loader__ is self: 462 setattr(parent_module, name.rpartition(".")[2], module) 463 464 # note: copied from cpython's import code, with call to create module replaced with _make_module 465 def _do_find_and_load(self, name): 466 path = None 467 parent = name.rpartition(".")[0] 468 module_name_no_parent = name.rpartition(".")[-1] 469 if parent: 470 if parent not in self.modules: 471 self._gcd_import(parent) 472 # Crazy side-effects! 473 if name in self.modules: 474 return self.modules[name] 475 parent_module = self.modules[parent] 476 477 try: 478 path = parent_module.__path__ # type: ignore[attr-defined] 479 480 except AttributeError: 481 # when we attempt to import a package only containing pybinded files, 482 # the parent directory isn't always a package as defined by python, 483 # so we search if the package is actually there or not before calling the error. 484 if isinstance( 485 parent_module.__loader__, 486 importlib.machinery.ExtensionFileLoader, 487 ): 488 if name not in self.extern_modules: 489 msg = ( 490 _ERR_MSG 491 + "; {!r} is a c extension module which was not externed. C extension modules \ 492 need to be externed by the PackageExporter in order to be used as we do not support interning them.}." 493 ).format(name, name) 494 raise ModuleNotFoundError(msg, name=name) from None 495 if not isinstance( 496 parent_module.__dict__.get(module_name_no_parent), 497 types.ModuleType, 498 ): 499 msg = ( 500 _ERR_MSG 501 + "; {!r} is a c extension package which does not contain {!r}." 502 ).format(name, parent, name) 503 raise ModuleNotFoundError(msg, name=name) from None 504 else: 505 msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent) 506 raise ModuleNotFoundError(msg, name=name) from None 507 508 module = self._load_module(name, parent) 509 510 self._install_on_parent(parent, name, module) 511 512 return module 513 514 # note: copied from cpython's import code 515 def _find_and_load(self, name): 516 module = self.modules.get(name, _NEEDS_LOADING) 517 if module is _NEEDS_LOADING: 518 return self._do_find_and_load(name) 519 520 if module is None: 521 message = f"import of {name} halted; None in sys.modules" 522 raise ModuleNotFoundError(message, name=name) 523 524 # To handle https://github.com/pytorch/pytorch/issues/57490, where std's 525 # creation of fake submodules via the hacking of sys.modules is not import 526 # friendly 527 if name == "os": 528 self.modules["os.path"] = cast(Any, module).path 529 elif name == "typing": 530 self.modules["typing.io"] = cast(Any, module).io 531 self.modules["typing.re"] = cast(Any, module).re 532 533 return module 534 535 def _gcd_import(self, name, package=None, level=0): 536 """Import and return the module based on its name, the package the call is 537 being made from, and the level adjustment. 538 539 This function represents the greatest common denominator of functionality 540 between import_module and __import__. This includes setting __package__ if 541 the loader did not. 542 543 """ 544 _sanity_check(name, package, level) 545 if level > 0: 546 name = _resolve_name(name, package, level) 547 548 return self._find_and_load(name) 549 550 # note: copied from cpython's import code 551 def _handle_fromlist(self, module, fromlist, *, recursive=False): 552 """Figure out what __import__ should return. 553 554 The import_ parameter is a callable which takes the name of module to 555 import. It is required to decouple the function from assuming importlib's 556 import implementation is desired. 557 558 """ 559 module_name = demangle(module.__name__) 560 # The hell that is fromlist ... 561 # If a package was imported, try to import stuff from fromlist. 562 if hasattr(module, "__path__"): 563 for x in fromlist: 564 if not isinstance(x, str): 565 if recursive: 566 where = module_name + ".__all__" 567 else: 568 where = "``from list''" 569 raise TypeError( 570 f"Item in {where} must be str, " f"not {type(x).__name__}" 571 ) 572 elif x == "*": 573 if not recursive and hasattr(module, "__all__"): 574 self._handle_fromlist(module, module.__all__, recursive=True) 575 elif not hasattr(module, x): 576 from_name = f"{module_name}.{x}" 577 try: 578 self._gcd_import(from_name) 579 except ModuleNotFoundError as exc: 580 # Backwards-compatibility dictates we ignore failed 581 # imports triggered by fromlist for modules that don't 582 # exist. 583 if ( 584 exc.name == from_name 585 and self.modules.get(from_name, _NEEDS_LOADING) is not None 586 ): 587 continue 588 raise 589 return module 590 591 def __import__(self, name, globals=None, locals=None, fromlist=(), level=0): 592 if level == 0: 593 module = self._gcd_import(name) 594 else: 595 globals_ = globals if globals is not None else {} 596 package = _calc___package__(globals_) 597 module = self._gcd_import(name, package, level) 598 if not fromlist: 599 # Return up to the first dot in 'name'. This is complicated by the fact 600 # that 'name' may be relative. 601 if level == 0: 602 return self._gcd_import(name.partition(".")[0]) 603 elif not name: 604 return module 605 else: 606 # Figure out where to slice the module's name up to the first dot 607 # in 'name'. 608 cut_off = len(name) - len(name.partition(".")[0]) 609 # Slice end needs to be positive to alleviate need to special-case 610 # when ``'.' not in name``. 611 module_name = demangle(module.__name__) 612 return self.modules[module_name[: len(module_name) - cut_off]] 613 else: 614 return self._handle_fromlist(module, fromlist) 615 616 def _get_package(self, package): 617 """Take a package name or module object and return the module. 618 619 If a name, the module is imported. If the passed or imported module 620 object is not a package, raise an exception. 621 """ 622 if hasattr(package, "__spec__"): 623 if package.__spec__.submodule_search_locations is None: 624 raise TypeError(f"{package.__spec__.name!r} is not a package") 625 else: 626 return package 627 else: 628 module = self.import_module(package) 629 if module.__spec__.submodule_search_locations is None: 630 raise TypeError(f"{package!r} is not a package") 631 else: 632 return module 633 634 def _zipfile_path(self, package, resource=None): 635 package = self._get_package(package) 636 assert package.__loader__ is self 637 name = demangle(package.__name__) 638 if resource is not None: 639 resource = _normalize_path(resource) 640 return f"{name.replace('.', '/')}/{resource}" 641 else: 642 return f"{name.replace('.', '/')}" 643 644 def _get_or_create_package( 645 self, atoms: List[str] 646 ) -> "Union[_PackageNode, _ExternNode]": 647 cur = self.root 648 for i, atom in enumerate(atoms): 649 node = cur.children.get(atom, None) 650 if node is None: 651 node = cur.children[atom] = _PackageNode(None) 652 if isinstance(node, _ExternNode): 653 return node 654 if isinstance(node, _ModuleNode): 655 name = ".".join(atoms[:i]) 656 raise ImportError( 657 f"inconsistent module structure. module {name} is not a package, but has submodules" 658 ) 659 assert isinstance(node, _PackageNode) 660 cur = node 661 return cur 662 663 def _add_file(self, filename: str): 664 """Assembles a Python module out of the given file. Will ignore files in the .data directory. 665 666 Args: 667 filename (str): the name of the file inside of the package archive to be added 668 """ 669 *prefix, last = filename.split("/") 670 if len(prefix) > 1 and prefix[0] == ".data": 671 return 672 package = self._get_or_create_package(prefix) 673 if isinstance(package, _ExternNode): 674 raise ImportError( 675 f"inconsistent module structure. package contains a module file {filename}" 676 f" that is a subpackage of a module marked external." 677 ) 678 if last == "__init__.py": 679 package.source_file = filename 680 elif last.endswith(".py"): 681 package_name = last[: -len(".py")] 682 package.children[package_name] = _ModuleNode(filename) 683 684 def _add_extern(self, extern_name: str): 685 *prefix, last = extern_name.split(".") 686 package = self._get_or_create_package(prefix) 687 if isinstance(package, _ExternNode): 688 return # the shorter extern covers this extern case 689 package.children[last] = _ExternNode() 690 691 692_NEEDS_LOADING = object() 693_ERR_MSG_PREFIX = "No module named " 694_ERR_MSG = _ERR_MSG_PREFIX + "{!r}" 695 696 697class _PathNode: 698 pass 699 700 701class _PackageNode(_PathNode): 702 def __init__(self, source_file: Optional[str]): 703 self.source_file = source_file 704 self.children: Dict[str, _PathNode] = {} 705 706 707class _ModuleNode(_PathNode): 708 __slots__ = ["source_file"] 709 710 def __init__(self, source_file: str): 711 self.source_file = source_file 712 713 714class _ExternNode(_PathNode): 715 pass 716 717 718# A private global registry of all modules that have been package-imported. 719_package_imported_modules: WeakValueDictionary = WeakValueDictionary() 720 721# `inspect` by default only looks in `sys.modules` to find source files for classes. 722# Patch it to check our private registry of package-imported modules as well. 723_orig_getfile = inspect.getfile 724 725 726def _patched_getfile(object): 727 if inspect.isclass(object): 728 if object.__module__ in _package_imported_modules: 729 return _package_imported_modules[object.__module__].__file__ 730 return _orig_getfile(object) 731 732 733inspect.getfile = _patched_getfile 734 735 736class _PackageResourceReader: 737 """Private class used to support PackageImporter.get_resource_reader(). 738 739 Confirms to the importlib.abc.ResourceReader interface. Allowed to access 740 the innards of PackageImporter. 741 """ 742 743 def __init__(self, importer, fullname): 744 self.importer = importer 745 self.fullname = fullname 746 747 def open_resource(self, resource): 748 from io import BytesIO 749 750 return BytesIO(self.importer.load_binary(self.fullname, resource)) 751 752 def resource_path(self, resource): 753 # The contract for resource_path is that it either returns a concrete 754 # file system path or raises FileNotFoundError. 755 if isinstance( 756 self.importer.zip_reader, DirectoryReader 757 ) and self.importer.zip_reader.has_record( 758 os.path.join(self.fullname, resource) 759 ): 760 return os.path.join( 761 self.importer.zip_reader.directory, self.fullname, resource 762 ) 763 raise FileNotFoundError 764 765 def is_resource(self, name): 766 path = self.importer._zipfile_path(self.fullname, name) 767 return self.importer.zip_reader.has_record(path) 768 769 def contents(self): 770 from pathlib import Path 771 772 filename = self.fullname.replace(".", "/") 773 774 fullname_path = Path(self.importer._zipfile_path(self.fullname)) 775 files = self.importer.zip_reader.get_all_records() 776 subdirs_seen = set() 777 for filename in files: 778 try: 779 relative = Path(filename).relative_to(fullname_path) 780 except ValueError: 781 continue 782 # If the path of the file (which is relative to the top of the zip 783 # namespace), relative to the package given when the resource 784 # reader was created, has a parent, then it's a name in a 785 # subdirectory and thus we skip it. 786 parent_name = relative.parent.name 787 if len(parent_name) == 0: 788 yield relative.name 789 elif parent_name not in subdirs_seen: 790 subdirs_seen.add(parent_name) 791 yield parent_name 792