xref: /aosp_15_r20/external/pytorch/torch/_deploy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import io
3
4import torch
5from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
6from torch.package._package_pickler import create_pickler
7from torch.package._package_unpickler import PackageUnpickler
8from torch.serialization import _maybe_decode_ascii
9
10
11def _save_storages(importer, obj):
12    serialized_storages = []
13    serialized_dtypes = []
14
15    importer = importer if isinstance(importer, torch.package.PackageImporter) else None
16    importers: Importer
17    if importer is not None:
18        importers = OrderedImporter(importer, sys_importer)
19    else:
20        importers = sys_importer
21
22    def persistent_id(obj):
23        if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
24            if isinstance(obj, torch.storage.TypedStorage):
25                # TODO: Once we decide to break serialization FC, we can
26                # remove this case
27                dtype = obj.dtype
28            else:
29                dtype = torch.uint8
30
31            serialized_storages.append(obj)
32            serialized_dtypes.append(dtype)
33            return ("storage", len(serialized_storages) - 1)
34
35        if hasattr(obj, "__reduce_deploy__"):
36            if _serialized_reduces.get(id(obj)) is None:
37                _serialized_reduces[id(obj)] = (
38                    "reduce_deploy",
39                    id(obj),
40                    *obj.__reduce_deploy__(importers),
41                )
42            return _serialized_reduces[id(obj)]
43
44        return None
45
46    # Write the pickle data for `obj`
47    data_buf = io.BytesIO()
48    pickler = create_pickler(data_buf, importers)
49    pickler.persistent_id = persistent_id
50    pickler.dump(obj)
51    data_value = data_buf.getvalue()
52    return (
53        data_value,
54        serialized_storages,
55        serialized_dtypes,
56        importer.zip_reader if importer else None,
57    )
58
59
60def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
61    def persistent_load(saved_id):
62        assert isinstance(saved_id, tuple)
63        typename = _maybe_decode_ascii(saved_id[0])
64        data = saved_id[1:]
65
66        if typename == "storage":
67            # TODO: Once we decide to break serialization FC, we can
68            # stop wrapping with TypedStorage
69            storage = serialized_storages[data[0]]
70            dtype = serialized_dtypes[data[0]]
71            return torch.storage.TypedStorage(
72                wrap_storage=storage.untyped(), dtype=dtype
73            )
74
75        if typename == "reduce_deploy":
76            reduce_id, func, args = data
77            if reduce_id not in _loaded_reduces:
78                _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
79            return _loaded_reduces[reduce_id]
80
81        return None
82
83    importer: Importer
84    if zip_reader is not None:
85        importer = OrderedImporter(_get_package(zip_reader), sys_importer)
86    else:
87        importer = sys_importer
88
89    unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
90    unpickler.persistent_load = persistent_load  # type: ignore[method-assign]
91    result = _deploy_objects[id] = unpickler.load()
92    return result
93
94
95def _get_package(zip_reader):
96    if zip_reader not in _raw_packages:
97        _raw_packages[zip_reader] = PackageImporter(zip_reader)
98    return _raw_packages[zip_reader]
99
100
101_raw_packages: dict = {}
102_deploy_objects: dict = {}
103_serialized_reduces: dict = {}
104_loaded_reduces: dict = {}
105