1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3import sys 4import pickle 5import struct 6import pprint 7import zipfile 8import fnmatch 9from typing import Any, IO, BinaryIO, Union 10 11__all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"] 12 13class FakeObject: 14 def __init__(self, module, name, args): 15 self.module = module 16 self.name = name 17 self.args = args 18 # NOTE: We don't distinguish between state never set and state set to None. 19 self.state = None 20 21 def __repr__(self): 22 state_str = "" if self.state is None else f"(state={self.state!r})" 23 return f"{self.module}.{self.name}{self.args!r}{state_str}" 24 25 def __setstate__(self, state): 26 self.state = state 27 28 @staticmethod 29 def pp_format(printer, obj, stream, indent, allowance, context, level): 30 if not obj.args and obj.state is None: 31 stream.write(repr(obj)) 32 return 33 if obj.state is None: 34 stream.write(f"{obj.module}.{obj.name}") 35 printer._format(obj.args, stream, indent + 1, allowance + 1, context, level) 36 return 37 if not obj.args: 38 stream.write(f"{obj.module}.{obj.name}()(state=\n") 39 indent += printer._indent_per_level 40 stream.write(" " * indent) 41 printer._format(obj.state, stream, indent, allowance + 1, context, level + 1) 42 stream.write(")") 43 return 44 raise Exception("Need to implement") # noqa: TRY002 45 46 47class FakeClass: 48 def __init__(self, module, name): 49 self.module = module 50 self.name = name 51 self.__new__ = self.fake_new # type: ignore[assignment] 52 53 def __repr__(self): 54 return f"{self.module}.{self.name}" 55 56 def __call__(self, *args): 57 return FakeObject(self.module, self.name, args) 58 59 def fake_new(self, *args): 60 return FakeObject(self.module, self.name, args[1:]) 61 62 63class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined] 64 def __init__( 65 self, 66 file, 67 *, 68 catch_invalid_utf8=False, 69 **kwargs): 70 super().__init__(file, **kwargs) 71 self.catch_invalid_utf8 = catch_invalid_utf8 72 73 def find_class(self, module, name): 74 return FakeClass(module, name) 75 76 def persistent_load(self, pid): 77 return FakeObject("pers", "obj", (pid,)) 78 79 dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined] 80 81 # Custom objects in TorchScript are able to return invalid UTF-8 strings 82 # from their pickle (__getstate__) functions. Install a custom loader 83 # for strings that catches the decode exception and replaces it with 84 # a sentinel object. 85 def load_binunicode(self): 86 strlen, = struct.unpack("<I", self.read(4)) # type: ignore[attr-defined] 87 if strlen > sys.maxsize: 88 raise Exception("String too long.") # noqa: TRY002 89 str_bytes = self.read(strlen) # type: ignore[attr-defined] 90 obj: Any 91 try: 92 obj = str(str_bytes, "utf-8", "surrogatepass") 93 except UnicodeDecodeError as exn: 94 if not self.catch_invalid_utf8: 95 raise 96 obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),)) 97 self.append(obj) # type: ignore[attr-defined] 98 dispatch[pickle.BINUNICODE[0]] = load_binunicode # type: ignore[assignment] 99 100 @classmethod 101 def dump(cls, in_stream, out_stream): 102 value = cls(in_stream).load() 103 pprint.pprint(value, stream=out_stream) 104 return value 105 106 107def main(argv, output_stream=None): 108 if len(argv) != 2: 109 # Don't spam stderr if not using stdout. 110 if output_stream is not None: 111 raise Exception("Pass argv of length 2.") # noqa: TRY002 112 sys.stderr.write("usage: show_pickle PICKLE_FILE\n") 113 sys.stderr.write(" PICKLE_FILE can be any of:\n") 114 sys.stderr.write(" path to a pickle file\n") 115 sys.stderr.write(" [email protected]\n") 116 sys.stderr.write(" file.zip@*/pattern.*\n") 117 sys.stderr.write(" (shell glob pattern for members)\n") 118 sys.stderr.write(" (only first match will be shown)\n") 119 return 2 120 121 fname = argv[1] 122 handle: Union[IO[bytes], BinaryIO] 123 if "@" not in fname: 124 with open(fname, "rb") as handle: 125 DumpUnpickler.dump(handle, output_stream) 126 else: 127 zfname, mname = fname.split("@", 1) 128 with zipfile.ZipFile(zfname) as zf: 129 if "*" not in mname: 130 with zf.open(mname) as handle: 131 DumpUnpickler.dump(handle, output_stream) 132 else: 133 found = False 134 for info in zf.infolist(): 135 if fnmatch.fnmatch(info.filename, mname): 136 with zf.open(info) as handle: 137 DumpUnpickler.dump(handle, output_stream) 138 found = True 139 break 140 if not found: 141 raise Exception(f"Could not find member matching {mname} in {zfname}") # noqa: TRY002 142 143 144if __name__ == "__main__": 145 # This hack works on every version of Python I've tested. 146 # I've tested on the following versions: 147 # 3.7.4 148 if True: 149 pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore[attr-defined] 150 151 sys.exit(main(sys.argv)) 152