1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport io 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.package import Importer, OrderedImporter, PackageImporter, sys_importer 6*da0073e9SAndroid Build Coastguard Workerfrom torch.package._package_pickler import create_pickler 7*da0073e9SAndroid Build Coastguard Workerfrom torch.package._package_unpickler import PackageUnpickler 8*da0073e9SAndroid Build Coastguard Workerfrom torch.serialization import _maybe_decode_ascii 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerdef _save_storages(importer, obj): 12*da0073e9SAndroid Build Coastguard Worker serialized_storages = [] 13*da0073e9SAndroid Build Coastguard Worker serialized_dtypes = [] 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker importer = importer if isinstance(importer, torch.package.PackageImporter) else None 16*da0073e9SAndroid Build Coastguard Worker importers: Importer 17*da0073e9SAndroid Build Coastguard Worker if importer is not None: 18*da0073e9SAndroid Build Coastguard Worker importers = OrderedImporter(importer, sys_importer) 19*da0073e9SAndroid Build Coastguard Worker else: 20*da0073e9SAndroid Build Coastguard Worker importers = sys_importer 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker def persistent_id(obj): 23*da0073e9SAndroid Build Coastguard Worker if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): 24*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, torch.storage.TypedStorage): 25*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, we can 26*da0073e9SAndroid Build Coastguard Worker # remove this case 27*da0073e9SAndroid Build Coastguard Worker dtype = obj.dtype 28*da0073e9SAndroid Build Coastguard Worker else: 29*da0073e9SAndroid Build Coastguard Worker dtype = torch.uint8 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker serialized_storages.append(obj) 32*da0073e9SAndroid Build Coastguard Worker serialized_dtypes.append(dtype) 33*da0073e9SAndroid Build Coastguard Worker return ("storage", len(serialized_storages) - 1) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker if hasattr(obj, "__reduce_deploy__"): 36*da0073e9SAndroid Build Coastguard Worker if _serialized_reduces.get(id(obj)) is None: 37*da0073e9SAndroid Build Coastguard Worker _serialized_reduces[id(obj)] = ( 38*da0073e9SAndroid Build Coastguard Worker "reduce_deploy", 39*da0073e9SAndroid Build Coastguard Worker id(obj), 40*da0073e9SAndroid Build Coastguard Worker *obj.__reduce_deploy__(importers), 41*da0073e9SAndroid Build Coastguard Worker ) 42*da0073e9SAndroid Build Coastguard Worker return _serialized_reduces[id(obj)] 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker return None 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker # Write the pickle data for `obj` 47*da0073e9SAndroid Build Coastguard Worker data_buf = io.BytesIO() 48*da0073e9SAndroid Build Coastguard Worker pickler = create_pickler(data_buf, importers) 49*da0073e9SAndroid Build Coastguard Worker pickler.persistent_id = persistent_id 50*da0073e9SAndroid Build Coastguard Worker pickler.dump(obj) 51*da0073e9SAndroid Build Coastguard Worker data_value = data_buf.getvalue() 52*da0073e9SAndroid Build Coastguard Worker return ( 53*da0073e9SAndroid Build Coastguard Worker data_value, 54*da0073e9SAndroid Build Coastguard Worker serialized_storages, 55*da0073e9SAndroid Build Coastguard Worker serialized_dtypes, 56*da0073e9SAndroid Build Coastguard Worker importer.zip_reader if importer else None, 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Workerdef _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): 61*da0073e9SAndroid Build Coastguard Worker def persistent_load(saved_id): 62*da0073e9SAndroid Build Coastguard Worker assert isinstance(saved_id, tuple) 63*da0073e9SAndroid Build Coastguard Worker typename = _maybe_decode_ascii(saved_id[0]) 64*da0073e9SAndroid Build Coastguard Worker data = saved_id[1:] 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker if typename == "storage": 67*da0073e9SAndroid Build Coastguard Worker # TODO: Once we decide to break serialization FC, we can 68*da0073e9SAndroid Build Coastguard Worker # stop wrapping with TypedStorage 69*da0073e9SAndroid Build Coastguard Worker storage = serialized_storages[data[0]] 70*da0073e9SAndroid Build Coastguard Worker dtype = serialized_dtypes[data[0]] 71*da0073e9SAndroid Build Coastguard Worker return torch.storage.TypedStorage( 72*da0073e9SAndroid Build Coastguard Worker wrap_storage=storage.untyped(), dtype=dtype 73*da0073e9SAndroid Build Coastguard Worker ) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker if typename == "reduce_deploy": 76*da0073e9SAndroid Build Coastguard Worker reduce_id, func, args = data 77*da0073e9SAndroid Build Coastguard Worker if reduce_id not in _loaded_reduces: 78*da0073e9SAndroid Build Coastguard Worker _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) 79*da0073e9SAndroid Build Coastguard Worker return _loaded_reduces[reduce_id] 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker return None 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker importer: Importer 84*da0073e9SAndroid Build Coastguard Worker if zip_reader is not None: 85*da0073e9SAndroid Build Coastguard Worker importer = OrderedImporter(_get_package(zip_reader), sys_importer) 86*da0073e9SAndroid Build Coastguard Worker else: 87*da0073e9SAndroid Build Coastguard Worker importer = sys_importer 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) 90*da0073e9SAndroid Build Coastguard Worker unpickler.persistent_load = persistent_load # type: ignore[method-assign] 91*da0073e9SAndroid Build Coastguard Worker result = _deploy_objects[id] = unpickler.load() 92*da0073e9SAndroid Build Coastguard Worker return result 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Workerdef _get_package(zip_reader): 96*da0073e9SAndroid Build Coastguard Worker if zip_reader not in _raw_packages: 97*da0073e9SAndroid Build Coastguard Worker _raw_packages[zip_reader] = PackageImporter(zip_reader) 98*da0073e9SAndroid Build Coastguard Worker return _raw_packages[zip_reader] 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker_raw_packages: dict = {} 102*da0073e9SAndroid Build Coastguard Worker_deploy_objects: dict = {} 103*da0073e9SAndroid Build Coastguard Worker_serialized_reduces: dict = {} 104*da0073e9SAndroid Build Coastguard Worker_loaded_reduces: dict = {} 105