1# mypy: allow-untyped-defs 2import io 3 4import torch 5from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer 6from torch.package._package_pickler import create_pickler 7from torch.package._package_unpickler import PackageUnpickler 8from torch.serialization import _maybe_decode_ascii 9 10 11def _save_storages(importer, obj): 12 serialized_storages = [] 13 serialized_dtypes = [] 14 15 importer = importer if isinstance(importer, torch.package.PackageImporter) else None 16 importers: Importer 17 if importer is not None: 18 importers = OrderedImporter(importer, sys_importer) 19 else: 20 importers = sys_importer 21 22 def persistent_id(obj): 23 if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): 24 if isinstance(obj, torch.storage.TypedStorage): 25 # TODO: Once we decide to break serialization FC, we can 26 # remove this case 27 dtype = obj.dtype 28 else: 29 dtype = torch.uint8 30 31 serialized_storages.append(obj) 32 serialized_dtypes.append(dtype) 33 return ("storage", len(serialized_storages) - 1) 34 35 if hasattr(obj, "__reduce_deploy__"): 36 if _serialized_reduces.get(id(obj)) is None: 37 _serialized_reduces[id(obj)] = ( 38 "reduce_deploy", 39 id(obj), 40 *obj.__reduce_deploy__(importers), 41 ) 42 return _serialized_reduces[id(obj)] 43 44 return None 45 46 # Write the pickle data for `obj` 47 data_buf = io.BytesIO() 48 pickler = create_pickler(data_buf, importers) 49 pickler.persistent_id = persistent_id 50 pickler.dump(obj) 51 data_value = data_buf.getvalue() 52 return ( 53 data_value, 54 serialized_storages, 55 serialized_dtypes, 56 importer.zip_reader if importer else None, 57 ) 58 59 60def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): 61 def persistent_load(saved_id): 62 assert isinstance(saved_id, tuple) 63 typename = _maybe_decode_ascii(saved_id[0]) 64 data = saved_id[1:] 65 66 if typename == "storage": 67 # TODO: Once we decide to break serialization FC, we can 68 # stop wrapping with TypedStorage 69 storage = serialized_storages[data[0]] 70 dtype = serialized_dtypes[data[0]] 71 return torch.storage.TypedStorage( 72 wrap_storage=storage.untyped(), dtype=dtype 73 ) 74 75 if typename == "reduce_deploy": 76 reduce_id, func, args = data 77 if reduce_id not in _loaded_reduces: 78 _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) 79 return _loaded_reduces[reduce_id] 80 81 return None 82 83 importer: Importer 84 if zip_reader is not None: 85 importer = OrderedImporter(_get_package(zip_reader), sys_importer) 86 else: 87 importer = sys_importer 88 89 unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) 90 unpickler.persistent_load = persistent_load # type: ignore[method-assign] 91 result = _deploy_objects[id] = unpickler.load() 92 return result 93 94 95def _get_package(zip_reader): 96 if zip_reader not in _raw_packages: 97 _raw_packages[zip_reader] = PackageImporter(zip_reader) 98 return _raw_packages[zip_reader] 99 100 101_raw_packages: dict = {} 102_deploy_objects: dict = {} 103_serialized_reduces: dict = {} 104_loaded_reduces: dict = {} 105