xref: /aosp_15_r20/external/executorch/exir/serde/serialize.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-ignore-all-errors
8
9import base64
10import io
11import json
12import logging
13import operator
14import os
15import zipfile
16from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
18import executorch.exir as exir
19import executorch.exir.memory as memory
20import executorch.exir.serde.export_serialize as export_serialize
21import executorch.exir.serde.schema as schema
22import torch
23import torch.export.exported_program as ep
24from executorch.exir import delegate
25from executorch.exir.backend.compile_spec_schema import (
26    CompileSpec as delegate_CompileSpec,
27)
28from executorch.exir.dialects._ops import _DialectNamespace, ops as exir_ops
29from executorch.exir.dialects.backend._ops import BackendOpOverload
30from executorch.exir.dialects.edge._ops import EdgeOpOverload
31from executorch.exir.lowered_backend_module import (
32    LoweredBackendModule as ExirLoweredBackendModule,
33)
34from executorch.exir.serde.export_serialize import GraphModuleOpUpgrader, SerializeError
35from executorch.exir.serde.schema import (
36    CompileSpec,
37    LoweredBackendModule as SerdeLoweredBackendModule,
38    SCHEMA_VERSION,
39    SchemaVersion,
40)
41from torch._export.verifier import load_verifier
42from torch.fx.experimental import symbolic_shapes
43
44log: logging.Logger = logging.getLogger(__name__)
45
46
47class GraphModuleSerializer(export_serialize.GraphModuleSerializer):
48    def __init__(
49        self,
50        graph_signature: ep.ExportGraphSignature,
51        module_call_graph: List[ep.ModuleCallEntry],
52    ) -> None:
53        super().__init__(graph_signature, module_call_graph)
54        self.state_dict: Dict[str, torch.Tensor] = {}  # TODO(T157676982)
55
56    def serialize_operator(
57        self,
58        target: Union[
59            str,
60            EdgeOpOverload,
61            BackendOpOverload,
62            torch._ops.OpOverload,
63            torch._ops.HigherOrderOperator,
64        ],
65    ) -> str:
66        if isinstance(target, str):
67            return target
68        elif target.__module__.startswith("executorch.exir.dialects.edge"):
69            # TODO(zhxchen17) Maybe provide a function name helper in FX.
70            # From torch.fx.node._get_qualified_name
71            module = target.__module__.replace(
72                "executorch.exir.dialects.edge._ops",
73                "executorch.exir.dialects.edge.ops",
74            )
75            return f"{module}.{target.__name__}"
76        elif target.__module__.startswith("executorch.exir.dialects.backend"):
77            module = target.__module__.replace(
78                "executorch.exir.dialects.backend._ops",
79                "executorch.exir.dialects.backend.ops",
80            )
81            return f"{module}.{target.__name__}"
82
83        return super().serialize_operator(target)
84
85    def handle_call_function(self, node: torch.fx.Node) -> None:
86        assert node.op == "call_function"
87
88        if node.target is memory.alloc:
89            ex_node = schema.Node(
90                target="memory.alloc",
91                inputs=self.serialize_alloc_inputs(node.args),
92                outputs=self.serialize_arbitrary_outputs(node),
93                metadata=self.serialize_metadata(node),
94            )
95            self.graph_state.nodes.append(ex_node)
96            return
97        elif isinstance(node.target, EdgeOpOverload):
98            assert node.target._op is not None
99            ex_node = schema.Node(
100                target=self.serialize_operator(node.target),
101                # pyre-ignore Undefined attribute [16]: Item `typing.Callable` of
102                # `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`.
103                inputs=self.serialize_inputs(node.target._op, node.args, node.kwargs),
104                outputs=self.serialize_outputs(node),
105                # TODO: create a new tensor_values here, meta might have faketensor info
106                metadata=self.serialize_metadata(node),
107            )
108            self.graph_state.nodes.append(ex_node)
109            return
110        elif node.target is delegate.executorch_call_delegate:
111            ex_node = schema.Node(
112                target=self.serialize_operator(node.target),
113                inputs=self.serialize_call_delegate_inputs(node.args),
114                outputs=self.serialize_arbitrary_outputs(node),
115                metadata=self.serialize_metadata(node),
116            )
117            self.graph_state.nodes.append(ex_node)
118            return
119
120        super().handle_call_function(node)
121
122    def serialize_outputs(self, node: torch.fx.Node) -> List[schema.Argument]:
123        if isinstance(node.target, EdgeOpOverload):
124            # Store the original edge op
125            edge_op = node.target
126            # Replace the edge op with the original ATen op so that we can just call into
127            # the serialize_outputs implementation present in the parent class.
128            node.target = edge_op._op
129            ret = super().serialize_outputs(node)
130            # Replace the edge op back.
131            node.target = edge_op
132        else:
133            ret = super().serialize_outputs(node)
134        return ret
135
136    def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
137        meta = super().serialize_metadata(node)
138
139        if "debug_handle" in node.meta:
140            debug_handle = node.meta["debug_handle"]
141            meta["debug_handle"] = str(debug_handle)
142
143        return meta
144
145    def serialize_alloc_inputs(
146        self, inputs  # pyre-ignore
147    ) -> List[schema.NamedArgument]:
148        """
149        Serialize the inputs to the memory.alloc function. Since there's no
150        specific spec, we jut serialize the inputs with a dummy name.
151        We serialize the AllocSpec into a string "size;dtype"
152        """
153        assert len(inputs) == 1
154
155        def serialize_alloc_spec(alloc_spec: memory.AllocSpec) -> schema.Argument:
156            return schema.Argument.create(
157                as_string=f"{alloc_spec[0]};{export_serialize._TORCH_TO_SERIALIZE_DTYPE[alloc_spec[1]].value}"
158            )
159
160        if isinstance(inputs[0], list):
161            return [
162                schema.NamedArgument(name="alloc_list", arg=serialize_alloc_spec(arg))
163                for arg in inputs[0]
164            ]
165        else:
166            # Single value
167            return [
168                schema.NamedArgument(
169                    name="alloc_arg", arg=serialize_alloc_spec(inputs[0])
170                )
171            ]
172
173    def serialize_arbitrary_outputs(self, node: torch.fx.Node) -> List[schema.Argument]:
174        meta_val = node.meta["val"]
175
176        # Check single value return
177        if isinstance(meta_val, torch.Tensor):
178            return [
179                schema.Argument.create(
180                    as_tensor=self.serialize_tensor_output(node.name, meta_val)
181                )
182            ]
183
184        # There are a two possibilities at this point:
185        # - This operator returns a list of Tensors.
186        # - This operator returns multiple Tensors.
187        #
188        # Either way, start by gathering a list of TensorArguments with the correct names.
189        # For consistent naming with FX, consult the downstream `getitem` node and
190        # make sure our outputs have the same name.
191        idx_to_name = {}
192        for user in node.users:
193            if user.target is not operator.getitem:
194                continue
195            idx_to_name[user.args[1]] = user.name
196
197        for idx, _ in enumerate(meta_val):
198            # FX does not emit a getitem node for any outputs that are unused.
199            # However, we need a name for them so that the number of outputs will
200            # correctly match the schema. Just assign a dummy name.
201            if idx not in idx_to_name:
202                idx_to_name[idx] = f"{node.name}_unused_{idx}"
203
204        arg_list = []
205        for i, element_meta_val in enumerate(meta_val):
206            arg_list.append(
207                self.serialize_tensor_output(idx_to_name[i], element_meta_val)
208            )
209
210        if len(meta_val) == 1:
211            # The operator returns a list of tensors
212            return [schema.Argument.create(as_tensors=arg_list)]
213        else:
214            # The operator returns multiple tensors
215            return [schema.Argument.create(as_tensor=arg) for arg in arg_list]
216
217    def serialize_graph(self, graph_module: torch.fx.GraphModule) -> schema.Graph:
218        self.original_graph_module: torch.fx.GraphModule = graph_module  # pyre-ignore
219        return super().serialize_graph(graph_module)
220
221    def serialize_call_delegate_inputs(
222        self, args  # pyre-ignore
223    ) -> List[schema.NamedArgument]:
224        lowered_module_arg = args[0]
225        delegate_args = args[1:]
226
227        serialized_lowered_module = self.serialize_lowered_module(lowered_module_arg)
228        serialized_lowered_module_arg = schema.NamedArgument(
229            name=lowered_module_arg.target,
230            arg=schema.Argument.create(as_string=serialized_lowered_module),
231        )
232
233        serialized_args = [serialized_lowered_module_arg]
234        for i, arg in enumerate(delegate_args):
235            serialized_args.append(
236                schema.NamedArgument(
237                    name=f"delegate_arg_{i}", arg=self.serialize_input(arg)
238                )
239            )
240        return serialized_args
241
242    def serialize_lowered_module(self, lowered_module_arg: torch.fx.Node) -> str:
243        assert lowered_module_arg.op == "get_attr"
244        assert isinstance(lowered_module_arg.target, str)
245
246        def serialize_bytes(b: bytes) -> str:
247            # We want to serialize the bytes to string because JSON cannot
248            # serialize bytes.
249            # Since the given bytes may be serialized with any encoding, so we
250            # want to first encode with base64, and then decode it with
251            # ascii. During deserialization we can just directly decode with b64
252            # to get the original encoded bytes.
253            return base64.b64encode(b).decode("ascii")
254
255        lowered_module = getattr(
256            lowered_module_arg.graph.owning_module, lowered_module_arg.target
257        )
258        assert isinstance(lowered_module, ExirLoweredBackendModule)
259
260        serialized_compile_spec = [
261            CompileSpec(cs.key, serialize_bytes(cs.value))
262            for cs in lowered_module.compile_specs
263        ]
264
265        serialized_artifact = ExportedProgramSerializer().serialize(
266            lowered_module.original_module
267        )
268        assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
269
270        serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes)
271
272        serialized_lowered_module = SerdeLoweredBackendModule(
273            original_module=serialized_artifact.exported_program,
274            original_state_dict=serialize_bytes(serialized_artifact.state_dict),
275            original_constants=serialize_bytes(serialized_artifact.constants),
276            processed_bytes=serialized_processed_bytes,
277            compile_specs=serialized_compile_spec,
278            backend_id=lowered_module.backend_id,
279        )
280
281        json_lowered_module = json.dumps(
282            export_serialize._dataclass_to_dict(serialized_lowered_module),
283            cls=export_serialize.EnumEncoder,
284        )
285        return json_lowered_module
286
287
288class ExportedProgramSerializer(export_serialize.ExportedProgramSerializer):
289    def serialize(
290        self, exported_program: ep.ExportedProgram
291    ) -> export_serialize._SerializedProgram:
292        """
293        Args:
294            exported_program: Exported Program to serialize
295        """
296
297        assert isinstance(exported_program, ep.ExportedProgram)
298
299        gm_serializer = GraphModuleSerializer(
300            exported_program.graph_signature, exported_program.module_call_graph
301        )
302        serialized_graph_module = gm_serializer.serialize(exported_program.graph_module)
303        serialized_range_constraints = export_serialize.serialize_range_constraints(
304            exported_program.range_constraints
305        )
306
307        # TODO: Directly serialize exported_program.constants once
308        # CustomClassHolders get stored in the ExportedProgram rather than in
309        # the graph
310        constants = {}
311        for n, c in gm_serializer.custom_objs.items():
312            constants[n] = c
313        for n, t in exported_program.constants.items():
314            assert n not in constants
315            constants[n] = t
316
317        additional_kwargs = {}
318        if hasattr(exported_program, "verifiers"):
319            additional_kwargs["verifiers"] = [
320                v.dialect for v in exported_program.verifiers
321            ]
322        elif hasattr(exported_program, "dialect"):
323            additional_kwargs["dialect"] = exported_program.dialect
324        serialized_ep = schema.ExportedProgram(
325            graph_module=serialized_graph_module,
326            opset_version=self.opset_version,
327            range_constraints=serialized_range_constraints,
328            schema_version=SchemaVersion(
329                major=SCHEMA_VERSION[0],
330                minor=SCHEMA_VERSION[1],
331            ),
332            **additional_kwargs,
333        )
334
335        # Test canonical form is well defined.
336        # TODO : Doesn't pass currently on executorch graphs with alloc nodes.
337        # canonicalize(serialized_ep)
338
339        if exported_program.example_inputs is not None:
340            example_inputs = export_serialize.serialize_torch_artifact(
341                exported_program.example_inputs
342            )
343        else:
344            example_inputs = b""
345
346        return export_serialize._SerializedProgram(
347            serialized_ep,
348            export_serialize.serialize_torch_artifact(exported_program.state_dict),
349            export_serialize.serialize_torch_artifact(constants),
350            example_inputs,
351        )
352
353
354class GraphModuleDeserializer(export_serialize.GraphModuleDeserializer):
355    def deserialize_operator(self, serialized_target: str) -> str:
356        def find_operator(module: _DialectNamespace, serialized_target: str) -> str:
357            serialized_target_names = serialized_target.split(".")[5:]
358
359            target = module
360            for name in serialized_target_names:
361                if not hasattr(target, name):
362                    return serialized_target
363                else:
364                    target = getattr(target, name)
365            return target
366
367        if serialized_target.startswith("executorch.exir.dialects.edge.ops"):
368            return find_operator(exir_ops.edge, serialized_target)
369        elif serialized_target.startswith("executorch.exir.dialects.backend.ops"):
370            return find_operator(exir_ops.backend, serialized_target)
371
372        return super().deserialize_operator(serialized_target)
373
374    # pyre-ignore
375    def deserialize_inputs_no_schema(self, serialized_node) -> Any:
376        return tuple(
377            self.deserialize_input(input.arg) for input in serialized_node.inputs
378        )
379
380    # pyre-ignore
381    def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> None:
382        if target == "memory.alloc":
383            args = self.deserialize_alloc_inputs(serialized_node.inputs)
384            fx_node = self.graph.create_node(
385                "call_function", memory.alloc, args, {}, "alloc"
386            )
387
388            self.deserialize_arbitrary_outputs(serialized_node, fx_node)
389
390            fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
391            return
392
393        elif target is delegate.executorch_call_delegate:
394            if (
395                len(serialized_node.outputs) == 1
396                and serialized_node.outputs[0].type == "as_tensor"
397            ):
398                # If it's a single tensor return then we can use the name of the
399                # node itself
400                name = serialized_node.outputs[0].value.name
401            else:
402                # Otherwise FX will make a name for us, and we'll have `getitem`
403                # nodes pointed to that
404                name = None
405
406            args = self.deserialize_call_delegate_inputs(serialized_node.inputs)
407            fx_node = self.graph.create_node("call_function", target, args, {}, name)
408
409            self.deserialize_arbitrary_outputs(serialized_node, fx_node)
410
411            fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
412            return
413        elif isinstance(target, EdgeOpOverload):
414            # For convenience: if this node returns a single tensor, name the
415            # newly-created node after it. This ensures that these tensor values
416            # have names that are consistent with serialized.
417            name = (
418                serialized_node.outputs[0].value.name
419                if export_serialize._is_single_tensor_return(target._op)
420                else None  # FX will generate a name for us.
421            )
422            args, kwargs = self.deserialize_inputs(target._op, serialized_node)
423            fx_node = self.graph.create_node(
424                "call_function", target, args, kwargs, name
425            )
426            self.deserialize_outputs(serialized_node, fx_node)
427            fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
428            return
429        elif isinstance(target, str):
430            # Create a dummy fake op if the target does not exist
431            # because we cannot create a call_function node w/o a
432            # callable target
433            log.warning(
434                f"Could not find operator {target}. Returning fake operator."
435            )  # noqa: G004
436
437            # pyre-ignore
438            def fake_op(x):
439                raise NotImplementedError("Fake op is not meant to be run.")
440
441            fake_op.__name__ = target
442            target = fake_op
443
444            args = self.deserialize_inputs_no_schema(serialized_node)
445            fx_node = self.graph.create_node("call_function", target, args, None, None)
446            self.deserialize_arbitrary_outputs(serialized_node, fx_node)
447
448            return
449
450        super().deserialize_node(serialized_node, target)
451
452    def deserialize_outputs(
453        self, serialized_node: schema.Node, fx_node: torch.fx.Node
454    ) -> None:
455        if isinstance(fx_node.target, EdgeOpOverload):
456            # Store the original edge op
457            edge_op = fx_node.target
458            # Replace the edge op with the original ATen op so that we can just call into
459            # node deserialize_outputs implementation present in the parent class.
460            fx_node.target = edge_op._op
461            super().deserialize_outputs(serialized_node, fx_node)
462            # Replace the edge op back.
463            fx_node.target = edge_op
464        else:
465            super().deserialize_outputs(serialized_node, fx_node)
466
467    def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
468        res = super().deserialize_metadata(metadata)
469
470        if debug_handle := metadata.get("debug_handle"):
471            res["debug_handle"] = int(debug_handle)
472
473        return res
474
475    # pyre-ignore
476    def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
477        def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:
478            serialized_alloc_spec_elems = serialized_alloc_spec.split(";")
479            assert len(serialized_alloc_spec_elems) == 2
480            serialized_size_elems = (
481                serialized_alloc_spec_elems[0].strip("()").split(",")
482            )
483
484            size = tuple(int(x) for x in serialized_size_elems if x != "")
485            dtype = export_serialize._SERIALIZE_TO_TORCH_DTYPE[
486                int(serialized_alloc_spec_elems[1])
487            ]
488            return (size, dtype)
489
490        assert serialized_inputs[0].arg.type == "as_string"
491
492        # Single value
493        if len(serialized_inputs) == 1 and serialized_inputs[0].name == "alloc_arg":
494            res = (deserialize_alloc_spec(serialized_inputs[0].arg.value),)
495            return res
496
497        alloc_specs = [
498            deserialize_alloc_spec(serialized_input.arg.value)
499            for serialized_input in serialized_inputs
500        ]
501        return (alloc_specs,)
502
503    def deserialize_arbitrary_outputs(
504        self, serialized_node: schema.Node, fx_node: torch.fx.Node
505    ) -> None:
506        if len(serialized_node.outputs) == 0:
507            return
508        # Single tensor return
509        elif (
510            len(serialized_node.outputs) == 1
511            and serialized_node.outputs[0].type == "as_tensor"
512        ):
513            return self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
514        elif len(serialized_node.outputs) == 1 and isinstance(
515            serialized_node.outputs[0].value,
516            (schema.SymIntArgument, schema.SymBoolArgument),
517        ):
518            self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
519            return
520
521        self.deserialize_multiple_outputs(serialized_node, fx_node)
522
523    # pyre-ignore
524    def deserialize_call_delegate_inputs(
525        self, serialized_inputs: List[schema.NamedArgument]
526    ):
527        serialized_lowered_module = serialized_inputs[0]
528        lowered_module_node = self.deserialize_lowered_module(serialized_lowered_module)
529        serialized_delegate_inputs = serialized_inputs[1:]
530        args = tuple(
531            self.deserialize_input(input.arg) for input in serialized_delegate_inputs
532        )
533        return (lowered_module_node,) + args
534
535    def deserialize_lowered_module(
536        self, serialized_lowered_module_arg: schema.NamedArgument
537    ) -> torch.fx.Node:
538        assert serialized_lowered_module_arg.arg.type == "as_string"
539        lowered_module_str = serialized_lowered_module_arg.arg.value
540        json_lowered_module = json.loads(lowered_module_str)
541        serialized_lowered_module = export_serialize._dict_to_dataclass(
542            SerdeLoweredBackendModule, json_lowered_module
543        )
544
545        backend_id = serialized_lowered_module.backend_id
546        processed_bytes = base64.b64decode(serialized_lowered_module.processed_bytes)
547        compile_specs = [
548            delegate_CompileSpec(key=cs.key, value=base64.b64decode(cs.value))
549            for cs in serialized_lowered_module.compile_specs
550        ]
551
552        original_module = ExportedProgramDeserializer().deserialize(
553            serialized_lowered_module.original_module,
554            base64.b64decode(serialized_lowered_module.original_state_dict),
555            base64.b64decode(serialized_lowered_module.original_constants),
556            None,
557        )
558
559        lowered_module = ExirLoweredBackendModule(
560            original_module,
561            backend_id,
562            processed_bytes,
563            compile_specs,
564        )
565        self.module.register_module(serialized_lowered_module_arg.name, lowered_module)
566        return self.graph.get_attr(serialized_lowered_module_arg.name)
567
568
569class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
570    def deserialize(
571        self,
572        exported_program: export_serialize.ExportedProgram,
573        state_dict: Union[Dict[str, torch.Tensor], bytes],
574        constants: Union[Dict[str, torch.Tensor], bytes],
575        example_inputs: Optional[
576            Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]
577        ] = None,
578    ) -> ep.ExportedProgram:
579        assert isinstance(exported_program, export_serialize.ExportedProgram)
580        version = exported_program.schema_version
581
582        # TODO(zhxchen17) blocked on thrift schema refactor
583        if version.major != SCHEMA_VERSION[0] and not (
584            version.major == 0 and version.minor == 0
585        ):
586            raise SerializeError(
587                f"Serialized schema version {exported_program.schema_version} "
588                f"does not match our current schema version {SCHEMA_VERSION}."
589            )
590
591        symbol_name_to_range = {
592            k: symbolic_shapes.ValueRanges(
593                export_serialize._int_to_sympy_int(v.min_val),
594                export_serialize._int_to_sympy_int(v.max_val),
595            )
596            for k, v in exported_program.range_constraints.items()
597        }
598        res = GraphModuleDeserializer().deserialize(
599            exported_program.graph_module,
600            state_dict,
601            constants,
602            example_inputs,
603            symbol_name_to_range,
604        )
605        range_constraints = self.deserialize_range_constraints(
606            symbol_name_to_range,
607            res.names_to_symbols,
608        )
609        model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version
610        self._validate_model_opset_version(model_opset_version)
611
612        upgrader = GraphModuleOpUpgrader(
613            self.expected_opset_version, model_opset_version
614        )
615
616        dummy_g = torch.fx.Graph()
617        dummy_g.output(())
618        additional_kwargs = {}
619        if hasattr(exported_program, "verifiers"):
620            additional_kwargs["verifiers"] = [
621                load_verifier(v) for v in exported_program.verifiers  # pyre-ignore
622            ]
623        elif hasattr(exported_program, "dialect"):
624            additional_kwargs["verifier"] = load_verifier(
625                exported_program.dialect  # pyre-ignore
626            )
627        exported_program = ep.ExportedProgram(
628            root=res.graph_module,
629            graph=dummy_g,
630            graph_signature=ep.ExportGraphSignature(input_specs=[], output_specs=[]),
631            state_dict=res.state_dict,  # type: ignore[arg-type]
632            range_constraints=range_constraints,
633            module_call_graph=res.module_call_graph,
634            example_inputs=res.example_inputs,
635            constants=res.constants,
636            **additional_kwargs,
637        )
638
639        exported_program.graph_module.graph = res.graph_module.graph
640        exported_program._graph_signature = res.signature
641        for node in res.graph_module.graph.nodes:
642            if node.op == "get_attr":
643                setattr(
644                    exported_program.graph_module,
645                    node.target,
646                    getattr(res.graph_module, node.target),
647                )
648        return upgrader.upgrade(exported_program)
649
650
651def serialize(
652    exported_program: ep.ExportedProgram,
653    opset_version: Optional[Dict[str, int]] = None,
654) -> export_serialize.SerializedArtifact:
655    serialized_artifact = ExportedProgramSerializer(opset_version).serialize(
656        exported_program
657    )
658    assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
659    json_program = json.dumps(
660        export_serialize._dataclass_to_dict(serialized_artifact.exported_program),
661        cls=export_serialize.EnumEncoder,
662    )
663    json_bytes = json_program.encode("utf-8")
664    artifact = export_serialize.SerializedArtifact(
665        json_bytes,
666        serialized_artifact.state_dict,
667        serialized_artifact.constants,
668        serialized_artifact.example_inputs,
669    )
670    return artifact
671
672
673def deserialize(
674    artifact: export_serialize.SerializedArtifact,
675    expected_opset_version: Optional[Dict[str, int]] = None,
676) -> ep.ExportedProgram:
677    assert isinstance(artifact.exported_program, bytes)
678    exported_program_str = artifact.exported_program.decode("utf-8")
679    exported_program_dict = json.loads(exported_program_str)
680    serialized_exported_program = export_serialize._dict_to_dataclass(
681        schema.ExportedProgram, exported_program_dict
682    )
683    return ExportedProgramDeserializer(expected_opset_version).deserialize(
684        serialized_exported_program,
685        artifact.state_dict,
686        artifact.constants,
687        artifact.example_inputs,
688    )
689
690
691def save(
692    ep_save: ep.ExportedProgram,
693    f: Union[str, os.PathLike[str], io.BytesIO],
694    *,
695    extra_files: Optional[Dict[str, Any]] = None,
696    opset_version: Optional[Dict[str, int]] = None,
697) -> None:
698    if not isinstance(ep_save, ep.ExportedProgram):
699        raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
700
701    artifact: export_serialize.SerializedArtifact = serialize(ep_save, opset_version)
702
703    if isinstance(f, (str, os.PathLike)):
704        f = os.fspath(str(f))
705
706    with zipfile.ZipFile(f, "w") as zipf:
707        # Save every field in the SerializedArtifact to a file.
708        assert isinstance(artifact.exported_program, bytes)
709        zipf.writestr("serialized_exported_program.json", artifact.exported_program)
710        zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
711        zipf.writestr("serialized_constants.pt", artifact.constants)
712        zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs)
713
714        zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION)))
715
716        # Add extra files if provided
717        if extra_files:
718            for extra_file_name, content in extra_files.items():
719                encoded_content = content.encode("utf-8")
720                zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
721
722
723def load(
724    f: Union[str, os.PathLike[str], io.BytesIO],
725    *,
726    extra_files: Optional[Dict[str, Any]] = None,
727    expected_opset_version: Optional[Dict[str, int]] = None,
728) -> ep.ExportedProgram:
729    if isinstance(f, (str, os.PathLike)):
730        f = os.fspath(str(f))
731
732    extra_files = extra_files or {}
733
734    with zipfile.ZipFile(f, "r") as zipf:
735        # Check the version
736        version = zipf.read("version").decode().split(".")
737
738        assert len(version) == len(SCHEMA_VERSION)
739        if version[0] != str(SCHEMA_VERSION[0]):
740            raise RuntimeError(
741                f"Serialized version {version} does not match our current "
742                f"schema version {SCHEMA_VERSION}."
743            )
744
745        # Load serialized_ep and serialized_state_dict from the zip file
746
747        serialized_exported_program: Optional[bytes] = None
748        serialized_state_dict: Optional[bytes] = None
749        serialized_constants: Optional[bytes] = None
750        serialized_example_inputs: Optional[bytes] = None
751
752        for file_info in zipf.infolist():
753            file_content = zipf.read(file_info.filename)
754
755            if file_info.filename == "serialized_exported_program.json":
756                serialized_exported_program = file_content
757            elif file_info.filename == "serialized_state_dict.json":
758                print("This version of file is deprecated")
759                serialized_state_dict = file_content
760            elif file_info.filename == "serialized_constants.json":
761                print("This version of file is deprecated")
762                serialized_constants = file_content
763            elif file_info.filename == "serialized_state_dict.pt":
764                serialized_state_dict = file_content
765            elif file_info.filename == "serialized_constants.pt":
766                serialized_constants = file_content
767            elif file_info.filename.startswith("extra_files"):
768                filename = file_info.filename.split("/", 1)[1]
769                extra_files[filename] = file_content.decode("utf-8")
770            elif file_info.filename == "serialized_example_inputs.pt":
771                serialized_example_inputs = file_content
772
773        assert serialized_exported_program is not None
774        assert serialized_state_dict is not None
775        assert serialized_constants is not None
776        assert serialized_example_inputs is not None
777
778        artifact: export_serialize.SerializedArtifact = (
779            export_serialize.SerializedArtifact(
780                serialized_exported_program,
781                serialized_state_dict,
782                serialized_constants,
783                serialized_example_inputs,
784            )
785        )
786
787        # Deserialize ExportedProgram
788        ep = deserialize(artifact, expected_opset_version)
789
790        return ep
791