1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-ignore-all-errors 8 9import base64 10import io 11import json 12import logging 13import operator 14import os 15import zipfile 16from typing import Any, Callable, Dict, List, Optional, Tuple, Union 17 18import executorch.exir as exir 19import executorch.exir.memory as memory 20import executorch.exir.serde.export_serialize as export_serialize 21import executorch.exir.serde.schema as schema 22import torch 23import torch.export.exported_program as ep 24from executorch.exir import delegate 25from executorch.exir.backend.compile_spec_schema import ( 26 CompileSpec as delegate_CompileSpec, 27) 28from executorch.exir.dialects._ops import _DialectNamespace, ops as exir_ops 29from executorch.exir.dialects.backend._ops import BackendOpOverload 30from executorch.exir.dialects.edge._ops import EdgeOpOverload 31from executorch.exir.lowered_backend_module import ( 32 LoweredBackendModule as ExirLoweredBackendModule, 33) 34from executorch.exir.serde.export_serialize import GraphModuleOpUpgrader, SerializeError 35from executorch.exir.serde.schema import ( 36 CompileSpec, 37 LoweredBackendModule as SerdeLoweredBackendModule, 38 SCHEMA_VERSION, 39 SchemaVersion, 40) 41from torch._export.verifier import load_verifier 42from torch.fx.experimental import symbolic_shapes 43 44log: logging.Logger = logging.getLogger(__name__) 45 46 47class GraphModuleSerializer(export_serialize.GraphModuleSerializer): 48 def __init__( 49 self, 50 graph_signature: ep.ExportGraphSignature, 51 module_call_graph: List[ep.ModuleCallEntry], 52 ) -> None: 53 super().__init__(graph_signature, module_call_graph) 54 self.state_dict: Dict[str, torch.Tensor] = {} # TODO(T157676982) 55 56 def serialize_operator( 57 self, 58 target: Union[ 59 str, 60 EdgeOpOverload, 61 BackendOpOverload, 62 torch._ops.OpOverload, 63 torch._ops.HigherOrderOperator, 64 ], 65 ) -> str: 66 if isinstance(target, str): 67 return target 68 elif target.__module__.startswith("executorch.exir.dialects.edge"): 69 # TODO(zhxchen17) Maybe provide a function name helper in FX. 70 # From torch.fx.node._get_qualified_name 71 module = target.__module__.replace( 72 "executorch.exir.dialects.edge._ops", 73 "executorch.exir.dialects.edge.ops", 74 ) 75 return f"{module}.{target.__name__}" 76 elif target.__module__.startswith("executorch.exir.dialects.backend"): 77 module = target.__module__.replace( 78 "executorch.exir.dialects.backend._ops", 79 "executorch.exir.dialects.backend.ops", 80 ) 81 return f"{module}.{target.__name__}" 82 83 return super().serialize_operator(target) 84 85 def handle_call_function(self, node: torch.fx.Node) -> None: 86 assert node.op == "call_function" 87 88 if node.target is memory.alloc: 89 ex_node = schema.Node( 90 target="memory.alloc", 91 inputs=self.serialize_alloc_inputs(node.args), 92 outputs=self.serialize_arbitrary_outputs(node), 93 metadata=self.serialize_metadata(node), 94 ) 95 self.graph_state.nodes.append(ex_node) 96 return 97 elif isinstance(node.target, EdgeOpOverload): 98 assert node.target._op is not None 99 ex_node = schema.Node( 100 target=self.serialize_operator(node.target), 101 # pyre-ignore Undefined attribute [16]: Item `typing.Callable` of 102 # `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`. 103 inputs=self.serialize_inputs(node.target._op, node.args, node.kwargs), 104 outputs=self.serialize_outputs(node), 105 # TODO: create a new tensor_values here, meta might have faketensor info 106 metadata=self.serialize_metadata(node), 107 ) 108 self.graph_state.nodes.append(ex_node) 109 return 110 elif node.target is delegate.executorch_call_delegate: 111 ex_node = schema.Node( 112 target=self.serialize_operator(node.target), 113 inputs=self.serialize_call_delegate_inputs(node.args), 114 outputs=self.serialize_arbitrary_outputs(node), 115 metadata=self.serialize_metadata(node), 116 ) 117 self.graph_state.nodes.append(ex_node) 118 return 119 120 super().handle_call_function(node) 121 122 def serialize_outputs(self, node: torch.fx.Node) -> List[schema.Argument]: 123 if isinstance(node.target, EdgeOpOverload): 124 # Store the original edge op 125 edge_op = node.target 126 # Replace the edge op with the original ATen op so that we can just call into 127 # the serialize_outputs implementation present in the parent class. 128 node.target = edge_op._op 129 ret = super().serialize_outputs(node) 130 # Replace the edge op back. 131 node.target = edge_op 132 else: 133 ret = super().serialize_outputs(node) 134 return ret 135 136 def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: 137 meta = super().serialize_metadata(node) 138 139 if "debug_handle" in node.meta: 140 debug_handle = node.meta["debug_handle"] 141 meta["debug_handle"] = str(debug_handle) 142 143 return meta 144 145 def serialize_alloc_inputs( 146 self, inputs # pyre-ignore 147 ) -> List[schema.NamedArgument]: 148 """ 149 Serialize the inputs to the memory.alloc function. Since there's no 150 specific spec, we jut serialize the inputs with a dummy name. 151 We serialize the AllocSpec into a string "size;dtype" 152 """ 153 assert len(inputs) == 1 154 155 def serialize_alloc_spec(alloc_spec: memory.AllocSpec) -> schema.Argument: 156 return schema.Argument.create( 157 as_string=f"{alloc_spec[0]};{export_serialize._TORCH_TO_SERIALIZE_DTYPE[alloc_spec[1]].value}" 158 ) 159 160 if isinstance(inputs[0], list): 161 return [ 162 schema.NamedArgument(name="alloc_list", arg=serialize_alloc_spec(arg)) 163 for arg in inputs[0] 164 ] 165 else: 166 # Single value 167 return [ 168 schema.NamedArgument( 169 name="alloc_arg", arg=serialize_alloc_spec(inputs[0]) 170 ) 171 ] 172 173 def serialize_arbitrary_outputs(self, node: torch.fx.Node) -> List[schema.Argument]: 174 meta_val = node.meta["val"] 175 176 # Check single value return 177 if isinstance(meta_val, torch.Tensor): 178 return [ 179 schema.Argument.create( 180 as_tensor=self.serialize_tensor_output(node.name, meta_val) 181 ) 182 ] 183 184 # There are a two possibilities at this point: 185 # - This operator returns a list of Tensors. 186 # - This operator returns multiple Tensors. 187 # 188 # Either way, start by gathering a list of TensorArguments with the correct names. 189 # For consistent naming with FX, consult the downstream `getitem` node and 190 # make sure our outputs have the same name. 191 idx_to_name = {} 192 for user in node.users: 193 if user.target is not operator.getitem: 194 continue 195 idx_to_name[user.args[1]] = user.name 196 197 for idx, _ in enumerate(meta_val): 198 # FX does not emit a getitem node for any outputs that are unused. 199 # However, we need a name for them so that the number of outputs will 200 # correctly match the schema. Just assign a dummy name. 201 if idx not in idx_to_name: 202 idx_to_name[idx] = f"{node.name}_unused_{idx}" 203 204 arg_list = [] 205 for i, element_meta_val in enumerate(meta_val): 206 arg_list.append( 207 self.serialize_tensor_output(idx_to_name[i], element_meta_val) 208 ) 209 210 if len(meta_val) == 1: 211 # The operator returns a list of tensors 212 return [schema.Argument.create(as_tensors=arg_list)] 213 else: 214 # The operator returns multiple tensors 215 return [schema.Argument.create(as_tensor=arg) for arg in arg_list] 216 217 def serialize_graph(self, graph_module: torch.fx.GraphModule) -> schema.Graph: 218 self.original_graph_module: torch.fx.GraphModule = graph_module # pyre-ignore 219 return super().serialize_graph(graph_module) 220 221 def serialize_call_delegate_inputs( 222 self, args # pyre-ignore 223 ) -> List[schema.NamedArgument]: 224 lowered_module_arg = args[0] 225 delegate_args = args[1:] 226 227 serialized_lowered_module = self.serialize_lowered_module(lowered_module_arg) 228 serialized_lowered_module_arg = schema.NamedArgument( 229 name=lowered_module_arg.target, 230 arg=schema.Argument.create(as_string=serialized_lowered_module), 231 ) 232 233 serialized_args = [serialized_lowered_module_arg] 234 for i, arg in enumerate(delegate_args): 235 serialized_args.append( 236 schema.NamedArgument( 237 name=f"delegate_arg_{i}", arg=self.serialize_input(arg) 238 ) 239 ) 240 return serialized_args 241 242 def serialize_lowered_module(self, lowered_module_arg: torch.fx.Node) -> str: 243 assert lowered_module_arg.op == "get_attr" 244 assert isinstance(lowered_module_arg.target, str) 245 246 def serialize_bytes(b: bytes) -> str: 247 # We want to serialize the bytes to string because JSON cannot 248 # serialize bytes. 249 # Since the given bytes may be serialized with any encoding, so we 250 # want to first encode with base64, and then decode it with 251 # ascii. During deserialization we can just directly decode with b64 252 # to get the original encoded bytes. 253 return base64.b64encode(b).decode("ascii") 254 255 lowered_module = getattr( 256 lowered_module_arg.graph.owning_module, lowered_module_arg.target 257 ) 258 assert isinstance(lowered_module, ExirLoweredBackendModule) 259 260 serialized_compile_spec = [ 261 CompileSpec(cs.key, serialize_bytes(cs.value)) 262 for cs in lowered_module.compile_specs 263 ] 264 265 serialized_artifact = ExportedProgramSerializer().serialize( 266 lowered_module.original_module 267 ) 268 assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram) 269 270 serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes) 271 272 serialized_lowered_module = SerdeLoweredBackendModule( 273 original_module=serialized_artifact.exported_program, 274 original_state_dict=serialize_bytes(serialized_artifact.state_dict), 275 original_constants=serialize_bytes(serialized_artifact.constants), 276 processed_bytes=serialized_processed_bytes, 277 compile_specs=serialized_compile_spec, 278 backend_id=lowered_module.backend_id, 279 ) 280 281 json_lowered_module = json.dumps( 282 export_serialize._dataclass_to_dict(serialized_lowered_module), 283 cls=export_serialize.EnumEncoder, 284 ) 285 return json_lowered_module 286 287 288class ExportedProgramSerializer(export_serialize.ExportedProgramSerializer): 289 def serialize( 290 self, exported_program: ep.ExportedProgram 291 ) -> export_serialize._SerializedProgram: 292 """ 293 Args: 294 exported_program: Exported Program to serialize 295 """ 296 297 assert isinstance(exported_program, ep.ExportedProgram) 298 299 gm_serializer = GraphModuleSerializer( 300 exported_program.graph_signature, exported_program.module_call_graph 301 ) 302 serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) 303 serialized_range_constraints = export_serialize.serialize_range_constraints( 304 exported_program.range_constraints 305 ) 306 307 # TODO: Directly serialize exported_program.constants once 308 # CustomClassHolders get stored in the ExportedProgram rather than in 309 # the graph 310 constants = {} 311 for n, c in gm_serializer.custom_objs.items(): 312 constants[n] = c 313 for n, t in exported_program.constants.items(): 314 assert n not in constants 315 constants[n] = t 316 317 additional_kwargs = {} 318 if hasattr(exported_program, "verifiers"): 319 additional_kwargs["verifiers"] = [ 320 v.dialect for v in exported_program.verifiers 321 ] 322 elif hasattr(exported_program, "dialect"): 323 additional_kwargs["dialect"] = exported_program.dialect 324 serialized_ep = schema.ExportedProgram( 325 graph_module=serialized_graph_module, 326 opset_version=self.opset_version, 327 range_constraints=serialized_range_constraints, 328 schema_version=SchemaVersion( 329 major=SCHEMA_VERSION[0], 330 minor=SCHEMA_VERSION[1], 331 ), 332 **additional_kwargs, 333 ) 334 335 # Test canonical form is well defined. 336 # TODO : Doesn't pass currently on executorch graphs with alloc nodes. 337 # canonicalize(serialized_ep) 338 339 if exported_program.example_inputs is not None: 340 example_inputs = export_serialize.serialize_torch_artifact( 341 exported_program.example_inputs 342 ) 343 else: 344 example_inputs = b"" 345 346 return export_serialize._SerializedProgram( 347 serialized_ep, 348 export_serialize.serialize_torch_artifact(exported_program.state_dict), 349 export_serialize.serialize_torch_artifact(constants), 350 example_inputs, 351 ) 352 353 354class GraphModuleDeserializer(export_serialize.GraphModuleDeserializer): 355 def deserialize_operator(self, serialized_target: str) -> str: 356 def find_operator(module: _DialectNamespace, serialized_target: str) -> str: 357 serialized_target_names = serialized_target.split(".")[5:] 358 359 target = module 360 for name in serialized_target_names: 361 if not hasattr(target, name): 362 return serialized_target 363 else: 364 target = getattr(target, name) 365 return target 366 367 if serialized_target.startswith("executorch.exir.dialects.edge.ops"): 368 return find_operator(exir_ops.edge, serialized_target) 369 elif serialized_target.startswith("executorch.exir.dialects.backend.ops"): 370 return find_operator(exir_ops.backend, serialized_target) 371 372 return super().deserialize_operator(serialized_target) 373 374 # pyre-ignore 375 def deserialize_inputs_no_schema(self, serialized_node) -> Any: 376 return tuple( 377 self.deserialize_input(input.arg) for input in serialized_node.inputs 378 ) 379 380 # pyre-ignore 381 def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> None: 382 if target == "memory.alloc": 383 args = self.deserialize_alloc_inputs(serialized_node.inputs) 384 fx_node = self.graph.create_node( 385 "call_function", memory.alloc, args, {}, "alloc" 386 ) 387 388 self.deserialize_arbitrary_outputs(serialized_node, fx_node) 389 390 fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) 391 return 392 393 elif target is delegate.executorch_call_delegate: 394 if ( 395 len(serialized_node.outputs) == 1 396 and serialized_node.outputs[0].type == "as_tensor" 397 ): 398 # If it's a single tensor return then we can use the name of the 399 # node itself 400 name = serialized_node.outputs[0].value.name 401 else: 402 # Otherwise FX will make a name for us, and we'll have `getitem` 403 # nodes pointed to that 404 name = None 405 406 args = self.deserialize_call_delegate_inputs(serialized_node.inputs) 407 fx_node = self.graph.create_node("call_function", target, args, {}, name) 408 409 self.deserialize_arbitrary_outputs(serialized_node, fx_node) 410 411 fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) 412 return 413 elif isinstance(target, EdgeOpOverload): 414 # For convenience: if this node returns a single tensor, name the 415 # newly-created node after it. This ensures that these tensor values 416 # have names that are consistent with serialized. 417 name = ( 418 serialized_node.outputs[0].value.name 419 if export_serialize._is_single_tensor_return(target._op) 420 else None # FX will generate a name for us. 421 ) 422 args, kwargs = self.deserialize_inputs(target._op, serialized_node) 423 fx_node = self.graph.create_node( 424 "call_function", target, args, kwargs, name 425 ) 426 self.deserialize_outputs(serialized_node, fx_node) 427 fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) 428 return 429 elif isinstance(target, str): 430 # Create a dummy fake op if the target does not exist 431 # because we cannot create a call_function node w/o a 432 # callable target 433 log.warning( 434 f"Could not find operator {target}. Returning fake operator." 435 ) # noqa: G004 436 437 # pyre-ignore 438 def fake_op(x): 439 raise NotImplementedError("Fake op is not meant to be run.") 440 441 fake_op.__name__ = target 442 target = fake_op 443 444 args = self.deserialize_inputs_no_schema(serialized_node) 445 fx_node = self.graph.create_node("call_function", target, args, None, None) 446 self.deserialize_arbitrary_outputs(serialized_node, fx_node) 447 448 return 449 450 super().deserialize_node(serialized_node, target) 451 452 def deserialize_outputs( 453 self, serialized_node: schema.Node, fx_node: torch.fx.Node 454 ) -> None: 455 if isinstance(fx_node.target, EdgeOpOverload): 456 # Store the original edge op 457 edge_op = fx_node.target 458 # Replace the edge op with the original ATen op so that we can just call into 459 # node deserialize_outputs implementation present in the parent class. 460 fx_node.target = edge_op._op 461 super().deserialize_outputs(serialized_node, fx_node) 462 # Replace the edge op back. 463 fx_node.target = edge_op 464 else: 465 super().deserialize_outputs(serialized_node, fx_node) 466 467 def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: 468 res = super().deserialize_metadata(metadata) 469 470 if debug_handle := metadata.get("debug_handle"): 471 res["debug_handle"] = int(debug_handle) 472 473 return res 474 475 # pyre-ignore 476 def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): 477 def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec: 478 serialized_alloc_spec_elems = serialized_alloc_spec.split(";") 479 assert len(serialized_alloc_spec_elems) == 2 480 serialized_size_elems = ( 481 serialized_alloc_spec_elems[0].strip("()").split(",") 482 ) 483 484 size = tuple(int(x) for x in serialized_size_elems if x != "") 485 dtype = export_serialize._SERIALIZE_TO_TORCH_DTYPE[ 486 int(serialized_alloc_spec_elems[1]) 487 ] 488 return (size, dtype) 489 490 assert serialized_inputs[0].arg.type == "as_string" 491 492 # Single value 493 if len(serialized_inputs) == 1 and serialized_inputs[0].name == "alloc_arg": 494 res = (deserialize_alloc_spec(serialized_inputs[0].arg.value),) 495 return res 496 497 alloc_specs = [ 498 deserialize_alloc_spec(serialized_input.arg.value) 499 for serialized_input in serialized_inputs 500 ] 501 return (alloc_specs,) 502 503 def deserialize_arbitrary_outputs( 504 self, serialized_node: schema.Node, fx_node: torch.fx.Node 505 ) -> None: 506 if len(serialized_node.outputs) == 0: 507 return 508 # Single tensor return 509 elif ( 510 len(serialized_node.outputs) == 1 511 and serialized_node.outputs[0].type == "as_tensor" 512 ): 513 return self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) 514 elif len(serialized_node.outputs) == 1 and isinstance( 515 serialized_node.outputs[0].value, 516 (schema.SymIntArgument, schema.SymBoolArgument), 517 ): 518 self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) 519 return 520 521 self.deserialize_multiple_outputs(serialized_node, fx_node) 522 523 # pyre-ignore 524 def deserialize_call_delegate_inputs( 525 self, serialized_inputs: List[schema.NamedArgument] 526 ): 527 serialized_lowered_module = serialized_inputs[0] 528 lowered_module_node = self.deserialize_lowered_module(serialized_lowered_module) 529 serialized_delegate_inputs = serialized_inputs[1:] 530 args = tuple( 531 self.deserialize_input(input.arg) for input in serialized_delegate_inputs 532 ) 533 return (lowered_module_node,) + args 534 535 def deserialize_lowered_module( 536 self, serialized_lowered_module_arg: schema.NamedArgument 537 ) -> torch.fx.Node: 538 assert serialized_lowered_module_arg.arg.type == "as_string" 539 lowered_module_str = serialized_lowered_module_arg.arg.value 540 json_lowered_module = json.loads(lowered_module_str) 541 serialized_lowered_module = export_serialize._dict_to_dataclass( 542 SerdeLoweredBackendModule, json_lowered_module 543 ) 544 545 backend_id = serialized_lowered_module.backend_id 546 processed_bytes = base64.b64decode(serialized_lowered_module.processed_bytes) 547 compile_specs = [ 548 delegate_CompileSpec(key=cs.key, value=base64.b64decode(cs.value)) 549 for cs in serialized_lowered_module.compile_specs 550 ] 551 552 original_module = ExportedProgramDeserializer().deserialize( 553 serialized_lowered_module.original_module, 554 base64.b64decode(serialized_lowered_module.original_state_dict), 555 base64.b64decode(serialized_lowered_module.original_constants), 556 None, 557 ) 558 559 lowered_module = ExirLoweredBackendModule( 560 original_module, 561 backend_id, 562 processed_bytes, 563 compile_specs, 564 ) 565 self.module.register_module(serialized_lowered_module_arg.name, lowered_module) 566 return self.graph.get_attr(serialized_lowered_module_arg.name) 567 568 569class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer): 570 def deserialize( 571 self, 572 exported_program: export_serialize.ExportedProgram, 573 state_dict: Union[Dict[str, torch.Tensor], bytes], 574 constants: Union[Dict[str, torch.Tensor], bytes], 575 example_inputs: Optional[ 576 Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes] 577 ] = None, 578 ) -> ep.ExportedProgram: 579 assert isinstance(exported_program, export_serialize.ExportedProgram) 580 version = exported_program.schema_version 581 582 # TODO(zhxchen17) blocked on thrift schema refactor 583 if version.major != SCHEMA_VERSION[0] and not ( 584 version.major == 0 and version.minor == 0 585 ): 586 raise SerializeError( 587 f"Serialized schema version {exported_program.schema_version} " 588 f"does not match our current schema version {SCHEMA_VERSION}." 589 ) 590 591 symbol_name_to_range = { 592 k: symbolic_shapes.ValueRanges( 593 export_serialize._int_to_sympy_int(v.min_val), 594 export_serialize._int_to_sympy_int(v.max_val), 595 ) 596 for k, v in exported_program.range_constraints.items() 597 } 598 res = GraphModuleDeserializer().deserialize( 599 exported_program.graph_module, 600 state_dict, 601 constants, 602 example_inputs, 603 symbol_name_to_range, 604 ) 605 range_constraints = self.deserialize_range_constraints( 606 symbol_name_to_range, 607 res.names_to_symbols, 608 ) 609 model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version 610 self._validate_model_opset_version(model_opset_version) 611 612 upgrader = GraphModuleOpUpgrader( 613 self.expected_opset_version, model_opset_version 614 ) 615 616 dummy_g = torch.fx.Graph() 617 dummy_g.output(()) 618 additional_kwargs = {} 619 if hasattr(exported_program, "verifiers"): 620 additional_kwargs["verifiers"] = [ 621 load_verifier(v) for v in exported_program.verifiers # pyre-ignore 622 ] 623 elif hasattr(exported_program, "dialect"): 624 additional_kwargs["verifier"] = load_verifier( 625 exported_program.dialect # pyre-ignore 626 ) 627 exported_program = ep.ExportedProgram( 628 root=res.graph_module, 629 graph=dummy_g, 630 graph_signature=ep.ExportGraphSignature(input_specs=[], output_specs=[]), 631 state_dict=res.state_dict, # type: ignore[arg-type] 632 range_constraints=range_constraints, 633 module_call_graph=res.module_call_graph, 634 example_inputs=res.example_inputs, 635 constants=res.constants, 636 **additional_kwargs, 637 ) 638 639 exported_program.graph_module.graph = res.graph_module.graph 640 exported_program._graph_signature = res.signature 641 for node in res.graph_module.graph.nodes: 642 if node.op == "get_attr": 643 setattr( 644 exported_program.graph_module, 645 node.target, 646 getattr(res.graph_module, node.target), 647 ) 648 return upgrader.upgrade(exported_program) 649 650 651def serialize( 652 exported_program: ep.ExportedProgram, 653 opset_version: Optional[Dict[str, int]] = None, 654) -> export_serialize.SerializedArtifact: 655 serialized_artifact = ExportedProgramSerializer(opset_version).serialize( 656 exported_program 657 ) 658 assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram) 659 json_program = json.dumps( 660 export_serialize._dataclass_to_dict(serialized_artifact.exported_program), 661 cls=export_serialize.EnumEncoder, 662 ) 663 json_bytes = json_program.encode("utf-8") 664 artifact = export_serialize.SerializedArtifact( 665 json_bytes, 666 serialized_artifact.state_dict, 667 serialized_artifact.constants, 668 serialized_artifact.example_inputs, 669 ) 670 return artifact 671 672 673def deserialize( 674 artifact: export_serialize.SerializedArtifact, 675 expected_opset_version: Optional[Dict[str, int]] = None, 676) -> ep.ExportedProgram: 677 assert isinstance(artifact.exported_program, bytes) 678 exported_program_str = artifact.exported_program.decode("utf-8") 679 exported_program_dict = json.loads(exported_program_str) 680 serialized_exported_program = export_serialize._dict_to_dataclass( 681 schema.ExportedProgram, exported_program_dict 682 ) 683 return ExportedProgramDeserializer(expected_opset_version).deserialize( 684 serialized_exported_program, 685 artifact.state_dict, 686 artifact.constants, 687 artifact.example_inputs, 688 ) 689 690 691def save( 692 ep_save: ep.ExportedProgram, 693 f: Union[str, os.PathLike[str], io.BytesIO], 694 *, 695 extra_files: Optional[Dict[str, Any]] = None, 696 opset_version: Optional[Dict[str, int]] = None, 697) -> None: 698 if not isinstance(ep_save, ep.ExportedProgram): 699 raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}") 700 701 artifact: export_serialize.SerializedArtifact = serialize(ep_save, opset_version) 702 703 if isinstance(f, (str, os.PathLike)): 704 f = os.fspath(str(f)) 705 706 with zipfile.ZipFile(f, "w") as zipf: 707 # Save every field in the SerializedArtifact to a file. 708 assert isinstance(artifact.exported_program, bytes) 709 zipf.writestr("serialized_exported_program.json", artifact.exported_program) 710 zipf.writestr("serialized_state_dict.pt", artifact.state_dict) 711 zipf.writestr("serialized_constants.pt", artifact.constants) 712 zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs) 713 714 zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION))) 715 716 # Add extra files if provided 717 if extra_files: 718 for extra_file_name, content in extra_files.items(): 719 encoded_content = content.encode("utf-8") 720 zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) 721 722 723def load( 724 f: Union[str, os.PathLike[str], io.BytesIO], 725 *, 726 extra_files: Optional[Dict[str, Any]] = None, 727 expected_opset_version: Optional[Dict[str, int]] = None, 728) -> ep.ExportedProgram: 729 if isinstance(f, (str, os.PathLike)): 730 f = os.fspath(str(f)) 731 732 extra_files = extra_files or {} 733 734 with zipfile.ZipFile(f, "r") as zipf: 735 # Check the version 736 version = zipf.read("version").decode().split(".") 737 738 assert len(version) == len(SCHEMA_VERSION) 739 if version[0] != str(SCHEMA_VERSION[0]): 740 raise RuntimeError( 741 f"Serialized version {version} does not match our current " 742 f"schema version {SCHEMA_VERSION}." 743 ) 744 745 # Load serialized_ep and serialized_state_dict from the zip file 746 747 serialized_exported_program: Optional[bytes] = None 748 serialized_state_dict: Optional[bytes] = None 749 serialized_constants: Optional[bytes] = None 750 serialized_example_inputs: Optional[bytes] = None 751 752 for file_info in zipf.infolist(): 753 file_content = zipf.read(file_info.filename) 754 755 if file_info.filename == "serialized_exported_program.json": 756 serialized_exported_program = file_content 757 elif file_info.filename == "serialized_state_dict.json": 758 print("This version of file is deprecated") 759 serialized_state_dict = file_content 760 elif file_info.filename == "serialized_constants.json": 761 print("This version of file is deprecated") 762 serialized_constants = file_content 763 elif file_info.filename == "serialized_state_dict.pt": 764 serialized_state_dict = file_content 765 elif file_info.filename == "serialized_constants.pt": 766 serialized_constants = file_content 767 elif file_info.filename.startswith("extra_files"): 768 filename = file_info.filename.split("/", 1)[1] 769 extra_files[filename] = file_content.decode("utf-8") 770 elif file_info.filename == "serialized_example_inputs.pt": 771 serialized_example_inputs = file_content 772 773 assert serialized_exported_program is not None 774 assert serialized_state_dict is not None 775 assert serialized_constants is not None 776 assert serialized_example_inputs is not None 777 778 artifact: export_serialize.SerializedArtifact = ( 779 export_serialize.SerializedArtifact( 780 serialized_exported_program, 781 serialized_state_dict, 782 serialized_constants, 783 serialized_example_inputs, 784 ) 785 ) 786 787 # Deserialize ExportedProgram 788 ep = deserialize(artifact, expected_opset_version) 789 790 return ep 791