1# mypy: allow-untyped-defs 2from pickle import ( # type: ignore[attr-defined] 3 _compat_pickle, 4 _extension_registry, 5 _getattribute, 6 _Pickler, 7 EXT1, 8 EXT2, 9 EXT4, 10 GLOBAL, 11 Pickler, 12 PicklingError, 13 STACK_GLOBAL, 14) 15from struct import pack 16from types import FunctionType 17 18from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer 19 20 21class PackagePickler(_Pickler): 22 """Package-aware pickler. 23 24 This behaves the same as a normal pickler, except it uses an `Importer` 25 to find objects and modules to save. 26 """ 27 28 def __init__(self, importer: Importer, *args, **kwargs): 29 self.importer = importer 30 super().__init__(*args, **kwargs) 31 32 # Make sure the dispatch table copied from _Pickler is up-to-date. 33 # Previous issues have been encountered where a library (e.g. dill) 34 # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib 35 # is imported, then the offending library removes its dispatch entries, 36 # leaving PackagePickler with a stale dispatch table that may cause 37 # unwanted behavior. 38 self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc] 39 self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment] 40 41 def save_global(self, obj, name=None): 42 # unfortunately the pickler code is factored in a way that 43 # forces us to copy/paste this function. The only change is marked 44 # CHANGED below. 45 write = self.write # type: ignore[attr-defined] 46 memo = self.memo # type: ignore[attr-defined] 47 48 # CHANGED: import module from module environment instead of __import__ 49 try: 50 module_name, name = self.importer.get_name(obj, name) 51 except (ObjNotFoundError, ObjMismatchError) as err: 52 raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None 53 54 module = self.importer.import_module(module_name) 55 _, parent = _getattribute(module, name) 56 # END CHANGED 57 58 if self.proto >= 2: # type: ignore[attr-defined] 59 code = _extension_registry.get((module_name, name)) 60 if code: 61 assert code > 0 62 if code <= 0xFF: 63 write(EXT1 + pack("<B", code)) 64 elif code <= 0xFFFF: 65 write(EXT2 + pack("<H", code)) 66 else: 67 write(EXT4 + pack("<i", code)) 68 return 69 lastname = name.rpartition(".")[2] 70 if parent is module: 71 name = lastname 72 # Non-ASCII identifiers are supported only with protocols >= 3. 73 if self.proto >= 4: # type: ignore[attr-defined] 74 self.save(module_name) # type: ignore[attr-defined] 75 self.save(name) # type: ignore[attr-defined] 76 write(STACK_GLOBAL) 77 elif parent is not module: 78 self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined] 79 elif self.proto >= 3: # type: ignore[attr-defined] 80 write( 81 GLOBAL 82 + bytes(module_name, "utf-8") 83 + b"\n" 84 + bytes(name, "utf-8") 85 + b"\n" 86 ) 87 else: 88 if self.fix_imports: # type: ignore[attr-defined] 89 r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING 90 r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING 91 if (module_name, name) in r_name_mapping: 92 module_name, name = r_name_mapping[(module_name, name)] 93 elif module_name in r_import_mapping: 94 module_name = r_import_mapping[module_name] 95 try: 96 write( 97 GLOBAL 98 + bytes(module_name, "ascii") 99 + b"\n" 100 + bytes(name, "ascii") 101 + b"\n" 102 ) 103 except UnicodeEncodeError: 104 raise PicklingError( 105 "can't pickle global identifier '%s.%s' using " 106 "pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined] 107 ) from None 108 109 self.memoize(obj) # type: ignore[attr-defined] 110 111 112def create_pickler(data_buf, importer, protocol=4): 113 if importer is sys_importer: 114 # if we are using the normal import library system, then 115 # we can use the C implementation of pickle which is faster 116 return Pickler(data_buf, protocol=protocol) 117 else: 118 return PackagePickler(importer, data_buf, protocol=protocol) 119