xref: /aosp_15_r20/external/pytorch/torch/utils/show_pickle.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3import sys
4import pickle
5import struct
6import pprint
7import zipfile
8import fnmatch
9from typing import Any, IO, BinaryIO, Union
10
11__all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"]
12
13class FakeObject:
14    def __init__(self, module, name, args):
15        self.module = module
16        self.name = name
17        self.args = args
18        # NOTE: We don't distinguish between state never set and state set to None.
19        self.state = None
20
21    def __repr__(self):
22        state_str = "" if self.state is None else f"(state={self.state!r})"
23        return f"{self.module}.{self.name}{self.args!r}{state_str}"
24
25    def __setstate__(self, state):
26        self.state = state
27
28    @staticmethod
29    def pp_format(printer, obj, stream, indent, allowance, context, level):
30        if not obj.args and obj.state is None:
31            stream.write(repr(obj))
32            return
33        if obj.state is None:
34            stream.write(f"{obj.module}.{obj.name}")
35            printer._format(obj.args, stream, indent + 1, allowance + 1, context, level)
36            return
37        if not obj.args:
38            stream.write(f"{obj.module}.{obj.name}()(state=\n")
39            indent += printer._indent_per_level
40            stream.write(" " * indent)
41            printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
42            stream.write(")")
43            return
44        raise Exception("Need to implement")  # noqa: TRY002
45
46
47class FakeClass:
48    def __init__(self, module, name):
49        self.module = module
50        self.name = name
51        self.__new__ = self.fake_new  # type: ignore[assignment]
52
53    def __repr__(self):
54        return f"{self.module}.{self.name}"
55
56    def __call__(self, *args):
57        return FakeObject(self.module, self.name, args)
58
59    def fake_new(self, *args):
60        return FakeObject(self.module, self.name, args[1:])
61
62
63class DumpUnpickler(pickle._Unpickler):  # type: ignore[name-defined]
64    def __init__(
65            self,
66            file,
67            *,
68            catch_invalid_utf8=False,
69            **kwargs):
70        super().__init__(file, **kwargs)
71        self.catch_invalid_utf8 = catch_invalid_utf8
72
73    def find_class(self, module, name):
74        return FakeClass(module, name)
75
76    def persistent_load(self, pid):
77        return FakeObject("pers", "obj", (pid,))
78
79    dispatch = dict(pickle._Unpickler.dispatch)  # type: ignore[attr-defined]
80
81    # Custom objects in TorchScript are able to return invalid UTF-8 strings
82    # from their pickle (__getstate__) functions.  Install a custom loader
83    # for strings that catches the decode exception and replaces it with
84    # a sentinel object.
85    def load_binunicode(self):
86        strlen, = struct.unpack("<I", self.read(4))  # type: ignore[attr-defined]
87        if strlen > sys.maxsize:
88            raise Exception("String too long.")  # noqa: TRY002
89        str_bytes = self.read(strlen)  # type: ignore[attr-defined]
90        obj: Any
91        try:
92            obj = str(str_bytes, "utf-8", "surrogatepass")
93        except UnicodeDecodeError as exn:
94            if not self.catch_invalid_utf8:
95                raise
96            obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),))
97        self.append(obj)  # type: ignore[attr-defined]
98    dispatch[pickle.BINUNICODE[0]] = load_binunicode  # type: ignore[assignment]
99
100    @classmethod
101    def dump(cls, in_stream, out_stream):
102        value = cls(in_stream).load()
103        pprint.pprint(value, stream=out_stream)
104        return value
105
106
107def main(argv, output_stream=None):
108    if len(argv) != 2:
109        # Don't spam stderr if not using stdout.
110        if output_stream is not None:
111            raise Exception("Pass argv of length 2.")  # noqa: TRY002
112        sys.stderr.write("usage: show_pickle PICKLE_FILE\n")
113        sys.stderr.write("  PICKLE_FILE can be any of:\n")
114        sys.stderr.write("    path to a pickle file\n")
115        sys.stderr.write("    [email protected]\n")
116        sys.stderr.write("    file.zip@*/pattern.*\n")
117        sys.stderr.write("      (shell glob pattern for members)\n")
118        sys.stderr.write("      (only first match will be shown)\n")
119        return 2
120
121    fname = argv[1]
122    handle: Union[IO[bytes], BinaryIO]
123    if "@" not in fname:
124        with open(fname, "rb") as handle:
125            DumpUnpickler.dump(handle, output_stream)
126    else:
127        zfname, mname = fname.split("@", 1)
128        with zipfile.ZipFile(zfname) as zf:
129            if "*" not in mname:
130                with zf.open(mname) as handle:
131                    DumpUnpickler.dump(handle, output_stream)
132            else:
133                found = False
134                for info in zf.infolist():
135                    if fnmatch.fnmatch(info.filename, mname):
136                        with zf.open(info) as handle:
137                            DumpUnpickler.dump(handle, output_stream)
138                        found = True
139                        break
140                if not found:
141                    raise Exception(f"Could not find member matching {mname} in {zfname}")  # noqa: TRY002
142
143
144if __name__ == "__main__":
145    # This hack works on every version of Python I've tested.
146    # I've tested on the following versions:
147    #   3.7.4
148    if True:
149        pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format  # type: ignore[attr-defined]
150
151    sys.exit(main(sys.argv))
152