xref: /aosp_15_r20/external/pytorch/torch/package/_package_pickler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from pickle import (  # type: ignore[attr-defined]
3    _compat_pickle,
4    _extension_registry,
5    _getattribute,
6    _Pickler,
7    EXT1,
8    EXT2,
9    EXT4,
10    GLOBAL,
11    Pickler,
12    PicklingError,
13    STACK_GLOBAL,
14)
15from struct import pack
16from types import FunctionType
17
18from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
19
20
21class PackagePickler(_Pickler):
22    """Package-aware pickler.
23
24    This behaves the same as a normal pickler, except it uses an `Importer`
25    to find objects and modules to save.
26    """
27
28    def __init__(self, importer: Importer, *args, **kwargs):
29        self.importer = importer
30        super().__init__(*args, **kwargs)
31
32        # Make sure the dispatch table copied from _Pickler is up-to-date.
33        # Previous issues have been encountered where a library (e.g. dill)
34        # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
35        # is imported, then the offending library removes its dispatch entries,
36        # leaving PackagePickler with a stale dispatch table that may cause
37        # unwanted behavior.
38        self.dispatch = _Pickler.dispatch.copy()  # type: ignore[misc]
39        self.dispatch[FunctionType] = PackagePickler.save_global  # type: ignore[assignment]
40
41    def save_global(self, obj, name=None):
42        # unfortunately the pickler code is factored in a way that
43        # forces us to copy/paste this function. The only change is marked
44        # CHANGED below.
45        write = self.write  # type: ignore[attr-defined]
46        memo = self.memo  # type: ignore[attr-defined]
47
48        # CHANGED: import module from module environment instead of __import__
49        try:
50            module_name, name = self.importer.get_name(obj, name)
51        except (ObjNotFoundError, ObjMismatchError) as err:
52            raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
53
54        module = self.importer.import_module(module_name)
55        _, parent = _getattribute(module, name)
56        # END CHANGED
57
58        if self.proto >= 2:  # type: ignore[attr-defined]
59            code = _extension_registry.get((module_name, name))
60            if code:
61                assert code > 0
62                if code <= 0xFF:
63                    write(EXT1 + pack("<B", code))
64                elif code <= 0xFFFF:
65                    write(EXT2 + pack("<H", code))
66                else:
67                    write(EXT4 + pack("<i", code))
68                return
69        lastname = name.rpartition(".")[2]
70        if parent is module:
71            name = lastname
72        # Non-ASCII identifiers are supported only with protocols >= 3.
73        if self.proto >= 4:  # type: ignore[attr-defined]
74            self.save(module_name)  # type: ignore[attr-defined]
75            self.save(name)  # type: ignore[attr-defined]
76            write(STACK_GLOBAL)
77        elif parent is not module:
78            self.save_reduce(getattr, (parent, lastname))  # type: ignore[attr-defined]
79        elif self.proto >= 3:  # type: ignore[attr-defined]
80            write(
81                GLOBAL
82                + bytes(module_name, "utf-8")
83                + b"\n"
84                + bytes(name, "utf-8")
85                + b"\n"
86            )
87        else:
88            if self.fix_imports:  # type: ignore[attr-defined]
89                r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
90                r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
91                if (module_name, name) in r_name_mapping:
92                    module_name, name = r_name_mapping[(module_name, name)]
93                elif module_name in r_import_mapping:
94                    module_name = r_import_mapping[module_name]
95            try:
96                write(
97                    GLOBAL
98                    + bytes(module_name, "ascii")
99                    + b"\n"
100                    + bytes(name, "ascii")
101                    + b"\n"
102                )
103            except UnicodeEncodeError:
104                raise PicklingError(
105                    "can't pickle global identifier '%s.%s' using "
106                    "pickle protocol %i" % (module, name, self.proto)  # type: ignore[attr-defined]
107                ) from None
108
109        self.memoize(obj)  # type: ignore[attr-defined]
110
111
112def create_pickler(data_buf, importer, protocol=4):
113    if importer is sys_importer:
114        # if we are using the normal import library system, then
115        # we can use the C implementation of pickle which is faster
116        return Pickler(data_buf, protocol=protocol)
117    else:
118        return PackagePickler(importer, data_buf, protocol=protocol)
119