1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore-all-errors 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport base64 10*523fa7a6SAndroid Build Coastguard Workerimport io 11*523fa7a6SAndroid Build Coastguard Workerimport json 12*523fa7a6SAndroid Build Coastguard Workerimport logging 13*523fa7a6SAndroid Build Coastguard Workerimport operator 14*523fa7a6SAndroid Build Coastguard Workerimport os 15*523fa7a6SAndroid Build Coastguard Workerimport zipfile 16*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 19*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.memory as memory 20*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.serde.export_serialize as export_serialize 21*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.serde.schema as schema 22*523fa7a6SAndroid Build Coastguard Workerimport torch 23*523fa7a6SAndroid Build Coastguard Workerimport torch.export.exported_program as ep 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import delegate 25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import ( 26*523fa7a6SAndroid Build Coastguard Worker CompileSpec as delegate_CompileSpec, 27*523fa7a6SAndroid Build Coastguard Worker) 28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import _DialectNamespace, ops as exir_ops 29*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.backend._ops import BackendOpOverload 30*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload 31*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import ( 32*523fa7a6SAndroid Build Coastguard Worker LoweredBackendModule as ExirLoweredBackendModule, 33*523fa7a6SAndroid Build Coastguard Worker) 34*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.serde.export_serialize import GraphModuleOpUpgrader, SerializeError 35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.serde.schema import ( 36*523fa7a6SAndroid Build Coastguard Worker CompileSpec, 37*523fa7a6SAndroid Build Coastguard Worker LoweredBackendModule as SerdeLoweredBackendModule, 38*523fa7a6SAndroid Build Coastguard Worker SCHEMA_VERSION, 39*523fa7a6SAndroid Build Coastguard Worker SchemaVersion, 40*523fa7a6SAndroid Build Coastguard Worker) 41*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.verifier import load_verifier 42*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental import symbolic_shapes 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Workerlog: logging.Logger = logging.getLogger(__name__) 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Workerclass GraphModuleSerializer(export_serialize.GraphModuleSerializer): 48*523fa7a6SAndroid Build Coastguard Worker def __init__( 49*523fa7a6SAndroid Build Coastguard Worker self, 50*523fa7a6SAndroid Build Coastguard Worker graph_signature: ep.ExportGraphSignature, 51*523fa7a6SAndroid Build Coastguard Worker module_call_graph: List[ep.ModuleCallEntry], 52*523fa7a6SAndroid Build Coastguard Worker ) -> None: 53*523fa7a6SAndroid Build Coastguard Worker super().__init__(graph_signature, module_call_graph) 54*523fa7a6SAndroid Build Coastguard Worker self.state_dict: Dict[str, torch.Tensor] = {} # TODO(T157676982) 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker def serialize_operator( 57*523fa7a6SAndroid Build Coastguard Worker self, 58*523fa7a6SAndroid Build Coastguard Worker target: Union[ 59*523fa7a6SAndroid Build Coastguard Worker str, 60*523fa7a6SAndroid Build Coastguard Worker EdgeOpOverload, 61*523fa7a6SAndroid Build Coastguard Worker BackendOpOverload, 62*523fa7a6SAndroid Build Coastguard Worker torch._ops.OpOverload, 63*523fa7a6SAndroid Build Coastguard Worker torch._ops.HigherOrderOperator, 64*523fa7a6SAndroid Build Coastguard Worker ], 65*523fa7a6SAndroid Build Coastguard Worker ) -> str: 66*523fa7a6SAndroid Build Coastguard Worker if isinstance(target, str): 67*523fa7a6SAndroid Build Coastguard Worker return target 68*523fa7a6SAndroid Build Coastguard Worker elif target.__module__.startswith("executorch.exir.dialects.edge"): 69*523fa7a6SAndroid Build Coastguard Worker # TODO(zhxchen17) Maybe provide a function name helper in FX. 70*523fa7a6SAndroid Build Coastguard Worker # From torch.fx.node._get_qualified_name 71*523fa7a6SAndroid Build Coastguard Worker module = target.__module__.replace( 72*523fa7a6SAndroid Build Coastguard Worker "executorch.exir.dialects.edge._ops", 73*523fa7a6SAndroid Build Coastguard Worker "executorch.exir.dialects.edge.ops", 74*523fa7a6SAndroid Build Coastguard Worker ) 75*523fa7a6SAndroid Build Coastguard Worker return f"{module}.{target.__name__}" 76*523fa7a6SAndroid Build Coastguard Worker elif target.__module__.startswith("executorch.exir.dialects.backend"): 77*523fa7a6SAndroid Build Coastguard Worker module = target.__module__.replace( 78*523fa7a6SAndroid Build Coastguard Worker "executorch.exir.dialects.backend._ops", 79*523fa7a6SAndroid Build Coastguard Worker "executorch.exir.dialects.backend.ops", 80*523fa7a6SAndroid Build Coastguard Worker ) 81*523fa7a6SAndroid Build Coastguard Worker return f"{module}.{target.__name__}" 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker return super().serialize_operator(target) 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard Worker def handle_call_function(self, node: torch.fx.Node) -> None: 86*523fa7a6SAndroid Build Coastguard Worker assert node.op == "call_function" 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker if node.target is memory.alloc: 89*523fa7a6SAndroid Build Coastguard Worker ex_node = schema.Node( 90*523fa7a6SAndroid Build Coastguard Worker target="memory.alloc", 91*523fa7a6SAndroid Build Coastguard Worker inputs=self.serialize_alloc_inputs(node.args), 92*523fa7a6SAndroid Build Coastguard Worker outputs=self.serialize_arbitrary_outputs(node), 93*523fa7a6SAndroid Build Coastguard Worker metadata=self.serialize_metadata(node), 94*523fa7a6SAndroid Build Coastguard Worker ) 95*523fa7a6SAndroid Build Coastguard Worker self.graph_state.nodes.append(ex_node) 96*523fa7a6SAndroid Build Coastguard Worker return 97*523fa7a6SAndroid Build Coastguard Worker elif isinstance(node.target, EdgeOpOverload): 98*523fa7a6SAndroid Build Coastguard Worker assert node.target._op is not None 99*523fa7a6SAndroid Build Coastguard Worker ex_node = schema.Node( 100*523fa7a6SAndroid Build Coastguard Worker target=self.serialize_operator(node.target), 101*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore Undefined attribute [16]: Item `typing.Callable` of 102*523fa7a6SAndroid Build Coastguard Worker # `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`. 103*523fa7a6SAndroid Build Coastguard Worker inputs=self.serialize_inputs(node.target._op, node.args, node.kwargs), 104*523fa7a6SAndroid Build Coastguard Worker outputs=self.serialize_outputs(node), 105*523fa7a6SAndroid Build Coastguard Worker # TODO: create a new tensor_values here, meta might have faketensor info 106*523fa7a6SAndroid Build Coastguard Worker metadata=self.serialize_metadata(node), 107*523fa7a6SAndroid Build Coastguard Worker ) 108*523fa7a6SAndroid Build Coastguard Worker self.graph_state.nodes.append(ex_node) 109*523fa7a6SAndroid Build Coastguard Worker return 110*523fa7a6SAndroid Build Coastguard Worker elif node.target is delegate.executorch_call_delegate: 111*523fa7a6SAndroid Build Coastguard Worker ex_node = schema.Node( 112*523fa7a6SAndroid Build Coastguard Worker target=self.serialize_operator(node.target), 113*523fa7a6SAndroid Build Coastguard Worker inputs=self.serialize_call_delegate_inputs(node.args), 114*523fa7a6SAndroid Build Coastguard Worker outputs=self.serialize_arbitrary_outputs(node), 115*523fa7a6SAndroid Build Coastguard Worker metadata=self.serialize_metadata(node), 116*523fa7a6SAndroid Build Coastguard Worker ) 117*523fa7a6SAndroid Build Coastguard Worker self.graph_state.nodes.append(ex_node) 118*523fa7a6SAndroid Build Coastguard Worker return 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard Worker super().handle_call_function(node) 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker def serialize_outputs(self, node: torch.fx.Node) -> List[schema.Argument]: 123*523fa7a6SAndroid Build Coastguard Worker if isinstance(node.target, EdgeOpOverload): 124*523fa7a6SAndroid Build Coastguard Worker # Store the original edge op 125*523fa7a6SAndroid Build Coastguard Worker edge_op = node.target 126*523fa7a6SAndroid Build Coastguard Worker # Replace the edge op with the original ATen op so that we can just call into 127*523fa7a6SAndroid Build Coastguard Worker # the serialize_outputs implementation present in the parent class. 128*523fa7a6SAndroid Build Coastguard Worker node.target = edge_op._op 129*523fa7a6SAndroid Build Coastguard Worker ret = super().serialize_outputs(node) 130*523fa7a6SAndroid Build Coastguard Worker # Replace the edge op back. 131*523fa7a6SAndroid Build Coastguard Worker node.target = edge_op 132*523fa7a6SAndroid Build Coastguard Worker else: 133*523fa7a6SAndroid Build Coastguard Worker ret = super().serialize_outputs(node) 134*523fa7a6SAndroid Build Coastguard Worker return ret 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: 137*523fa7a6SAndroid Build Coastguard Worker meta = super().serialize_metadata(node) 138*523fa7a6SAndroid Build Coastguard Worker 139*523fa7a6SAndroid Build Coastguard Worker if "debug_handle" in node.meta: 140*523fa7a6SAndroid Build Coastguard Worker debug_handle = node.meta["debug_handle"] 141*523fa7a6SAndroid Build Coastguard Worker meta["debug_handle"] = str(debug_handle) 142*523fa7a6SAndroid Build Coastguard Worker 143*523fa7a6SAndroid Build Coastguard Worker return meta 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker def serialize_alloc_inputs( 146*523fa7a6SAndroid Build Coastguard Worker self, inputs # pyre-ignore 147*523fa7a6SAndroid Build Coastguard Worker ) -> List[schema.NamedArgument]: 148*523fa7a6SAndroid Build Coastguard Worker """ 149*523fa7a6SAndroid Build Coastguard Worker Serialize the inputs to the memory.alloc function. Since there's no 150*523fa7a6SAndroid Build Coastguard Worker specific spec, we jut serialize the inputs with a dummy name. 151*523fa7a6SAndroid Build Coastguard Worker We serialize the AllocSpec into a string "size;dtype" 152*523fa7a6SAndroid Build Coastguard Worker """ 153*523fa7a6SAndroid Build Coastguard Worker assert len(inputs) == 1 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker def serialize_alloc_spec(alloc_spec: memory.AllocSpec) -> schema.Argument: 156*523fa7a6SAndroid Build Coastguard Worker return schema.Argument.create( 157*523fa7a6SAndroid Build Coastguard Worker as_string=f"{alloc_spec[0]};{export_serialize._TORCH_TO_SERIALIZE_DTYPE[alloc_spec[1]].value}" 158*523fa7a6SAndroid Build Coastguard Worker ) 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Worker if isinstance(inputs[0], list): 161*523fa7a6SAndroid Build Coastguard Worker return [ 162*523fa7a6SAndroid Build Coastguard Worker schema.NamedArgument(name="alloc_list", arg=serialize_alloc_spec(arg)) 163*523fa7a6SAndroid Build Coastguard Worker for arg in inputs[0] 164*523fa7a6SAndroid Build Coastguard Worker ] 165*523fa7a6SAndroid Build Coastguard Worker else: 166*523fa7a6SAndroid Build Coastguard Worker # Single value 167*523fa7a6SAndroid Build Coastguard Worker return [ 168*523fa7a6SAndroid Build Coastguard Worker schema.NamedArgument( 169*523fa7a6SAndroid Build Coastguard Worker name="alloc_arg", arg=serialize_alloc_spec(inputs[0]) 170*523fa7a6SAndroid Build Coastguard Worker ) 171*523fa7a6SAndroid Build Coastguard Worker ] 172*523fa7a6SAndroid Build Coastguard Worker 173*523fa7a6SAndroid Build Coastguard Worker def serialize_arbitrary_outputs(self, node: torch.fx.Node) -> List[schema.Argument]: 174*523fa7a6SAndroid Build Coastguard Worker meta_val = node.meta["val"] 175*523fa7a6SAndroid Build Coastguard Worker 176*523fa7a6SAndroid Build Coastguard Worker # Check single value return 177*523fa7a6SAndroid Build Coastguard Worker if isinstance(meta_val, torch.Tensor): 178*523fa7a6SAndroid Build Coastguard Worker return [ 179*523fa7a6SAndroid Build Coastguard Worker schema.Argument.create( 180*523fa7a6SAndroid Build Coastguard Worker as_tensor=self.serialize_tensor_output(node.name, meta_val) 181*523fa7a6SAndroid Build Coastguard Worker ) 182*523fa7a6SAndroid Build Coastguard Worker ] 183*523fa7a6SAndroid Build Coastguard Worker 184*523fa7a6SAndroid Build Coastguard Worker # There are a two possibilities at this point: 185*523fa7a6SAndroid Build Coastguard Worker # - This operator returns a list of Tensors. 186*523fa7a6SAndroid Build Coastguard Worker # - This operator returns multiple Tensors. 187*523fa7a6SAndroid Build Coastguard Worker # 188*523fa7a6SAndroid Build Coastguard Worker # Either way, start by gathering a list of TensorArguments with the correct names. 189*523fa7a6SAndroid Build Coastguard Worker # For consistent naming with FX, consult the downstream `getitem` node and 190*523fa7a6SAndroid Build Coastguard Worker # make sure our outputs have the same name. 191*523fa7a6SAndroid Build Coastguard Worker idx_to_name = {} 192*523fa7a6SAndroid Build Coastguard Worker for user in node.users: 193*523fa7a6SAndroid Build Coastguard Worker if user.target is not operator.getitem: 194*523fa7a6SAndroid Build Coastguard Worker continue 195*523fa7a6SAndroid Build Coastguard Worker idx_to_name[user.args[1]] = user.name 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Worker for idx, _ in enumerate(meta_val): 198*523fa7a6SAndroid Build Coastguard Worker # FX does not emit a getitem node for any outputs that are unused. 199*523fa7a6SAndroid Build Coastguard Worker # However, we need a name for them so that the number of outputs will 200*523fa7a6SAndroid Build Coastguard Worker # correctly match the schema. Just assign a dummy name. 201*523fa7a6SAndroid Build Coastguard Worker if idx not in idx_to_name: 202*523fa7a6SAndroid Build Coastguard Worker idx_to_name[idx] = f"{node.name}_unused_{idx}" 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker arg_list = [] 205*523fa7a6SAndroid Build Coastguard Worker for i, element_meta_val in enumerate(meta_val): 206*523fa7a6SAndroid Build Coastguard Worker arg_list.append( 207*523fa7a6SAndroid Build Coastguard Worker self.serialize_tensor_output(idx_to_name[i], element_meta_val) 208*523fa7a6SAndroid Build Coastguard Worker ) 209*523fa7a6SAndroid Build Coastguard Worker 210*523fa7a6SAndroid Build Coastguard Worker if len(meta_val) == 1: 211*523fa7a6SAndroid Build Coastguard Worker # The operator returns a list of tensors 212*523fa7a6SAndroid Build Coastguard Worker return [schema.Argument.create(as_tensors=arg_list)] 213*523fa7a6SAndroid Build Coastguard Worker else: 214*523fa7a6SAndroid Build Coastguard Worker # The operator returns multiple tensors 215*523fa7a6SAndroid Build Coastguard Worker return [schema.Argument.create(as_tensor=arg) for arg in arg_list] 216*523fa7a6SAndroid Build Coastguard Worker 217*523fa7a6SAndroid Build Coastguard Worker def serialize_graph(self, graph_module: torch.fx.GraphModule) -> schema.Graph: 218*523fa7a6SAndroid Build Coastguard Worker self.original_graph_module: torch.fx.GraphModule = graph_module # pyre-ignore 219*523fa7a6SAndroid Build Coastguard Worker return super().serialize_graph(graph_module) 220*523fa7a6SAndroid Build Coastguard Worker 221*523fa7a6SAndroid Build Coastguard Worker def serialize_call_delegate_inputs( 222*523fa7a6SAndroid Build Coastguard Worker self, args # pyre-ignore 223*523fa7a6SAndroid Build Coastguard Worker ) -> List[schema.NamedArgument]: 224*523fa7a6SAndroid Build Coastguard Worker lowered_module_arg = args[0] 225*523fa7a6SAndroid Build Coastguard Worker delegate_args = args[1:] 226*523fa7a6SAndroid Build Coastguard Worker 227*523fa7a6SAndroid Build Coastguard Worker serialized_lowered_module = self.serialize_lowered_module(lowered_module_arg) 228*523fa7a6SAndroid Build Coastguard Worker serialized_lowered_module_arg = schema.NamedArgument( 229*523fa7a6SAndroid Build Coastguard Worker name=lowered_module_arg.target, 230*523fa7a6SAndroid Build Coastguard Worker arg=schema.Argument.create(as_string=serialized_lowered_module), 231*523fa7a6SAndroid Build Coastguard Worker ) 232*523fa7a6SAndroid Build Coastguard Worker 233*523fa7a6SAndroid Build Coastguard Worker serialized_args = [serialized_lowered_module_arg] 234*523fa7a6SAndroid Build Coastguard Worker for i, arg in enumerate(delegate_args): 235*523fa7a6SAndroid Build Coastguard Worker serialized_args.append( 236*523fa7a6SAndroid Build Coastguard Worker schema.NamedArgument( 237*523fa7a6SAndroid Build Coastguard Worker name=f"delegate_arg_{i}", arg=self.serialize_input(arg) 238*523fa7a6SAndroid Build Coastguard Worker ) 239*523fa7a6SAndroid Build Coastguard Worker ) 240*523fa7a6SAndroid Build Coastguard Worker return serialized_args 241*523fa7a6SAndroid Build Coastguard Worker 242*523fa7a6SAndroid Build Coastguard Worker def serialize_lowered_module(self, lowered_module_arg: torch.fx.Node) -> str: 243*523fa7a6SAndroid Build Coastguard Worker assert lowered_module_arg.op == "get_attr" 244*523fa7a6SAndroid Build Coastguard Worker assert isinstance(lowered_module_arg.target, str) 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker def serialize_bytes(b: bytes) -> str: 247*523fa7a6SAndroid Build Coastguard Worker # We want to serialize the bytes to string because JSON cannot 248*523fa7a6SAndroid Build Coastguard Worker # serialize bytes. 249*523fa7a6SAndroid Build Coastguard Worker # Since the given bytes may be serialized with any encoding, so we 250*523fa7a6SAndroid Build Coastguard Worker # want to first encode with base64, and then decode it with 251*523fa7a6SAndroid Build Coastguard Worker # ascii. During deserialization we can just directly decode with b64 252*523fa7a6SAndroid Build Coastguard Worker # to get the original encoded bytes. 253*523fa7a6SAndroid Build Coastguard Worker return base64.b64encode(b).decode("ascii") 254*523fa7a6SAndroid Build Coastguard Worker 255*523fa7a6SAndroid Build Coastguard Worker lowered_module = getattr( 256*523fa7a6SAndroid Build Coastguard Worker lowered_module_arg.graph.owning_module, lowered_module_arg.target 257*523fa7a6SAndroid Build Coastguard Worker ) 258*523fa7a6SAndroid Build Coastguard Worker assert isinstance(lowered_module, ExirLoweredBackendModule) 259*523fa7a6SAndroid Build Coastguard Worker 260*523fa7a6SAndroid Build Coastguard Worker serialized_compile_spec = [ 261*523fa7a6SAndroid Build Coastguard Worker CompileSpec(cs.key, serialize_bytes(cs.value)) 262*523fa7a6SAndroid Build Coastguard Worker for cs in lowered_module.compile_specs 263*523fa7a6SAndroid Build Coastguard Worker ] 264*523fa7a6SAndroid Build Coastguard Worker 265*523fa7a6SAndroid Build Coastguard Worker serialized_artifact = ExportedProgramSerializer().serialize( 266*523fa7a6SAndroid Build Coastguard Worker lowered_module.original_module 267*523fa7a6SAndroid Build Coastguard Worker ) 268*523fa7a6SAndroid Build Coastguard Worker assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram) 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Worker serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes) 271*523fa7a6SAndroid Build Coastguard Worker 272*523fa7a6SAndroid Build Coastguard Worker serialized_lowered_module = SerdeLoweredBackendModule( 273*523fa7a6SAndroid Build Coastguard Worker original_module=serialized_artifact.exported_program, 274*523fa7a6SAndroid Build Coastguard Worker original_state_dict=serialize_bytes(serialized_artifact.state_dict), 275*523fa7a6SAndroid Build Coastguard Worker original_constants=serialize_bytes(serialized_artifact.constants), 276*523fa7a6SAndroid Build Coastguard Worker processed_bytes=serialized_processed_bytes, 277*523fa7a6SAndroid Build Coastguard Worker compile_specs=serialized_compile_spec, 278*523fa7a6SAndroid Build Coastguard Worker backend_id=lowered_module.backend_id, 279*523fa7a6SAndroid Build Coastguard Worker ) 280*523fa7a6SAndroid Build Coastguard Worker 281*523fa7a6SAndroid Build Coastguard Worker json_lowered_module = json.dumps( 282*523fa7a6SAndroid Build Coastguard Worker export_serialize._dataclass_to_dict(serialized_lowered_module), 283*523fa7a6SAndroid Build Coastguard Worker cls=export_serialize.EnumEncoder, 284*523fa7a6SAndroid Build Coastguard Worker ) 285*523fa7a6SAndroid Build Coastguard Worker return json_lowered_module 286*523fa7a6SAndroid Build Coastguard Worker 287*523fa7a6SAndroid Build Coastguard Worker 288*523fa7a6SAndroid Build Coastguard Workerclass ExportedProgramSerializer(export_serialize.ExportedProgramSerializer): 289*523fa7a6SAndroid Build Coastguard Worker def serialize( 290*523fa7a6SAndroid Build Coastguard Worker self, exported_program: ep.ExportedProgram 291*523fa7a6SAndroid Build Coastguard Worker ) -> export_serialize._SerializedProgram: 292*523fa7a6SAndroid Build Coastguard Worker """ 293*523fa7a6SAndroid Build Coastguard Worker Args: 294*523fa7a6SAndroid Build Coastguard Worker exported_program: Exported Program to serialize 295*523fa7a6SAndroid Build Coastguard Worker """ 296*523fa7a6SAndroid Build Coastguard Worker 297*523fa7a6SAndroid Build Coastguard Worker assert isinstance(exported_program, ep.ExportedProgram) 298*523fa7a6SAndroid Build Coastguard Worker 299*523fa7a6SAndroid Build Coastguard Worker gm_serializer = GraphModuleSerializer( 300*523fa7a6SAndroid Build Coastguard Worker exported_program.graph_signature, exported_program.module_call_graph 301*523fa7a6SAndroid Build Coastguard Worker ) 302*523fa7a6SAndroid Build Coastguard Worker serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) 303*523fa7a6SAndroid Build Coastguard Worker serialized_range_constraints = export_serialize.serialize_range_constraints( 304*523fa7a6SAndroid Build Coastguard Worker exported_program.range_constraints 305*523fa7a6SAndroid Build Coastguard Worker ) 306*523fa7a6SAndroid Build Coastguard Worker 307*523fa7a6SAndroid Build Coastguard Worker # TODO: Directly serialize exported_program.constants once 308*523fa7a6SAndroid Build Coastguard Worker # CustomClassHolders get stored in the ExportedProgram rather than in 309*523fa7a6SAndroid Build Coastguard Worker # the graph 310*523fa7a6SAndroid Build Coastguard Worker constants = {} 311*523fa7a6SAndroid Build Coastguard Worker for n, c in gm_serializer.custom_objs.items(): 312*523fa7a6SAndroid Build Coastguard Worker constants[n] = c 313*523fa7a6SAndroid Build Coastguard Worker for n, t in exported_program.constants.items(): 314*523fa7a6SAndroid Build Coastguard Worker assert n not in constants 315*523fa7a6SAndroid Build Coastguard Worker constants[n] = t 316*523fa7a6SAndroid Build Coastguard Worker 317*523fa7a6SAndroid Build Coastguard Worker additional_kwargs = {} 318*523fa7a6SAndroid Build Coastguard Worker if hasattr(exported_program, "verifiers"): 319*523fa7a6SAndroid Build Coastguard Worker additional_kwargs["verifiers"] = [ 320*523fa7a6SAndroid Build Coastguard Worker v.dialect for v in exported_program.verifiers 321*523fa7a6SAndroid Build Coastguard Worker ] 322*523fa7a6SAndroid Build Coastguard Worker elif hasattr(exported_program, "dialect"): 323*523fa7a6SAndroid Build Coastguard Worker additional_kwargs["dialect"] = exported_program.dialect 324*523fa7a6SAndroid Build Coastguard Worker serialized_ep = schema.ExportedProgram( 325*523fa7a6SAndroid Build Coastguard Worker graph_module=serialized_graph_module, 326*523fa7a6SAndroid Build Coastguard Worker opset_version=self.opset_version, 327*523fa7a6SAndroid Build Coastguard Worker range_constraints=serialized_range_constraints, 328*523fa7a6SAndroid Build Coastguard Worker schema_version=SchemaVersion( 329*523fa7a6SAndroid Build Coastguard Worker major=SCHEMA_VERSION[0], 330*523fa7a6SAndroid Build Coastguard Worker minor=SCHEMA_VERSION[1], 331*523fa7a6SAndroid Build Coastguard Worker ), 332*523fa7a6SAndroid Build Coastguard Worker **additional_kwargs, 333*523fa7a6SAndroid Build Coastguard Worker ) 334*523fa7a6SAndroid Build Coastguard Worker 335*523fa7a6SAndroid Build Coastguard Worker # Test canonical form is well defined. 336*523fa7a6SAndroid Build Coastguard Worker # TODO : Doesn't pass currently on executorch graphs with alloc nodes. 337*523fa7a6SAndroid Build Coastguard Worker # canonicalize(serialized_ep) 338*523fa7a6SAndroid Build Coastguard Worker 339*523fa7a6SAndroid Build Coastguard Worker if exported_program.example_inputs is not None: 340*523fa7a6SAndroid Build Coastguard Worker example_inputs = export_serialize.serialize_torch_artifact( 341*523fa7a6SAndroid Build Coastguard Worker exported_program.example_inputs 342*523fa7a6SAndroid Build Coastguard Worker ) 343*523fa7a6SAndroid Build Coastguard Worker else: 344*523fa7a6SAndroid Build Coastguard Worker example_inputs = b"" 345*523fa7a6SAndroid Build Coastguard Worker 346*523fa7a6SAndroid Build Coastguard Worker return export_serialize._SerializedProgram( 347*523fa7a6SAndroid Build Coastguard Worker serialized_ep, 348*523fa7a6SAndroid Build Coastguard Worker export_serialize.serialize_torch_artifact(exported_program.state_dict), 349*523fa7a6SAndroid Build Coastguard Worker export_serialize.serialize_torch_artifact(constants), 350*523fa7a6SAndroid Build Coastguard Worker example_inputs, 351*523fa7a6SAndroid Build Coastguard Worker ) 352*523fa7a6SAndroid Build Coastguard Worker 353*523fa7a6SAndroid Build Coastguard Worker 354*523fa7a6SAndroid Build Coastguard Workerclass GraphModuleDeserializer(export_serialize.GraphModuleDeserializer): 355*523fa7a6SAndroid Build Coastguard Worker def deserialize_operator(self, serialized_target: str) -> str: 356*523fa7a6SAndroid Build Coastguard Worker def find_operator(module: _DialectNamespace, serialized_target: str) -> str: 357*523fa7a6SAndroid Build Coastguard Worker serialized_target_names = serialized_target.split(".")[5:] 358*523fa7a6SAndroid Build Coastguard Worker 359*523fa7a6SAndroid Build Coastguard Worker target = module 360*523fa7a6SAndroid Build Coastguard Worker for name in serialized_target_names: 361*523fa7a6SAndroid Build Coastguard Worker if not hasattr(target, name): 362*523fa7a6SAndroid Build Coastguard Worker return serialized_target 363*523fa7a6SAndroid Build Coastguard Worker else: 364*523fa7a6SAndroid Build Coastguard Worker target = getattr(target, name) 365*523fa7a6SAndroid Build Coastguard Worker return target 366*523fa7a6SAndroid Build Coastguard Worker 367*523fa7a6SAndroid Build Coastguard Worker if serialized_target.startswith("executorch.exir.dialects.edge.ops"): 368*523fa7a6SAndroid Build Coastguard Worker return find_operator(exir_ops.edge, serialized_target) 369*523fa7a6SAndroid Build Coastguard Worker elif serialized_target.startswith("executorch.exir.dialects.backend.ops"): 370*523fa7a6SAndroid Build Coastguard Worker return find_operator(exir_ops.backend, serialized_target) 371*523fa7a6SAndroid Build Coastguard Worker 372*523fa7a6SAndroid Build Coastguard Worker return super().deserialize_operator(serialized_target) 373*523fa7a6SAndroid Build Coastguard Worker 374*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 375*523fa7a6SAndroid Build Coastguard Worker def deserialize_inputs_no_schema(self, serialized_node) -> Any: 376*523fa7a6SAndroid Build Coastguard Worker return tuple( 377*523fa7a6SAndroid Build Coastguard Worker self.deserialize_input(input.arg) for input in serialized_node.inputs 378*523fa7a6SAndroid Build Coastguard Worker ) 379*523fa7a6SAndroid Build Coastguard Worker 380*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 381*523fa7a6SAndroid Build Coastguard Worker def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> None: 382*523fa7a6SAndroid Build Coastguard Worker if target == "memory.alloc": 383*523fa7a6SAndroid Build Coastguard Worker args = self.deserialize_alloc_inputs(serialized_node.inputs) 384*523fa7a6SAndroid Build Coastguard Worker fx_node = self.graph.create_node( 385*523fa7a6SAndroid Build Coastguard Worker "call_function", memory.alloc, args, {}, "alloc" 386*523fa7a6SAndroid Build Coastguard Worker ) 387*523fa7a6SAndroid Build Coastguard Worker 388*523fa7a6SAndroid Build Coastguard Worker self.deserialize_arbitrary_outputs(serialized_node, fx_node) 389*523fa7a6SAndroid Build Coastguard Worker 390*523fa7a6SAndroid Build Coastguard Worker fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) 391*523fa7a6SAndroid Build Coastguard Worker return 392*523fa7a6SAndroid Build Coastguard Worker 393*523fa7a6SAndroid Build Coastguard Worker elif target is delegate.executorch_call_delegate: 394*523fa7a6SAndroid Build Coastguard Worker if ( 395*523fa7a6SAndroid Build Coastguard Worker len(serialized_node.outputs) == 1 396*523fa7a6SAndroid Build Coastguard Worker and serialized_node.outputs[0].type == "as_tensor" 397*523fa7a6SAndroid Build Coastguard Worker ): 398*523fa7a6SAndroid Build Coastguard Worker # If it's a single tensor return then we can use the name of the 399*523fa7a6SAndroid Build Coastguard Worker # node itself 400*523fa7a6SAndroid Build Coastguard Worker name = serialized_node.outputs[0].value.name 401*523fa7a6SAndroid Build Coastguard Worker else: 402*523fa7a6SAndroid Build Coastguard Worker # Otherwise FX will make a name for us, and we'll have `getitem` 403*523fa7a6SAndroid Build Coastguard Worker # nodes pointed to that 404*523fa7a6SAndroid Build Coastguard Worker name = None 405*523fa7a6SAndroid Build Coastguard Worker 406*523fa7a6SAndroid Build Coastguard Worker args = self.deserialize_call_delegate_inputs(serialized_node.inputs) 407*523fa7a6SAndroid Build Coastguard Worker fx_node = self.graph.create_node("call_function", target, args, {}, name) 408*523fa7a6SAndroid Build Coastguard Worker 409*523fa7a6SAndroid Build Coastguard Worker self.deserialize_arbitrary_outputs(serialized_node, fx_node) 410*523fa7a6SAndroid Build Coastguard Worker 411*523fa7a6SAndroid Build Coastguard Worker fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) 412*523fa7a6SAndroid Build Coastguard Worker return 413*523fa7a6SAndroid Build Coastguard Worker elif isinstance(target, EdgeOpOverload): 414*523fa7a6SAndroid Build Coastguard Worker # For convenience: if this node returns a single tensor, name the 415*523fa7a6SAndroid Build Coastguard Worker # newly-created node after it. This ensures that these tensor values 416*523fa7a6SAndroid Build Coastguard Worker # have names that are consistent with serialized. 417*523fa7a6SAndroid Build Coastguard Worker name = ( 418*523fa7a6SAndroid Build Coastguard Worker serialized_node.outputs[0].value.name 419*523fa7a6SAndroid Build Coastguard Worker if export_serialize._is_single_tensor_return(target._op) 420*523fa7a6SAndroid Build Coastguard Worker else None # FX will generate a name for us. 421*523fa7a6SAndroid Build Coastguard Worker ) 422*523fa7a6SAndroid Build Coastguard Worker args, kwargs = self.deserialize_inputs(target._op, serialized_node) 423*523fa7a6SAndroid Build Coastguard Worker fx_node = self.graph.create_node( 424*523fa7a6SAndroid Build Coastguard Worker "call_function", target, args, kwargs, name 425*523fa7a6SAndroid Build Coastguard Worker ) 426*523fa7a6SAndroid Build Coastguard Worker self.deserialize_outputs(serialized_node, fx_node) 427*523fa7a6SAndroid Build Coastguard Worker fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) 428*523fa7a6SAndroid Build Coastguard Worker return 429*523fa7a6SAndroid Build Coastguard Worker elif isinstance(target, str): 430*523fa7a6SAndroid Build Coastguard Worker # Create a dummy fake op if the target does not exist 431*523fa7a6SAndroid Build Coastguard Worker # because we cannot create a call_function node w/o a 432*523fa7a6SAndroid Build Coastguard Worker # callable target 433*523fa7a6SAndroid Build Coastguard Worker log.warning( 434*523fa7a6SAndroid Build Coastguard Worker f"Could not find operator {target}. Returning fake operator." 435*523fa7a6SAndroid Build Coastguard Worker ) # noqa: G004 436*523fa7a6SAndroid Build Coastguard Worker 437*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 438*523fa7a6SAndroid Build Coastguard Worker def fake_op(x): 439*523fa7a6SAndroid Build Coastguard Worker raise NotImplementedError("Fake op is not meant to be run.") 440*523fa7a6SAndroid Build Coastguard Worker 441*523fa7a6SAndroid Build Coastguard Worker fake_op.__name__ = target 442*523fa7a6SAndroid Build Coastguard Worker target = fake_op 443*523fa7a6SAndroid Build Coastguard Worker 444*523fa7a6SAndroid Build Coastguard Worker args = self.deserialize_inputs_no_schema(serialized_node) 445*523fa7a6SAndroid Build Coastguard Worker fx_node = self.graph.create_node("call_function", target, args, None, None) 446*523fa7a6SAndroid Build Coastguard Worker self.deserialize_arbitrary_outputs(serialized_node, fx_node) 447*523fa7a6SAndroid Build Coastguard Worker 448*523fa7a6SAndroid Build Coastguard Worker return 449*523fa7a6SAndroid Build Coastguard Worker 450*523fa7a6SAndroid Build Coastguard Worker super().deserialize_node(serialized_node, target) 451*523fa7a6SAndroid Build Coastguard Worker 452*523fa7a6SAndroid Build Coastguard Worker def deserialize_outputs( 453*523fa7a6SAndroid Build Coastguard Worker self, serialized_node: schema.Node, fx_node: torch.fx.Node 454*523fa7a6SAndroid Build Coastguard Worker ) -> None: 455*523fa7a6SAndroid Build Coastguard Worker if isinstance(fx_node.target, EdgeOpOverload): 456*523fa7a6SAndroid Build Coastguard Worker # Store the original edge op 457*523fa7a6SAndroid Build Coastguard Worker edge_op = fx_node.target 458*523fa7a6SAndroid Build Coastguard Worker # Replace the edge op with the original ATen op so that we can just call into 459*523fa7a6SAndroid Build Coastguard Worker # node deserialize_outputs implementation present in the parent class. 460*523fa7a6SAndroid Build Coastguard Worker fx_node.target = edge_op._op 461*523fa7a6SAndroid Build Coastguard Worker super().deserialize_outputs(serialized_node, fx_node) 462*523fa7a6SAndroid Build Coastguard Worker # Replace the edge op back. 463*523fa7a6SAndroid Build Coastguard Worker fx_node.target = edge_op 464*523fa7a6SAndroid Build Coastguard Worker else: 465*523fa7a6SAndroid Build Coastguard Worker super().deserialize_outputs(serialized_node, fx_node) 466*523fa7a6SAndroid Build Coastguard Worker 467*523fa7a6SAndroid Build Coastguard Worker def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: 468*523fa7a6SAndroid Build Coastguard Worker res = super().deserialize_metadata(metadata) 469*523fa7a6SAndroid Build Coastguard Worker 470*523fa7a6SAndroid Build Coastguard Worker if debug_handle := metadata.get("debug_handle"): 471*523fa7a6SAndroid Build Coastguard Worker res["debug_handle"] = int(debug_handle) 472*523fa7a6SAndroid Build Coastguard Worker 473*523fa7a6SAndroid Build Coastguard Worker return res 474*523fa7a6SAndroid Build Coastguard Worker 475*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 476*523fa7a6SAndroid Build Coastguard Worker def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): 477*523fa7a6SAndroid Build Coastguard Worker def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec: 478*523fa7a6SAndroid Build Coastguard Worker serialized_alloc_spec_elems = serialized_alloc_spec.split(";") 479*523fa7a6SAndroid Build Coastguard Worker assert len(serialized_alloc_spec_elems) == 2 480*523fa7a6SAndroid Build Coastguard Worker serialized_size_elems = ( 481*523fa7a6SAndroid Build Coastguard Worker serialized_alloc_spec_elems[0].strip("()").split(",") 482*523fa7a6SAndroid Build Coastguard Worker ) 483*523fa7a6SAndroid Build Coastguard Worker 484*523fa7a6SAndroid Build Coastguard Worker size = tuple(int(x) for x in serialized_size_elems if x != "") 485*523fa7a6SAndroid Build Coastguard Worker dtype = export_serialize._SERIALIZE_TO_TORCH_DTYPE[ 486*523fa7a6SAndroid Build Coastguard Worker int(serialized_alloc_spec_elems[1]) 487*523fa7a6SAndroid Build Coastguard Worker ] 488*523fa7a6SAndroid Build Coastguard Worker return (size, dtype) 489*523fa7a6SAndroid Build Coastguard Worker 490*523fa7a6SAndroid Build Coastguard Worker assert serialized_inputs[0].arg.type == "as_string" 491*523fa7a6SAndroid Build Coastguard Worker 492*523fa7a6SAndroid Build Coastguard Worker # Single value 493*523fa7a6SAndroid Build Coastguard Worker if len(serialized_inputs) == 1 and serialized_inputs[0].name == "alloc_arg": 494*523fa7a6SAndroid Build Coastguard Worker res = (deserialize_alloc_spec(serialized_inputs[0].arg.value),) 495*523fa7a6SAndroid Build Coastguard Worker return res 496*523fa7a6SAndroid Build Coastguard Worker 497*523fa7a6SAndroid Build Coastguard Worker alloc_specs = [ 498*523fa7a6SAndroid Build Coastguard Worker deserialize_alloc_spec(serialized_input.arg.value) 499*523fa7a6SAndroid Build Coastguard Worker for serialized_input in serialized_inputs 500*523fa7a6SAndroid Build Coastguard Worker ] 501*523fa7a6SAndroid Build Coastguard Worker return (alloc_specs,) 502*523fa7a6SAndroid Build Coastguard Worker 503*523fa7a6SAndroid Build Coastguard Worker def deserialize_arbitrary_outputs( 504*523fa7a6SAndroid Build Coastguard Worker self, serialized_node: schema.Node, fx_node: torch.fx.Node 505*523fa7a6SAndroid Build Coastguard Worker ) -> None: 506*523fa7a6SAndroid Build Coastguard Worker if len(serialized_node.outputs) == 0: 507*523fa7a6SAndroid Build Coastguard Worker return 508*523fa7a6SAndroid Build Coastguard Worker # Single tensor return 509*523fa7a6SAndroid Build Coastguard Worker elif ( 510*523fa7a6SAndroid Build Coastguard Worker len(serialized_node.outputs) == 1 511*523fa7a6SAndroid Build Coastguard Worker and serialized_node.outputs[0].type == "as_tensor" 512*523fa7a6SAndroid Build Coastguard Worker ): 513*523fa7a6SAndroid Build Coastguard Worker return self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) 514*523fa7a6SAndroid Build Coastguard Worker elif len(serialized_node.outputs) == 1 and isinstance( 515*523fa7a6SAndroid Build Coastguard Worker serialized_node.outputs[0].value, 516*523fa7a6SAndroid Build Coastguard Worker (schema.SymIntArgument, schema.SymBoolArgument), 517*523fa7a6SAndroid Build Coastguard Worker ): 518*523fa7a6SAndroid Build Coastguard Worker self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) 519*523fa7a6SAndroid Build Coastguard Worker return 520*523fa7a6SAndroid Build Coastguard Worker 521*523fa7a6SAndroid Build Coastguard Worker self.deserialize_multiple_outputs(serialized_node, fx_node) 522*523fa7a6SAndroid Build Coastguard Worker 523*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 524*523fa7a6SAndroid Build Coastguard Worker def deserialize_call_delegate_inputs( 525*523fa7a6SAndroid Build Coastguard Worker self, serialized_inputs: List[schema.NamedArgument] 526*523fa7a6SAndroid Build Coastguard Worker ): 527*523fa7a6SAndroid Build Coastguard Worker serialized_lowered_module = serialized_inputs[0] 528*523fa7a6SAndroid Build Coastguard Worker lowered_module_node = self.deserialize_lowered_module(serialized_lowered_module) 529*523fa7a6SAndroid Build Coastguard Worker serialized_delegate_inputs = serialized_inputs[1:] 530*523fa7a6SAndroid Build Coastguard Worker args = tuple( 531*523fa7a6SAndroid Build Coastguard Worker self.deserialize_input(input.arg) for input in serialized_delegate_inputs 532*523fa7a6SAndroid Build Coastguard Worker ) 533*523fa7a6SAndroid Build Coastguard Worker return (lowered_module_node,) + args 534*523fa7a6SAndroid Build Coastguard Worker 535*523fa7a6SAndroid Build Coastguard Worker def deserialize_lowered_module( 536*523fa7a6SAndroid Build Coastguard Worker self, serialized_lowered_module_arg: schema.NamedArgument 537*523fa7a6SAndroid Build Coastguard Worker ) -> torch.fx.Node: 538*523fa7a6SAndroid Build Coastguard Worker assert serialized_lowered_module_arg.arg.type == "as_string" 539*523fa7a6SAndroid Build Coastguard Worker lowered_module_str = serialized_lowered_module_arg.arg.value 540*523fa7a6SAndroid Build Coastguard Worker json_lowered_module = json.loads(lowered_module_str) 541*523fa7a6SAndroid Build Coastguard Worker serialized_lowered_module = export_serialize._dict_to_dataclass( 542*523fa7a6SAndroid Build Coastguard Worker SerdeLoweredBackendModule, json_lowered_module 543*523fa7a6SAndroid Build Coastguard Worker ) 544*523fa7a6SAndroid Build Coastguard Worker 545*523fa7a6SAndroid Build Coastguard Worker backend_id = serialized_lowered_module.backend_id 546*523fa7a6SAndroid Build Coastguard Worker processed_bytes = base64.b64decode(serialized_lowered_module.processed_bytes) 547*523fa7a6SAndroid Build Coastguard Worker compile_specs = [ 548*523fa7a6SAndroid Build Coastguard Worker delegate_CompileSpec(key=cs.key, value=base64.b64decode(cs.value)) 549*523fa7a6SAndroid Build Coastguard Worker for cs in serialized_lowered_module.compile_specs 550*523fa7a6SAndroid Build Coastguard Worker ] 551*523fa7a6SAndroid Build Coastguard Worker 552*523fa7a6SAndroid Build Coastguard Worker original_module = ExportedProgramDeserializer().deserialize( 553*523fa7a6SAndroid Build Coastguard Worker serialized_lowered_module.original_module, 554*523fa7a6SAndroid Build Coastguard Worker base64.b64decode(serialized_lowered_module.original_state_dict), 555*523fa7a6SAndroid Build Coastguard Worker base64.b64decode(serialized_lowered_module.original_constants), 556*523fa7a6SAndroid Build Coastguard Worker None, 557*523fa7a6SAndroid Build Coastguard Worker ) 558*523fa7a6SAndroid Build Coastguard Worker 559*523fa7a6SAndroid Build Coastguard Worker lowered_module = ExirLoweredBackendModule( 560*523fa7a6SAndroid Build Coastguard Worker original_module, 561*523fa7a6SAndroid Build Coastguard Worker backend_id, 562*523fa7a6SAndroid Build Coastguard Worker processed_bytes, 563*523fa7a6SAndroid Build Coastguard Worker compile_specs, 564*523fa7a6SAndroid Build Coastguard Worker ) 565*523fa7a6SAndroid Build Coastguard Worker self.module.register_module(serialized_lowered_module_arg.name, lowered_module) 566*523fa7a6SAndroid Build Coastguard Worker return self.graph.get_attr(serialized_lowered_module_arg.name) 567*523fa7a6SAndroid Build Coastguard Worker 568*523fa7a6SAndroid Build Coastguard Worker 569*523fa7a6SAndroid Build Coastguard Workerclass ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer): 570*523fa7a6SAndroid Build Coastguard Worker def deserialize( 571*523fa7a6SAndroid Build Coastguard Worker self, 572*523fa7a6SAndroid Build Coastguard Worker exported_program: export_serialize.ExportedProgram, 573*523fa7a6SAndroid Build Coastguard Worker state_dict: Union[Dict[str, torch.Tensor], bytes], 574*523fa7a6SAndroid Build Coastguard Worker constants: Union[Dict[str, torch.Tensor], bytes], 575*523fa7a6SAndroid Build Coastguard Worker example_inputs: Optional[ 576*523fa7a6SAndroid Build Coastguard Worker Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes] 577*523fa7a6SAndroid Build Coastguard Worker ] = None, 578*523fa7a6SAndroid Build Coastguard Worker ) -> ep.ExportedProgram: 579*523fa7a6SAndroid Build Coastguard Worker assert isinstance(exported_program, export_serialize.ExportedProgram) 580*523fa7a6SAndroid Build Coastguard Worker version = exported_program.schema_version 581*523fa7a6SAndroid Build Coastguard Worker 582*523fa7a6SAndroid Build Coastguard Worker # TODO(zhxchen17) blocked on thrift schema refactor 583*523fa7a6SAndroid Build Coastguard Worker if version.major != SCHEMA_VERSION[0] and not ( 584*523fa7a6SAndroid Build Coastguard Worker version.major == 0 and version.minor == 0 585*523fa7a6SAndroid Build Coastguard Worker ): 586*523fa7a6SAndroid Build Coastguard Worker raise SerializeError( 587*523fa7a6SAndroid Build Coastguard Worker f"Serialized schema version {exported_program.schema_version} " 588*523fa7a6SAndroid Build Coastguard Worker f"does not match our current schema version {SCHEMA_VERSION}." 589*523fa7a6SAndroid Build Coastguard Worker ) 590*523fa7a6SAndroid Build Coastguard Worker 591*523fa7a6SAndroid Build Coastguard Worker symbol_name_to_range = { 592*523fa7a6SAndroid Build Coastguard Worker k: symbolic_shapes.ValueRanges( 593*523fa7a6SAndroid Build Coastguard Worker export_serialize._int_to_sympy_int(v.min_val), 594*523fa7a6SAndroid Build Coastguard Worker export_serialize._int_to_sympy_int(v.max_val), 595*523fa7a6SAndroid Build Coastguard Worker ) 596*523fa7a6SAndroid Build Coastguard Worker for k, v in exported_program.range_constraints.items() 597*523fa7a6SAndroid Build Coastguard Worker } 598*523fa7a6SAndroid Build Coastguard Worker res = GraphModuleDeserializer().deserialize( 599*523fa7a6SAndroid Build Coastguard Worker exported_program.graph_module, 600*523fa7a6SAndroid Build Coastguard Worker state_dict, 601*523fa7a6SAndroid Build Coastguard Worker constants, 602*523fa7a6SAndroid Build Coastguard Worker example_inputs, 603*523fa7a6SAndroid Build Coastguard Worker symbol_name_to_range, 604*523fa7a6SAndroid Build Coastguard Worker ) 605*523fa7a6SAndroid Build Coastguard Worker range_constraints = self.deserialize_range_constraints( 606*523fa7a6SAndroid Build Coastguard Worker symbol_name_to_range, 607*523fa7a6SAndroid Build Coastguard Worker res.names_to_symbols, 608*523fa7a6SAndroid Build Coastguard Worker ) 609*523fa7a6SAndroid Build Coastguard Worker model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version 610*523fa7a6SAndroid Build Coastguard Worker self._validate_model_opset_version(model_opset_version) 611*523fa7a6SAndroid Build Coastguard Worker 612*523fa7a6SAndroid Build Coastguard Worker upgrader = GraphModuleOpUpgrader( 613*523fa7a6SAndroid Build Coastguard Worker self.expected_opset_version, model_opset_version 614*523fa7a6SAndroid Build Coastguard Worker ) 615*523fa7a6SAndroid Build Coastguard Worker 616*523fa7a6SAndroid Build Coastguard Worker dummy_g = torch.fx.Graph() 617*523fa7a6SAndroid Build Coastguard Worker dummy_g.output(()) 618*523fa7a6SAndroid Build Coastguard Worker additional_kwargs = {} 619*523fa7a6SAndroid Build Coastguard Worker if hasattr(exported_program, "verifiers"): 620*523fa7a6SAndroid Build Coastguard Worker additional_kwargs["verifiers"] = [ 621*523fa7a6SAndroid Build Coastguard Worker load_verifier(v) for v in exported_program.verifiers # pyre-ignore 622*523fa7a6SAndroid Build Coastguard Worker ] 623*523fa7a6SAndroid Build Coastguard Worker elif hasattr(exported_program, "dialect"): 624*523fa7a6SAndroid Build Coastguard Worker additional_kwargs["verifier"] = load_verifier( 625*523fa7a6SAndroid Build Coastguard Worker exported_program.dialect # pyre-ignore 626*523fa7a6SAndroid Build Coastguard Worker ) 627*523fa7a6SAndroid Build Coastguard Worker exported_program = ep.ExportedProgram( 628*523fa7a6SAndroid Build Coastguard Worker root=res.graph_module, 629*523fa7a6SAndroid Build Coastguard Worker graph=dummy_g, 630*523fa7a6SAndroid Build Coastguard Worker graph_signature=ep.ExportGraphSignature(input_specs=[], output_specs=[]), 631*523fa7a6SAndroid Build Coastguard Worker state_dict=res.state_dict, # type: ignore[arg-type] 632*523fa7a6SAndroid Build Coastguard Worker range_constraints=range_constraints, 633*523fa7a6SAndroid Build Coastguard Worker module_call_graph=res.module_call_graph, 634*523fa7a6SAndroid Build Coastguard Worker example_inputs=res.example_inputs, 635*523fa7a6SAndroid Build Coastguard Worker constants=res.constants, 636*523fa7a6SAndroid Build Coastguard Worker **additional_kwargs, 637*523fa7a6SAndroid Build Coastguard Worker ) 638*523fa7a6SAndroid Build Coastguard Worker 639*523fa7a6SAndroid Build Coastguard Worker exported_program.graph_module.graph = res.graph_module.graph 640*523fa7a6SAndroid Build Coastguard Worker exported_program._graph_signature = res.signature 641*523fa7a6SAndroid Build Coastguard Worker for node in res.graph_module.graph.nodes: 642*523fa7a6SAndroid Build Coastguard Worker if node.op == "get_attr": 643*523fa7a6SAndroid Build Coastguard Worker setattr( 644*523fa7a6SAndroid Build Coastguard Worker exported_program.graph_module, 645*523fa7a6SAndroid Build Coastguard Worker node.target, 646*523fa7a6SAndroid Build Coastguard Worker getattr(res.graph_module, node.target), 647*523fa7a6SAndroid Build Coastguard Worker ) 648*523fa7a6SAndroid Build Coastguard Worker return upgrader.upgrade(exported_program) 649*523fa7a6SAndroid Build Coastguard Worker 650*523fa7a6SAndroid Build Coastguard Worker 651*523fa7a6SAndroid Build Coastguard Workerdef serialize( 652*523fa7a6SAndroid Build Coastguard Worker exported_program: ep.ExportedProgram, 653*523fa7a6SAndroid Build Coastguard Worker opset_version: Optional[Dict[str, int]] = None, 654*523fa7a6SAndroid Build Coastguard Worker) -> export_serialize.SerializedArtifact: 655*523fa7a6SAndroid Build Coastguard Worker serialized_artifact = ExportedProgramSerializer(opset_version).serialize( 656*523fa7a6SAndroid Build Coastguard Worker exported_program 657*523fa7a6SAndroid Build Coastguard Worker ) 658*523fa7a6SAndroid Build Coastguard Worker assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram) 659*523fa7a6SAndroid Build Coastguard Worker json_program = json.dumps( 660*523fa7a6SAndroid Build Coastguard Worker export_serialize._dataclass_to_dict(serialized_artifact.exported_program), 661*523fa7a6SAndroid Build Coastguard Worker cls=export_serialize.EnumEncoder, 662*523fa7a6SAndroid Build Coastguard Worker ) 663*523fa7a6SAndroid Build Coastguard Worker json_bytes = json_program.encode("utf-8") 664*523fa7a6SAndroid Build Coastguard Worker artifact = export_serialize.SerializedArtifact( 665*523fa7a6SAndroid Build Coastguard Worker json_bytes, 666*523fa7a6SAndroid Build Coastguard Worker serialized_artifact.state_dict, 667*523fa7a6SAndroid Build Coastguard Worker serialized_artifact.constants, 668*523fa7a6SAndroid Build Coastguard Worker serialized_artifact.example_inputs, 669*523fa7a6SAndroid Build Coastguard Worker ) 670*523fa7a6SAndroid Build Coastguard Worker return artifact 671*523fa7a6SAndroid Build Coastguard Worker 672*523fa7a6SAndroid Build Coastguard Worker 673*523fa7a6SAndroid Build Coastguard Workerdef deserialize( 674*523fa7a6SAndroid Build Coastguard Worker artifact: export_serialize.SerializedArtifact, 675*523fa7a6SAndroid Build Coastguard Worker expected_opset_version: Optional[Dict[str, int]] = None, 676*523fa7a6SAndroid Build Coastguard Worker) -> ep.ExportedProgram: 677*523fa7a6SAndroid Build Coastguard Worker assert isinstance(artifact.exported_program, bytes) 678*523fa7a6SAndroid Build Coastguard Worker exported_program_str = artifact.exported_program.decode("utf-8") 679*523fa7a6SAndroid Build Coastguard Worker exported_program_dict = json.loads(exported_program_str) 680*523fa7a6SAndroid Build Coastguard Worker serialized_exported_program = export_serialize._dict_to_dataclass( 681*523fa7a6SAndroid Build Coastguard Worker schema.ExportedProgram, exported_program_dict 682*523fa7a6SAndroid Build Coastguard Worker ) 683*523fa7a6SAndroid Build Coastguard Worker return ExportedProgramDeserializer(expected_opset_version).deserialize( 684*523fa7a6SAndroid Build Coastguard Worker serialized_exported_program, 685*523fa7a6SAndroid Build Coastguard Worker artifact.state_dict, 686*523fa7a6SAndroid Build Coastguard Worker artifact.constants, 687*523fa7a6SAndroid Build Coastguard Worker artifact.example_inputs, 688*523fa7a6SAndroid Build Coastguard Worker ) 689*523fa7a6SAndroid Build Coastguard Worker 690*523fa7a6SAndroid Build Coastguard Worker 691*523fa7a6SAndroid Build Coastguard Workerdef save( 692*523fa7a6SAndroid Build Coastguard Worker ep_save: ep.ExportedProgram, 693*523fa7a6SAndroid Build Coastguard Worker f: Union[str, os.PathLike[str], io.BytesIO], 694*523fa7a6SAndroid Build Coastguard Worker *, 695*523fa7a6SAndroid Build Coastguard Worker extra_files: Optional[Dict[str, Any]] = None, 696*523fa7a6SAndroid Build Coastguard Worker opset_version: Optional[Dict[str, int]] = None, 697*523fa7a6SAndroid Build Coastguard Worker) -> None: 698*523fa7a6SAndroid Build Coastguard Worker if not isinstance(ep_save, ep.ExportedProgram): 699*523fa7a6SAndroid Build Coastguard Worker raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}") 700*523fa7a6SAndroid Build Coastguard Worker 701*523fa7a6SAndroid Build Coastguard Worker artifact: export_serialize.SerializedArtifact = serialize(ep_save, opset_version) 702*523fa7a6SAndroid Build Coastguard Worker 703*523fa7a6SAndroid Build Coastguard Worker if isinstance(f, (str, os.PathLike)): 704*523fa7a6SAndroid Build Coastguard Worker f = os.fspath(str(f)) 705*523fa7a6SAndroid Build Coastguard Worker 706*523fa7a6SAndroid Build Coastguard Worker with zipfile.ZipFile(f, "w") as zipf: 707*523fa7a6SAndroid Build Coastguard Worker # Save every field in the SerializedArtifact to a file. 708*523fa7a6SAndroid Build Coastguard Worker assert isinstance(artifact.exported_program, bytes) 709*523fa7a6SAndroid Build Coastguard Worker zipf.writestr("serialized_exported_program.json", artifact.exported_program) 710*523fa7a6SAndroid Build Coastguard Worker zipf.writestr("serialized_state_dict.pt", artifact.state_dict) 711*523fa7a6SAndroid Build Coastguard Worker zipf.writestr("serialized_constants.pt", artifact.constants) 712*523fa7a6SAndroid Build Coastguard Worker zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs) 713*523fa7a6SAndroid Build Coastguard Worker 714*523fa7a6SAndroid Build Coastguard Worker zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION))) 715*523fa7a6SAndroid Build Coastguard Worker 716*523fa7a6SAndroid Build Coastguard Worker # Add extra files if provided 717*523fa7a6SAndroid Build Coastguard Worker if extra_files: 718*523fa7a6SAndroid Build Coastguard Worker for extra_file_name, content in extra_files.items(): 719*523fa7a6SAndroid Build Coastguard Worker encoded_content = content.encode("utf-8") 720*523fa7a6SAndroid Build Coastguard Worker zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) 721*523fa7a6SAndroid Build Coastguard Worker 722*523fa7a6SAndroid Build Coastguard Worker 723*523fa7a6SAndroid Build Coastguard Workerdef load( 724*523fa7a6SAndroid Build Coastguard Worker f: Union[str, os.PathLike[str], io.BytesIO], 725*523fa7a6SAndroid Build Coastguard Worker *, 726*523fa7a6SAndroid Build Coastguard Worker extra_files: Optional[Dict[str, Any]] = None, 727*523fa7a6SAndroid Build Coastguard Worker expected_opset_version: Optional[Dict[str, int]] = None, 728*523fa7a6SAndroid Build Coastguard Worker) -> ep.ExportedProgram: 729*523fa7a6SAndroid Build Coastguard Worker if isinstance(f, (str, os.PathLike)): 730*523fa7a6SAndroid Build Coastguard Worker f = os.fspath(str(f)) 731*523fa7a6SAndroid Build Coastguard Worker 732*523fa7a6SAndroid Build Coastguard Worker extra_files = extra_files or {} 733*523fa7a6SAndroid Build Coastguard Worker 734*523fa7a6SAndroid Build Coastguard Worker with zipfile.ZipFile(f, "r") as zipf: 735*523fa7a6SAndroid Build Coastguard Worker # Check the version 736*523fa7a6SAndroid Build Coastguard Worker version = zipf.read("version").decode().split(".") 737*523fa7a6SAndroid Build Coastguard Worker 738*523fa7a6SAndroid Build Coastguard Worker assert len(version) == len(SCHEMA_VERSION) 739*523fa7a6SAndroid Build Coastguard Worker if version[0] != str(SCHEMA_VERSION[0]): 740*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 741*523fa7a6SAndroid Build Coastguard Worker f"Serialized version {version} does not match our current " 742*523fa7a6SAndroid Build Coastguard Worker f"schema version {SCHEMA_VERSION}." 743*523fa7a6SAndroid Build Coastguard Worker ) 744*523fa7a6SAndroid Build Coastguard Worker 745*523fa7a6SAndroid Build Coastguard Worker # Load serialized_ep and serialized_state_dict from the zip file 746*523fa7a6SAndroid Build Coastguard Worker 747*523fa7a6SAndroid Build Coastguard Worker serialized_exported_program: Optional[bytes] = None 748*523fa7a6SAndroid Build Coastguard Worker serialized_state_dict: Optional[bytes] = None 749*523fa7a6SAndroid Build Coastguard Worker serialized_constants: Optional[bytes] = None 750*523fa7a6SAndroid Build Coastguard Worker serialized_example_inputs: Optional[bytes] = None 751*523fa7a6SAndroid Build Coastguard Worker 752*523fa7a6SAndroid Build Coastguard Worker for file_info in zipf.infolist(): 753*523fa7a6SAndroid Build Coastguard Worker file_content = zipf.read(file_info.filename) 754*523fa7a6SAndroid Build Coastguard Worker 755*523fa7a6SAndroid Build Coastguard Worker if file_info.filename == "serialized_exported_program.json": 756*523fa7a6SAndroid Build Coastguard Worker serialized_exported_program = file_content 757*523fa7a6SAndroid Build Coastguard Worker elif file_info.filename == "serialized_state_dict.json": 758*523fa7a6SAndroid Build Coastguard Worker print("This version of file is deprecated") 759*523fa7a6SAndroid Build Coastguard Worker serialized_state_dict = file_content 760*523fa7a6SAndroid Build Coastguard Worker elif file_info.filename == "serialized_constants.json": 761*523fa7a6SAndroid Build Coastguard Worker print("This version of file is deprecated") 762*523fa7a6SAndroid Build Coastguard Worker serialized_constants = file_content 763*523fa7a6SAndroid Build Coastguard Worker elif file_info.filename == "serialized_state_dict.pt": 764*523fa7a6SAndroid Build Coastguard Worker serialized_state_dict = file_content 765*523fa7a6SAndroid Build Coastguard Worker elif file_info.filename == "serialized_constants.pt": 766*523fa7a6SAndroid Build Coastguard Worker serialized_constants = file_content 767*523fa7a6SAndroid Build Coastguard Worker elif file_info.filename.startswith("extra_files"): 768*523fa7a6SAndroid Build Coastguard Worker filename = file_info.filename.split("/", 1)[1] 769*523fa7a6SAndroid Build Coastguard Worker extra_files[filename] = file_content.decode("utf-8") 770*523fa7a6SAndroid Build Coastguard Worker elif file_info.filename == "serialized_example_inputs.pt": 771*523fa7a6SAndroid Build Coastguard Worker serialized_example_inputs = file_content 772*523fa7a6SAndroid Build Coastguard Worker 773*523fa7a6SAndroid Build Coastguard Worker assert serialized_exported_program is not None 774*523fa7a6SAndroid Build Coastguard Worker assert serialized_state_dict is not None 775*523fa7a6SAndroid Build Coastguard Worker assert serialized_constants is not None 776*523fa7a6SAndroid Build Coastguard Worker assert serialized_example_inputs is not None 777*523fa7a6SAndroid Build Coastguard Worker 778*523fa7a6SAndroid Build Coastguard Worker artifact: export_serialize.SerializedArtifact = ( 779*523fa7a6SAndroid Build Coastguard Worker export_serialize.SerializedArtifact( 780*523fa7a6SAndroid Build Coastguard Worker serialized_exported_program, 781*523fa7a6SAndroid Build Coastguard Worker serialized_state_dict, 782*523fa7a6SAndroid Build Coastguard Worker serialized_constants, 783*523fa7a6SAndroid Build Coastguard Worker serialized_example_inputs, 784*523fa7a6SAndroid Build Coastguard Worker ) 785*523fa7a6SAndroid Build Coastguard Worker ) 786*523fa7a6SAndroid Build Coastguard Worker 787*523fa7a6SAndroid Build Coastguard Worker # Deserialize ExportedProgram 788*523fa7a6SAndroid Build Coastguard Worker ep = deserialize(artifact, expected_opset_version) 789*523fa7a6SAndroid Build Coastguard Worker 790*523fa7a6SAndroid Build Coastguard Worker return ep 791