xref: /aosp_15_r20/external/pytorch/torch/_deploy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport io
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
6*da0073e9SAndroid Build Coastguard Workerfrom torch.package._package_pickler import create_pickler
7*da0073e9SAndroid Build Coastguard Workerfrom torch.package._package_unpickler import PackageUnpickler
8*da0073e9SAndroid Build Coastguard Workerfrom torch.serialization import _maybe_decode_ascii
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerdef _save_storages(importer, obj):
12*da0073e9SAndroid Build Coastguard Worker    serialized_storages = []
13*da0073e9SAndroid Build Coastguard Worker    serialized_dtypes = []
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker    importer = importer if isinstance(importer, torch.package.PackageImporter) else None
16*da0073e9SAndroid Build Coastguard Worker    importers: Importer
17*da0073e9SAndroid Build Coastguard Worker    if importer is not None:
18*da0073e9SAndroid Build Coastguard Worker        importers = OrderedImporter(importer, sys_importer)
19*da0073e9SAndroid Build Coastguard Worker    else:
20*da0073e9SAndroid Build Coastguard Worker        importers = sys_importer
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    def persistent_id(obj):
23*da0073e9SAndroid Build Coastguard Worker        if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
24*da0073e9SAndroid Build Coastguard Worker            if isinstance(obj, torch.storage.TypedStorage):
25*da0073e9SAndroid Build Coastguard Worker                # TODO: Once we decide to break serialization FC, we can
26*da0073e9SAndroid Build Coastguard Worker                # remove this case
27*da0073e9SAndroid Build Coastguard Worker                dtype = obj.dtype
28*da0073e9SAndroid Build Coastguard Worker            else:
29*da0073e9SAndroid Build Coastguard Worker                dtype = torch.uint8
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker            serialized_storages.append(obj)
32*da0073e9SAndroid Build Coastguard Worker            serialized_dtypes.append(dtype)
33*da0073e9SAndroid Build Coastguard Worker            return ("storage", len(serialized_storages) - 1)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        if hasattr(obj, "__reduce_deploy__"):
36*da0073e9SAndroid Build Coastguard Worker            if _serialized_reduces.get(id(obj)) is None:
37*da0073e9SAndroid Build Coastguard Worker                _serialized_reduces[id(obj)] = (
38*da0073e9SAndroid Build Coastguard Worker                    "reduce_deploy",
39*da0073e9SAndroid Build Coastguard Worker                    id(obj),
40*da0073e9SAndroid Build Coastguard Worker                    *obj.__reduce_deploy__(importers),
41*da0073e9SAndroid Build Coastguard Worker                )
42*da0073e9SAndroid Build Coastguard Worker            return _serialized_reduces[id(obj)]
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        return None
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    # Write the pickle data for `obj`
47*da0073e9SAndroid Build Coastguard Worker    data_buf = io.BytesIO()
48*da0073e9SAndroid Build Coastguard Worker    pickler = create_pickler(data_buf, importers)
49*da0073e9SAndroid Build Coastguard Worker    pickler.persistent_id = persistent_id
50*da0073e9SAndroid Build Coastguard Worker    pickler.dump(obj)
51*da0073e9SAndroid Build Coastguard Worker    data_value = data_buf.getvalue()
52*da0073e9SAndroid Build Coastguard Worker    return (
53*da0073e9SAndroid Build Coastguard Worker        data_value,
54*da0073e9SAndroid Build Coastguard Worker        serialized_storages,
55*da0073e9SAndroid Build Coastguard Worker        serialized_dtypes,
56*da0073e9SAndroid Build Coastguard Worker        importer.zip_reader if importer else None,
57*da0073e9SAndroid Build Coastguard Worker    )
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerdef _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
61*da0073e9SAndroid Build Coastguard Worker    def persistent_load(saved_id):
62*da0073e9SAndroid Build Coastguard Worker        assert isinstance(saved_id, tuple)
63*da0073e9SAndroid Build Coastguard Worker        typename = _maybe_decode_ascii(saved_id[0])
64*da0073e9SAndroid Build Coastguard Worker        data = saved_id[1:]
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        if typename == "storage":
67*da0073e9SAndroid Build Coastguard Worker            # TODO: Once we decide to break serialization FC, we can
68*da0073e9SAndroid Build Coastguard Worker            # stop wrapping with TypedStorage
69*da0073e9SAndroid Build Coastguard Worker            storage = serialized_storages[data[0]]
70*da0073e9SAndroid Build Coastguard Worker            dtype = serialized_dtypes[data[0]]
71*da0073e9SAndroid Build Coastguard Worker            return torch.storage.TypedStorage(
72*da0073e9SAndroid Build Coastguard Worker                wrap_storage=storage.untyped(), dtype=dtype
73*da0073e9SAndroid Build Coastguard Worker            )
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        if typename == "reduce_deploy":
76*da0073e9SAndroid Build Coastguard Worker            reduce_id, func, args = data
77*da0073e9SAndroid Build Coastguard Worker            if reduce_id not in _loaded_reduces:
78*da0073e9SAndroid Build Coastguard Worker                _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
79*da0073e9SAndroid Build Coastguard Worker            return _loaded_reduces[reduce_id]
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        return None
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    importer: Importer
84*da0073e9SAndroid Build Coastguard Worker    if zip_reader is not None:
85*da0073e9SAndroid Build Coastguard Worker        importer = OrderedImporter(_get_package(zip_reader), sys_importer)
86*da0073e9SAndroid Build Coastguard Worker    else:
87*da0073e9SAndroid Build Coastguard Worker        importer = sys_importer
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
90*da0073e9SAndroid Build Coastguard Worker    unpickler.persistent_load = persistent_load  # type: ignore[method-assign]
91*da0073e9SAndroid Build Coastguard Worker    result = _deploy_objects[id] = unpickler.load()
92*da0073e9SAndroid Build Coastguard Worker    return result
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Workerdef _get_package(zip_reader):
96*da0073e9SAndroid Build Coastguard Worker    if zip_reader not in _raw_packages:
97*da0073e9SAndroid Build Coastguard Worker        _raw_packages[zip_reader] = PackageImporter(zip_reader)
98*da0073e9SAndroid Build Coastguard Worker    return _raw_packages[zip_reader]
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker_raw_packages: dict = {}
102*da0073e9SAndroid Build Coastguard Worker_deploy_objects: dict = {}
103*da0073e9SAndroid Build Coastguard Worker_serialized_reduces: dict = {}
104*da0073e9SAndroid Build Coastguard Worker_loaded_reduces: dict = {}
105