xref: /aosp_15_r20/external/executorch/exir/serde/serialize.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore-all-errors
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport base64
10*523fa7a6SAndroid Build Coastguard Workerimport io
11*523fa7a6SAndroid Build Coastguard Workerimport json
12*523fa7a6SAndroid Build Coastguard Workerimport logging
13*523fa7a6SAndroid Build Coastguard Workerimport operator
14*523fa7a6SAndroid Build Coastguard Workerimport os
15*523fa7a6SAndroid Build Coastguard Workerimport zipfile
16*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
19*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.memory as memory
20*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.serde.export_serialize as export_serialize
21*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.serde.schema as schema
22*523fa7a6SAndroid Build Coastguard Workerimport torch
23*523fa7a6SAndroid Build Coastguard Workerimport torch.export.exported_program as ep
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import delegate
25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import (
26*523fa7a6SAndroid Build Coastguard Worker    CompileSpec as delegate_CompileSpec,
27*523fa7a6SAndroid Build Coastguard Worker)
28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import _DialectNamespace, ops as exir_ops
29*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.backend._ops import BackendOpOverload
30*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload
31*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import (
32*523fa7a6SAndroid Build Coastguard Worker    LoweredBackendModule as ExirLoweredBackendModule,
33*523fa7a6SAndroid Build Coastguard Worker)
34*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.serde.export_serialize import GraphModuleOpUpgrader, SerializeError
35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.serde.schema import (
36*523fa7a6SAndroid Build Coastguard Worker    CompileSpec,
37*523fa7a6SAndroid Build Coastguard Worker    LoweredBackendModule as SerdeLoweredBackendModule,
38*523fa7a6SAndroid Build Coastguard Worker    SCHEMA_VERSION,
39*523fa7a6SAndroid Build Coastguard Worker    SchemaVersion,
40*523fa7a6SAndroid Build Coastguard Worker)
41*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.verifier import load_verifier
42*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental import symbolic_shapes
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Workerlog: logging.Logger = logging.getLogger(__name__)
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Workerclass GraphModuleSerializer(export_serialize.GraphModuleSerializer):
48*523fa7a6SAndroid Build Coastguard Worker    def __init__(
49*523fa7a6SAndroid Build Coastguard Worker        self,
50*523fa7a6SAndroid Build Coastguard Worker        graph_signature: ep.ExportGraphSignature,
51*523fa7a6SAndroid Build Coastguard Worker        module_call_graph: List[ep.ModuleCallEntry],
52*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
53*523fa7a6SAndroid Build Coastguard Worker        super().__init__(graph_signature, module_call_graph)
54*523fa7a6SAndroid Build Coastguard Worker        self.state_dict: Dict[str, torch.Tensor] = {}  # TODO(T157676982)
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker    def serialize_operator(
57*523fa7a6SAndroid Build Coastguard Worker        self,
58*523fa7a6SAndroid Build Coastguard Worker        target: Union[
59*523fa7a6SAndroid Build Coastguard Worker            str,
60*523fa7a6SAndroid Build Coastguard Worker            EdgeOpOverload,
61*523fa7a6SAndroid Build Coastguard Worker            BackendOpOverload,
62*523fa7a6SAndroid Build Coastguard Worker            torch._ops.OpOverload,
63*523fa7a6SAndroid Build Coastguard Worker            torch._ops.HigherOrderOperator,
64*523fa7a6SAndroid Build Coastguard Worker        ],
65*523fa7a6SAndroid Build Coastguard Worker    ) -> str:
66*523fa7a6SAndroid Build Coastguard Worker        if isinstance(target, str):
67*523fa7a6SAndroid Build Coastguard Worker            return target
68*523fa7a6SAndroid Build Coastguard Worker        elif target.__module__.startswith("executorch.exir.dialects.edge"):
69*523fa7a6SAndroid Build Coastguard Worker            # TODO(zhxchen17) Maybe provide a function name helper in FX.
70*523fa7a6SAndroid Build Coastguard Worker            # From torch.fx.node._get_qualified_name
71*523fa7a6SAndroid Build Coastguard Worker            module = target.__module__.replace(
72*523fa7a6SAndroid Build Coastguard Worker                "executorch.exir.dialects.edge._ops",
73*523fa7a6SAndroid Build Coastguard Worker                "executorch.exir.dialects.edge.ops",
74*523fa7a6SAndroid Build Coastguard Worker            )
75*523fa7a6SAndroid Build Coastguard Worker            return f"{module}.{target.__name__}"
76*523fa7a6SAndroid Build Coastguard Worker        elif target.__module__.startswith("executorch.exir.dialects.backend"):
77*523fa7a6SAndroid Build Coastguard Worker            module = target.__module__.replace(
78*523fa7a6SAndroid Build Coastguard Worker                "executorch.exir.dialects.backend._ops",
79*523fa7a6SAndroid Build Coastguard Worker                "executorch.exir.dialects.backend.ops",
80*523fa7a6SAndroid Build Coastguard Worker            )
81*523fa7a6SAndroid Build Coastguard Worker            return f"{module}.{target.__name__}"
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker        return super().serialize_operator(target)
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker    def handle_call_function(self, node: torch.fx.Node) -> None:
86*523fa7a6SAndroid Build Coastguard Worker        assert node.op == "call_function"
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker        if node.target is memory.alloc:
89*523fa7a6SAndroid Build Coastguard Worker            ex_node = schema.Node(
90*523fa7a6SAndroid Build Coastguard Worker                target="memory.alloc",
91*523fa7a6SAndroid Build Coastguard Worker                inputs=self.serialize_alloc_inputs(node.args),
92*523fa7a6SAndroid Build Coastguard Worker                outputs=self.serialize_arbitrary_outputs(node),
93*523fa7a6SAndroid Build Coastguard Worker                metadata=self.serialize_metadata(node),
94*523fa7a6SAndroid Build Coastguard Worker            )
95*523fa7a6SAndroid Build Coastguard Worker            self.graph_state.nodes.append(ex_node)
96*523fa7a6SAndroid Build Coastguard Worker            return
97*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(node.target, EdgeOpOverload):
98*523fa7a6SAndroid Build Coastguard Worker            assert node.target._op is not None
99*523fa7a6SAndroid Build Coastguard Worker            ex_node = schema.Node(
100*523fa7a6SAndroid Build Coastguard Worker                target=self.serialize_operator(node.target),
101*523fa7a6SAndroid Build Coastguard Worker                # pyre-ignore Undefined attribute [16]: Item `typing.Callable` of
102*523fa7a6SAndroid Build Coastguard Worker                # `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`.
103*523fa7a6SAndroid Build Coastguard Worker                inputs=self.serialize_inputs(node.target._op, node.args, node.kwargs),
104*523fa7a6SAndroid Build Coastguard Worker                outputs=self.serialize_outputs(node),
105*523fa7a6SAndroid Build Coastguard Worker                # TODO: create a new tensor_values here, meta might have faketensor info
106*523fa7a6SAndroid Build Coastguard Worker                metadata=self.serialize_metadata(node),
107*523fa7a6SAndroid Build Coastguard Worker            )
108*523fa7a6SAndroid Build Coastguard Worker            self.graph_state.nodes.append(ex_node)
109*523fa7a6SAndroid Build Coastguard Worker            return
110*523fa7a6SAndroid Build Coastguard Worker        elif node.target is delegate.executorch_call_delegate:
111*523fa7a6SAndroid Build Coastguard Worker            ex_node = schema.Node(
112*523fa7a6SAndroid Build Coastguard Worker                target=self.serialize_operator(node.target),
113*523fa7a6SAndroid Build Coastguard Worker                inputs=self.serialize_call_delegate_inputs(node.args),
114*523fa7a6SAndroid Build Coastguard Worker                outputs=self.serialize_arbitrary_outputs(node),
115*523fa7a6SAndroid Build Coastguard Worker                metadata=self.serialize_metadata(node),
116*523fa7a6SAndroid Build Coastguard Worker            )
117*523fa7a6SAndroid Build Coastguard Worker            self.graph_state.nodes.append(ex_node)
118*523fa7a6SAndroid Build Coastguard Worker            return
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker        super().handle_call_function(node)
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker    def serialize_outputs(self, node: torch.fx.Node) -> List[schema.Argument]:
123*523fa7a6SAndroid Build Coastguard Worker        if isinstance(node.target, EdgeOpOverload):
124*523fa7a6SAndroid Build Coastguard Worker            # Store the original edge op
125*523fa7a6SAndroid Build Coastguard Worker            edge_op = node.target
126*523fa7a6SAndroid Build Coastguard Worker            # Replace the edge op with the original ATen op so that we can just call into
127*523fa7a6SAndroid Build Coastguard Worker            # the serialize_outputs implementation present in the parent class.
128*523fa7a6SAndroid Build Coastguard Worker            node.target = edge_op._op
129*523fa7a6SAndroid Build Coastguard Worker            ret = super().serialize_outputs(node)
130*523fa7a6SAndroid Build Coastguard Worker            # Replace the edge op back.
131*523fa7a6SAndroid Build Coastguard Worker            node.target = edge_op
132*523fa7a6SAndroid Build Coastguard Worker        else:
133*523fa7a6SAndroid Build Coastguard Worker            ret = super().serialize_outputs(node)
134*523fa7a6SAndroid Build Coastguard Worker        return ret
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker    def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
137*523fa7a6SAndroid Build Coastguard Worker        meta = super().serialize_metadata(node)
138*523fa7a6SAndroid Build Coastguard Worker
139*523fa7a6SAndroid Build Coastguard Worker        if "debug_handle" in node.meta:
140*523fa7a6SAndroid Build Coastguard Worker            debug_handle = node.meta["debug_handle"]
141*523fa7a6SAndroid Build Coastguard Worker            meta["debug_handle"] = str(debug_handle)
142*523fa7a6SAndroid Build Coastguard Worker
143*523fa7a6SAndroid Build Coastguard Worker        return meta
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker    def serialize_alloc_inputs(
146*523fa7a6SAndroid Build Coastguard Worker        self, inputs  # pyre-ignore
147*523fa7a6SAndroid Build Coastguard Worker    ) -> List[schema.NamedArgument]:
148*523fa7a6SAndroid Build Coastguard Worker        """
149*523fa7a6SAndroid Build Coastguard Worker        Serialize the inputs to the memory.alloc function. Since there's no
150*523fa7a6SAndroid Build Coastguard Worker        specific spec, we jut serialize the inputs with a dummy name.
151*523fa7a6SAndroid Build Coastguard Worker        We serialize the AllocSpec into a string "size;dtype"
152*523fa7a6SAndroid Build Coastguard Worker        """
153*523fa7a6SAndroid Build Coastguard Worker        assert len(inputs) == 1
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker        def serialize_alloc_spec(alloc_spec: memory.AllocSpec) -> schema.Argument:
156*523fa7a6SAndroid Build Coastguard Worker            return schema.Argument.create(
157*523fa7a6SAndroid Build Coastguard Worker                as_string=f"{alloc_spec[0]};{export_serialize._TORCH_TO_SERIALIZE_DTYPE[alloc_spec[1]].value}"
158*523fa7a6SAndroid Build Coastguard Worker            )
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Worker        if isinstance(inputs[0], list):
161*523fa7a6SAndroid Build Coastguard Worker            return [
162*523fa7a6SAndroid Build Coastguard Worker                schema.NamedArgument(name="alloc_list", arg=serialize_alloc_spec(arg))
163*523fa7a6SAndroid Build Coastguard Worker                for arg in inputs[0]
164*523fa7a6SAndroid Build Coastguard Worker            ]
165*523fa7a6SAndroid Build Coastguard Worker        else:
166*523fa7a6SAndroid Build Coastguard Worker            # Single value
167*523fa7a6SAndroid Build Coastguard Worker            return [
168*523fa7a6SAndroid Build Coastguard Worker                schema.NamedArgument(
169*523fa7a6SAndroid Build Coastguard Worker                    name="alloc_arg", arg=serialize_alloc_spec(inputs[0])
170*523fa7a6SAndroid Build Coastguard Worker                )
171*523fa7a6SAndroid Build Coastguard Worker            ]
172*523fa7a6SAndroid Build Coastguard Worker
173*523fa7a6SAndroid Build Coastguard Worker    def serialize_arbitrary_outputs(self, node: torch.fx.Node) -> List[schema.Argument]:
174*523fa7a6SAndroid Build Coastguard Worker        meta_val = node.meta["val"]
175*523fa7a6SAndroid Build Coastguard Worker
176*523fa7a6SAndroid Build Coastguard Worker        # Check single value return
177*523fa7a6SAndroid Build Coastguard Worker        if isinstance(meta_val, torch.Tensor):
178*523fa7a6SAndroid Build Coastguard Worker            return [
179*523fa7a6SAndroid Build Coastguard Worker                schema.Argument.create(
180*523fa7a6SAndroid Build Coastguard Worker                    as_tensor=self.serialize_tensor_output(node.name, meta_val)
181*523fa7a6SAndroid Build Coastguard Worker                )
182*523fa7a6SAndroid Build Coastguard Worker            ]
183*523fa7a6SAndroid Build Coastguard Worker
184*523fa7a6SAndroid Build Coastguard Worker        # There are a two possibilities at this point:
185*523fa7a6SAndroid Build Coastguard Worker        # - This operator returns a list of Tensors.
186*523fa7a6SAndroid Build Coastguard Worker        # - This operator returns multiple Tensors.
187*523fa7a6SAndroid Build Coastguard Worker        #
188*523fa7a6SAndroid Build Coastguard Worker        # Either way, start by gathering a list of TensorArguments with the correct names.
189*523fa7a6SAndroid Build Coastguard Worker        # For consistent naming with FX, consult the downstream `getitem` node and
190*523fa7a6SAndroid Build Coastguard Worker        # make sure our outputs have the same name.
191*523fa7a6SAndroid Build Coastguard Worker        idx_to_name = {}
192*523fa7a6SAndroid Build Coastguard Worker        for user in node.users:
193*523fa7a6SAndroid Build Coastguard Worker            if user.target is not operator.getitem:
194*523fa7a6SAndroid Build Coastguard Worker                continue
195*523fa7a6SAndroid Build Coastguard Worker            idx_to_name[user.args[1]] = user.name
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Worker        for idx, _ in enumerate(meta_val):
198*523fa7a6SAndroid Build Coastguard Worker            # FX does not emit a getitem node for any outputs that are unused.
199*523fa7a6SAndroid Build Coastguard Worker            # However, we need a name for them so that the number of outputs will
200*523fa7a6SAndroid Build Coastguard Worker            # correctly match the schema. Just assign a dummy name.
201*523fa7a6SAndroid Build Coastguard Worker            if idx not in idx_to_name:
202*523fa7a6SAndroid Build Coastguard Worker                idx_to_name[idx] = f"{node.name}_unused_{idx}"
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker        arg_list = []
205*523fa7a6SAndroid Build Coastguard Worker        for i, element_meta_val in enumerate(meta_val):
206*523fa7a6SAndroid Build Coastguard Worker            arg_list.append(
207*523fa7a6SAndroid Build Coastguard Worker                self.serialize_tensor_output(idx_to_name[i], element_meta_val)
208*523fa7a6SAndroid Build Coastguard Worker            )
209*523fa7a6SAndroid Build Coastguard Worker
210*523fa7a6SAndroid Build Coastguard Worker        if len(meta_val) == 1:
211*523fa7a6SAndroid Build Coastguard Worker            # The operator returns a list of tensors
212*523fa7a6SAndroid Build Coastguard Worker            return [schema.Argument.create(as_tensors=arg_list)]
213*523fa7a6SAndroid Build Coastguard Worker        else:
214*523fa7a6SAndroid Build Coastguard Worker            # The operator returns multiple tensors
215*523fa7a6SAndroid Build Coastguard Worker            return [schema.Argument.create(as_tensor=arg) for arg in arg_list]
216*523fa7a6SAndroid Build Coastguard Worker
217*523fa7a6SAndroid Build Coastguard Worker    def serialize_graph(self, graph_module: torch.fx.GraphModule) -> schema.Graph:
218*523fa7a6SAndroid Build Coastguard Worker        self.original_graph_module: torch.fx.GraphModule = graph_module  # pyre-ignore
219*523fa7a6SAndroid Build Coastguard Worker        return super().serialize_graph(graph_module)
220*523fa7a6SAndroid Build Coastguard Worker
221*523fa7a6SAndroid Build Coastguard Worker    def serialize_call_delegate_inputs(
222*523fa7a6SAndroid Build Coastguard Worker        self, args  # pyre-ignore
223*523fa7a6SAndroid Build Coastguard Worker    ) -> List[schema.NamedArgument]:
224*523fa7a6SAndroid Build Coastguard Worker        lowered_module_arg = args[0]
225*523fa7a6SAndroid Build Coastguard Worker        delegate_args = args[1:]
226*523fa7a6SAndroid Build Coastguard Worker
227*523fa7a6SAndroid Build Coastguard Worker        serialized_lowered_module = self.serialize_lowered_module(lowered_module_arg)
228*523fa7a6SAndroid Build Coastguard Worker        serialized_lowered_module_arg = schema.NamedArgument(
229*523fa7a6SAndroid Build Coastguard Worker            name=lowered_module_arg.target,
230*523fa7a6SAndroid Build Coastguard Worker            arg=schema.Argument.create(as_string=serialized_lowered_module),
231*523fa7a6SAndroid Build Coastguard Worker        )
232*523fa7a6SAndroid Build Coastguard Worker
233*523fa7a6SAndroid Build Coastguard Worker        serialized_args = [serialized_lowered_module_arg]
234*523fa7a6SAndroid Build Coastguard Worker        for i, arg in enumerate(delegate_args):
235*523fa7a6SAndroid Build Coastguard Worker            serialized_args.append(
236*523fa7a6SAndroid Build Coastguard Worker                schema.NamedArgument(
237*523fa7a6SAndroid Build Coastguard Worker                    name=f"delegate_arg_{i}", arg=self.serialize_input(arg)
238*523fa7a6SAndroid Build Coastguard Worker                )
239*523fa7a6SAndroid Build Coastguard Worker            )
240*523fa7a6SAndroid Build Coastguard Worker        return serialized_args
241*523fa7a6SAndroid Build Coastguard Worker
242*523fa7a6SAndroid Build Coastguard Worker    def serialize_lowered_module(self, lowered_module_arg: torch.fx.Node) -> str:
243*523fa7a6SAndroid Build Coastguard Worker        assert lowered_module_arg.op == "get_attr"
244*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(lowered_module_arg.target, str)
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker        def serialize_bytes(b: bytes) -> str:
247*523fa7a6SAndroid Build Coastguard Worker            # We want to serialize the bytes to string because JSON cannot
248*523fa7a6SAndroid Build Coastguard Worker            # serialize bytes.
249*523fa7a6SAndroid Build Coastguard Worker            # Since the given bytes may be serialized with any encoding, so we
250*523fa7a6SAndroid Build Coastguard Worker            # want to first encode with base64, and then decode it with
251*523fa7a6SAndroid Build Coastguard Worker            # ascii. During deserialization we can just directly decode with b64
252*523fa7a6SAndroid Build Coastguard Worker            # to get the original encoded bytes.
253*523fa7a6SAndroid Build Coastguard Worker            return base64.b64encode(b).decode("ascii")
254*523fa7a6SAndroid Build Coastguard Worker
255*523fa7a6SAndroid Build Coastguard Worker        lowered_module = getattr(
256*523fa7a6SAndroid Build Coastguard Worker            lowered_module_arg.graph.owning_module, lowered_module_arg.target
257*523fa7a6SAndroid Build Coastguard Worker        )
258*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(lowered_module, ExirLoweredBackendModule)
259*523fa7a6SAndroid Build Coastguard Worker
260*523fa7a6SAndroid Build Coastguard Worker        serialized_compile_spec = [
261*523fa7a6SAndroid Build Coastguard Worker            CompileSpec(cs.key, serialize_bytes(cs.value))
262*523fa7a6SAndroid Build Coastguard Worker            for cs in lowered_module.compile_specs
263*523fa7a6SAndroid Build Coastguard Worker        ]
264*523fa7a6SAndroid Build Coastguard Worker
265*523fa7a6SAndroid Build Coastguard Worker        serialized_artifact = ExportedProgramSerializer().serialize(
266*523fa7a6SAndroid Build Coastguard Worker            lowered_module.original_module
267*523fa7a6SAndroid Build Coastguard Worker        )
268*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Worker        serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes)
271*523fa7a6SAndroid Build Coastguard Worker
272*523fa7a6SAndroid Build Coastguard Worker        serialized_lowered_module = SerdeLoweredBackendModule(
273*523fa7a6SAndroid Build Coastguard Worker            original_module=serialized_artifact.exported_program,
274*523fa7a6SAndroid Build Coastguard Worker            original_state_dict=serialize_bytes(serialized_artifact.state_dict),
275*523fa7a6SAndroid Build Coastguard Worker            original_constants=serialize_bytes(serialized_artifact.constants),
276*523fa7a6SAndroid Build Coastguard Worker            processed_bytes=serialized_processed_bytes,
277*523fa7a6SAndroid Build Coastguard Worker            compile_specs=serialized_compile_spec,
278*523fa7a6SAndroid Build Coastguard Worker            backend_id=lowered_module.backend_id,
279*523fa7a6SAndroid Build Coastguard Worker        )
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Worker        json_lowered_module = json.dumps(
282*523fa7a6SAndroid Build Coastguard Worker            export_serialize._dataclass_to_dict(serialized_lowered_module),
283*523fa7a6SAndroid Build Coastguard Worker            cls=export_serialize.EnumEncoder,
284*523fa7a6SAndroid Build Coastguard Worker        )
285*523fa7a6SAndroid Build Coastguard Worker        return json_lowered_module
286*523fa7a6SAndroid Build Coastguard Worker
287*523fa7a6SAndroid Build Coastguard Worker
288*523fa7a6SAndroid Build Coastguard Workerclass ExportedProgramSerializer(export_serialize.ExportedProgramSerializer):
289*523fa7a6SAndroid Build Coastguard Worker    def serialize(
290*523fa7a6SAndroid Build Coastguard Worker        self, exported_program: ep.ExportedProgram
291*523fa7a6SAndroid Build Coastguard Worker    ) -> export_serialize._SerializedProgram:
292*523fa7a6SAndroid Build Coastguard Worker        """
293*523fa7a6SAndroid Build Coastguard Worker        Args:
294*523fa7a6SAndroid Build Coastguard Worker            exported_program: Exported Program to serialize
295*523fa7a6SAndroid Build Coastguard Worker        """
296*523fa7a6SAndroid Build Coastguard Worker
297*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(exported_program, ep.ExportedProgram)
298*523fa7a6SAndroid Build Coastguard Worker
299*523fa7a6SAndroid Build Coastguard Worker        gm_serializer = GraphModuleSerializer(
300*523fa7a6SAndroid Build Coastguard Worker            exported_program.graph_signature, exported_program.module_call_graph
301*523fa7a6SAndroid Build Coastguard Worker        )
302*523fa7a6SAndroid Build Coastguard Worker        serialized_graph_module = gm_serializer.serialize(exported_program.graph_module)
303*523fa7a6SAndroid Build Coastguard Worker        serialized_range_constraints = export_serialize.serialize_range_constraints(
304*523fa7a6SAndroid Build Coastguard Worker            exported_program.range_constraints
305*523fa7a6SAndroid Build Coastguard Worker        )
306*523fa7a6SAndroid Build Coastguard Worker
307*523fa7a6SAndroid Build Coastguard Worker        # TODO: Directly serialize exported_program.constants once
308*523fa7a6SAndroid Build Coastguard Worker        # CustomClassHolders get stored in the ExportedProgram rather than in
309*523fa7a6SAndroid Build Coastguard Worker        # the graph
310*523fa7a6SAndroid Build Coastguard Worker        constants = {}
311*523fa7a6SAndroid Build Coastguard Worker        for n, c in gm_serializer.custom_objs.items():
312*523fa7a6SAndroid Build Coastguard Worker            constants[n] = c
313*523fa7a6SAndroid Build Coastguard Worker        for n, t in exported_program.constants.items():
314*523fa7a6SAndroid Build Coastguard Worker            assert n not in constants
315*523fa7a6SAndroid Build Coastguard Worker            constants[n] = t
316*523fa7a6SAndroid Build Coastguard Worker
317*523fa7a6SAndroid Build Coastguard Worker        additional_kwargs = {}
318*523fa7a6SAndroid Build Coastguard Worker        if hasattr(exported_program, "verifiers"):
319*523fa7a6SAndroid Build Coastguard Worker            additional_kwargs["verifiers"] = [
320*523fa7a6SAndroid Build Coastguard Worker                v.dialect for v in exported_program.verifiers
321*523fa7a6SAndroid Build Coastguard Worker            ]
322*523fa7a6SAndroid Build Coastguard Worker        elif hasattr(exported_program, "dialect"):
323*523fa7a6SAndroid Build Coastguard Worker            additional_kwargs["dialect"] = exported_program.dialect
324*523fa7a6SAndroid Build Coastguard Worker        serialized_ep = schema.ExportedProgram(
325*523fa7a6SAndroid Build Coastguard Worker            graph_module=serialized_graph_module,
326*523fa7a6SAndroid Build Coastguard Worker            opset_version=self.opset_version,
327*523fa7a6SAndroid Build Coastguard Worker            range_constraints=serialized_range_constraints,
328*523fa7a6SAndroid Build Coastguard Worker            schema_version=SchemaVersion(
329*523fa7a6SAndroid Build Coastguard Worker                major=SCHEMA_VERSION[0],
330*523fa7a6SAndroid Build Coastguard Worker                minor=SCHEMA_VERSION[1],
331*523fa7a6SAndroid Build Coastguard Worker            ),
332*523fa7a6SAndroid Build Coastguard Worker            **additional_kwargs,
333*523fa7a6SAndroid Build Coastguard Worker        )
334*523fa7a6SAndroid Build Coastguard Worker
335*523fa7a6SAndroid Build Coastguard Worker        # Test canonical form is well defined.
336*523fa7a6SAndroid Build Coastguard Worker        # TODO : Doesn't pass currently on executorch graphs with alloc nodes.
337*523fa7a6SAndroid Build Coastguard Worker        # canonicalize(serialized_ep)
338*523fa7a6SAndroid Build Coastguard Worker
339*523fa7a6SAndroid Build Coastguard Worker        if exported_program.example_inputs is not None:
340*523fa7a6SAndroid Build Coastguard Worker            example_inputs = export_serialize.serialize_torch_artifact(
341*523fa7a6SAndroid Build Coastguard Worker                exported_program.example_inputs
342*523fa7a6SAndroid Build Coastguard Worker            )
343*523fa7a6SAndroid Build Coastguard Worker        else:
344*523fa7a6SAndroid Build Coastguard Worker            example_inputs = b""
345*523fa7a6SAndroid Build Coastguard Worker
346*523fa7a6SAndroid Build Coastguard Worker        return export_serialize._SerializedProgram(
347*523fa7a6SAndroid Build Coastguard Worker            serialized_ep,
348*523fa7a6SAndroid Build Coastguard Worker            export_serialize.serialize_torch_artifact(exported_program.state_dict),
349*523fa7a6SAndroid Build Coastguard Worker            export_serialize.serialize_torch_artifact(constants),
350*523fa7a6SAndroid Build Coastguard Worker            example_inputs,
351*523fa7a6SAndroid Build Coastguard Worker        )
352*523fa7a6SAndroid Build Coastguard Worker
353*523fa7a6SAndroid Build Coastguard Worker
354*523fa7a6SAndroid Build Coastguard Workerclass GraphModuleDeserializer(export_serialize.GraphModuleDeserializer):
355*523fa7a6SAndroid Build Coastguard Worker    def deserialize_operator(self, serialized_target: str) -> str:
356*523fa7a6SAndroid Build Coastguard Worker        def find_operator(module: _DialectNamespace, serialized_target: str) -> str:
357*523fa7a6SAndroid Build Coastguard Worker            serialized_target_names = serialized_target.split(".")[5:]
358*523fa7a6SAndroid Build Coastguard Worker
359*523fa7a6SAndroid Build Coastguard Worker            target = module
360*523fa7a6SAndroid Build Coastguard Worker            for name in serialized_target_names:
361*523fa7a6SAndroid Build Coastguard Worker                if not hasattr(target, name):
362*523fa7a6SAndroid Build Coastguard Worker                    return serialized_target
363*523fa7a6SAndroid Build Coastguard Worker                else:
364*523fa7a6SAndroid Build Coastguard Worker                    target = getattr(target, name)
365*523fa7a6SAndroid Build Coastguard Worker            return target
366*523fa7a6SAndroid Build Coastguard Worker
367*523fa7a6SAndroid Build Coastguard Worker        if serialized_target.startswith("executorch.exir.dialects.edge.ops"):
368*523fa7a6SAndroid Build Coastguard Worker            return find_operator(exir_ops.edge, serialized_target)
369*523fa7a6SAndroid Build Coastguard Worker        elif serialized_target.startswith("executorch.exir.dialects.backend.ops"):
370*523fa7a6SAndroid Build Coastguard Worker            return find_operator(exir_ops.backend, serialized_target)
371*523fa7a6SAndroid Build Coastguard Worker
372*523fa7a6SAndroid Build Coastguard Worker        return super().deserialize_operator(serialized_target)
373*523fa7a6SAndroid Build Coastguard Worker
374*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
375*523fa7a6SAndroid Build Coastguard Worker    def deserialize_inputs_no_schema(self, serialized_node) -> Any:
376*523fa7a6SAndroid Build Coastguard Worker        return tuple(
377*523fa7a6SAndroid Build Coastguard Worker            self.deserialize_input(input.arg) for input in serialized_node.inputs
378*523fa7a6SAndroid Build Coastguard Worker        )
379*523fa7a6SAndroid Build Coastguard Worker
380*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
381*523fa7a6SAndroid Build Coastguard Worker    def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> None:
382*523fa7a6SAndroid Build Coastguard Worker        if target == "memory.alloc":
383*523fa7a6SAndroid Build Coastguard Worker            args = self.deserialize_alloc_inputs(serialized_node.inputs)
384*523fa7a6SAndroid Build Coastguard Worker            fx_node = self.graph.create_node(
385*523fa7a6SAndroid Build Coastguard Worker                "call_function", memory.alloc, args, {}, "alloc"
386*523fa7a6SAndroid Build Coastguard Worker            )
387*523fa7a6SAndroid Build Coastguard Worker
388*523fa7a6SAndroid Build Coastguard Worker            self.deserialize_arbitrary_outputs(serialized_node, fx_node)
389*523fa7a6SAndroid Build Coastguard Worker
390*523fa7a6SAndroid Build Coastguard Worker            fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
391*523fa7a6SAndroid Build Coastguard Worker            return
392*523fa7a6SAndroid Build Coastguard Worker
393*523fa7a6SAndroid Build Coastguard Worker        elif target is delegate.executorch_call_delegate:
394*523fa7a6SAndroid Build Coastguard Worker            if (
395*523fa7a6SAndroid Build Coastguard Worker                len(serialized_node.outputs) == 1
396*523fa7a6SAndroid Build Coastguard Worker                and serialized_node.outputs[0].type == "as_tensor"
397*523fa7a6SAndroid Build Coastguard Worker            ):
398*523fa7a6SAndroid Build Coastguard Worker                # If it's a single tensor return then we can use the name of the
399*523fa7a6SAndroid Build Coastguard Worker                # node itself
400*523fa7a6SAndroid Build Coastguard Worker                name = serialized_node.outputs[0].value.name
401*523fa7a6SAndroid Build Coastguard Worker            else:
402*523fa7a6SAndroid Build Coastguard Worker                # Otherwise FX will make a name for us, and we'll have `getitem`
403*523fa7a6SAndroid Build Coastguard Worker                # nodes pointed to that
404*523fa7a6SAndroid Build Coastguard Worker                name = None
405*523fa7a6SAndroid Build Coastguard Worker
406*523fa7a6SAndroid Build Coastguard Worker            args = self.deserialize_call_delegate_inputs(serialized_node.inputs)
407*523fa7a6SAndroid Build Coastguard Worker            fx_node = self.graph.create_node("call_function", target, args, {}, name)
408*523fa7a6SAndroid Build Coastguard Worker
409*523fa7a6SAndroid Build Coastguard Worker            self.deserialize_arbitrary_outputs(serialized_node, fx_node)
410*523fa7a6SAndroid Build Coastguard Worker
411*523fa7a6SAndroid Build Coastguard Worker            fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
412*523fa7a6SAndroid Build Coastguard Worker            return
413*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(target, EdgeOpOverload):
414*523fa7a6SAndroid Build Coastguard Worker            # For convenience: if this node returns a single tensor, name the
415*523fa7a6SAndroid Build Coastguard Worker            # newly-created node after it. This ensures that these tensor values
416*523fa7a6SAndroid Build Coastguard Worker            # have names that are consistent with serialized.
417*523fa7a6SAndroid Build Coastguard Worker            name = (
418*523fa7a6SAndroid Build Coastguard Worker                serialized_node.outputs[0].value.name
419*523fa7a6SAndroid Build Coastguard Worker                if export_serialize._is_single_tensor_return(target._op)
420*523fa7a6SAndroid Build Coastguard Worker                else None  # FX will generate a name for us.
421*523fa7a6SAndroid Build Coastguard Worker            )
422*523fa7a6SAndroid Build Coastguard Worker            args, kwargs = self.deserialize_inputs(target._op, serialized_node)
423*523fa7a6SAndroid Build Coastguard Worker            fx_node = self.graph.create_node(
424*523fa7a6SAndroid Build Coastguard Worker                "call_function", target, args, kwargs, name
425*523fa7a6SAndroid Build Coastguard Worker            )
426*523fa7a6SAndroid Build Coastguard Worker            self.deserialize_outputs(serialized_node, fx_node)
427*523fa7a6SAndroid Build Coastguard Worker            fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
428*523fa7a6SAndroid Build Coastguard Worker            return
429*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(target, str):
430*523fa7a6SAndroid Build Coastguard Worker            # Create a dummy fake op if the target does not exist
431*523fa7a6SAndroid Build Coastguard Worker            # because we cannot create a call_function node w/o a
432*523fa7a6SAndroid Build Coastguard Worker            # callable target
433*523fa7a6SAndroid Build Coastguard Worker            log.warning(
434*523fa7a6SAndroid Build Coastguard Worker                f"Could not find operator {target}. Returning fake operator."
435*523fa7a6SAndroid Build Coastguard Worker            )  # noqa: G004
436*523fa7a6SAndroid Build Coastguard Worker
437*523fa7a6SAndroid Build Coastguard Worker            # pyre-ignore
438*523fa7a6SAndroid Build Coastguard Worker            def fake_op(x):
439*523fa7a6SAndroid Build Coastguard Worker                raise NotImplementedError("Fake op is not meant to be run.")
440*523fa7a6SAndroid Build Coastguard Worker
441*523fa7a6SAndroid Build Coastguard Worker            fake_op.__name__ = target
442*523fa7a6SAndroid Build Coastguard Worker            target = fake_op
443*523fa7a6SAndroid Build Coastguard Worker
444*523fa7a6SAndroid Build Coastguard Worker            args = self.deserialize_inputs_no_schema(serialized_node)
445*523fa7a6SAndroid Build Coastguard Worker            fx_node = self.graph.create_node("call_function", target, args, None, None)
446*523fa7a6SAndroid Build Coastguard Worker            self.deserialize_arbitrary_outputs(serialized_node, fx_node)
447*523fa7a6SAndroid Build Coastguard Worker
448*523fa7a6SAndroid Build Coastguard Worker            return
449*523fa7a6SAndroid Build Coastguard Worker
450*523fa7a6SAndroid Build Coastguard Worker        super().deserialize_node(serialized_node, target)
451*523fa7a6SAndroid Build Coastguard Worker
452*523fa7a6SAndroid Build Coastguard Worker    def deserialize_outputs(
453*523fa7a6SAndroid Build Coastguard Worker        self, serialized_node: schema.Node, fx_node: torch.fx.Node
454*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
455*523fa7a6SAndroid Build Coastguard Worker        if isinstance(fx_node.target, EdgeOpOverload):
456*523fa7a6SAndroid Build Coastguard Worker            # Store the original edge op
457*523fa7a6SAndroid Build Coastguard Worker            edge_op = fx_node.target
458*523fa7a6SAndroid Build Coastguard Worker            # Replace the edge op with the original ATen op so that we can just call into
459*523fa7a6SAndroid Build Coastguard Worker            # node deserialize_outputs implementation present in the parent class.
460*523fa7a6SAndroid Build Coastguard Worker            fx_node.target = edge_op._op
461*523fa7a6SAndroid Build Coastguard Worker            super().deserialize_outputs(serialized_node, fx_node)
462*523fa7a6SAndroid Build Coastguard Worker            # Replace the edge op back.
463*523fa7a6SAndroid Build Coastguard Worker            fx_node.target = edge_op
464*523fa7a6SAndroid Build Coastguard Worker        else:
465*523fa7a6SAndroid Build Coastguard Worker            super().deserialize_outputs(serialized_node, fx_node)
466*523fa7a6SAndroid Build Coastguard Worker
467*523fa7a6SAndroid Build Coastguard Worker    def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
468*523fa7a6SAndroid Build Coastguard Worker        res = super().deserialize_metadata(metadata)
469*523fa7a6SAndroid Build Coastguard Worker
470*523fa7a6SAndroid Build Coastguard Worker        if debug_handle := metadata.get("debug_handle"):
471*523fa7a6SAndroid Build Coastguard Worker            res["debug_handle"] = int(debug_handle)
472*523fa7a6SAndroid Build Coastguard Worker
473*523fa7a6SAndroid Build Coastguard Worker        return res
474*523fa7a6SAndroid Build Coastguard Worker
475*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
476*523fa7a6SAndroid Build Coastguard Worker    def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
477*523fa7a6SAndroid Build Coastguard Worker        def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:
478*523fa7a6SAndroid Build Coastguard Worker            serialized_alloc_spec_elems = serialized_alloc_spec.split(";")
479*523fa7a6SAndroid Build Coastguard Worker            assert len(serialized_alloc_spec_elems) == 2
480*523fa7a6SAndroid Build Coastguard Worker            serialized_size_elems = (
481*523fa7a6SAndroid Build Coastguard Worker                serialized_alloc_spec_elems[0].strip("()").split(",")
482*523fa7a6SAndroid Build Coastguard Worker            )
483*523fa7a6SAndroid Build Coastguard Worker
484*523fa7a6SAndroid Build Coastguard Worker            size = tuple(int(x) for x in serialized_size_elems if x != "")
485*523fa7a6SAndroid Build Coastguard Worker            dtype = export_serialize._SERIALIZE_TO_TORCH_DTYPE[
486*523fa7a6SAndroid Build Coastguard Worker                int(serialized_alloc_spec_elems[1])
487*523fa7a6SAndroid Build Coastguard Worker            ]
488*523fa7a6SAndroid Build Coastguard Worker            return (size, dtype)
489*523fa7a6SAndroid Build Coastguard Worker
490*523fa7a6SAndroid Build Coastguard Worker        assert serialized_inputs[0].arg.type == "as_string"
491*523fa7a6SAndroid Build Coastguard Worker
492*523fa7a6SAndroid Build Coastguard Worker        # Single value
493*523fa7a6SAndroid Build Coastguard Worker        if len(serialized_inputs) == 1 and serialized_inputs[0].name == "alloc_arg":
494*523fa7a6SAndroid Build Coastguard Worker            res = (deserialize_alloc_spec(serialized_inputs[0].arg.value),)
495*523fa7a6SAndroid Build Coastguard Worker            return res
496*523fa7a6SAndroid Build Coastguard Worker
497*523fa7a6SAndroid Build Coastguard Worker        alloc_specs = [
498*523fa7a6SAndroid Build Coastguard Worker            deserialize_alloc_spec(serialized_input.arg.value)
499*523fa7a6SAndroid Build Coastguard Worker            for serialized_input in serialized_inputs
500*523fa7a6SAndroid Build Coastguard Worker        ]
501*523fa7a6SAndroid Build Coastguard Worker        return (alloc_specs,)
502*523fa7a6SAndroid Build Coastguard Worker
503*523fa7a6SAndroid Build Coastguard Worker    def deserialize_arbitrary_outputs(
504*523fa7a6SAndroid Build Coastguard Worker        self, serialized_node: schema.Node, fx_node: torch.fx.Node
505*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
506*523fa7a6SAndroid Build Coastguard Worker        if len(serialized_node.outputs) == 0:
507*523fa7a6SAndroid Build Coastguard Worker            return
508*523fa7a6SAndroid Build Coastguard Worker        # Single tensor return
509*523fa7a6SAndroid Build Coastguard Worker        elif (
510*523fa7a6SAndroid Build Coastguard Worker            len(serialized_node.outputs) == 1
511*523fa7a6SAndroid Build Coastguard Worker            and serialized_node.outputs[0].type == "as_tensor"
512*523fa7a6SAndroid Build Coastguard Worker        ):
513*523fa7a6SAndroid Build Coastguard Worker            return self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
514*523fa7a6SAndroid Build Coastguard Worker        elif len(serialized_node.outputs) == 1 and isinstance(
515*523fa7a6SAndroid Build Coastguard Worker            serialized_node.outputs[0].value,
516*523fa7a6SAndroid Build Coastguard Worker            (schema.SymIntArgument, schema.SymBoolArgument),
517*523fa7a6SAndroid Build Coastguard Worker        ):
518*523fa7a6SAndroid Build Coastguard Worker            self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
519*523fa7a6SAndroid Build Coastguard Worker            return
520*523fa7a6SAndroid Build Coastguard Worker
521*523fa7a6SAndroid Build Coastguard Worker        self.deserialize_multiple_outputs(serialized_node, fx_node)
522*523fa7a6SAndroid Build Coastguard Worker
523*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
524*523fa7a6SAndroid Build Coastguard Worker    def deserialize_call_delegate_inputs(
525*523fa7a6SAndroid Build Coastguard Worker        self, serialized_inputs: List[schema.NamedArgument]
526*523fa7a6SAndroid Build Coastguard Worker    ):
527*523fa7a6SAndroid Build Coastguard Worker        serialized_lowered_module = serialized_inputs[0]
528*523fa7a6SAndroid Build Coastguard Worker        lowered_module_node = self.deserialize_lowered_module(serialized_lowered_module)
529*523fa7a6SAndroid Build Coastguard Worker        serialized_delegate_inputs = serialized_inputs[1:]
530*523fa7a6SAndroid Build Coastguard Worker        args = tuple(
531*523fa7a6SAndroid Build Coastguard Worker            self.deserialize_input(input.arg) for input in serialized_delegate_inputs
532*523fa7a6SAndroid Build Coastguard Worker        )
533*523fa7a6SAndroid Build Coastguard Worker        return (lowered_module_node,) + args
534*523fa7a6SAndroid Build Coastguard Worker
535*523fa7a6SAndroid Build Coastguard Worker    def deserialize_lowered_module(
536*523fa7a6SAndroid Build Coastguard Worker        self, serialized_lowered_module_arg: schema.NamedArgument
537*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.fx.Node:
538*523fa7a6SAndroid Build Coastguard Worker        assert serialized_lowered_module_arg.arg.type == "as_string"
539*523fa7a6SAndroid Build Coastguard Worker        lowered_module_str = serialized_lowered_module_arg.arg.value
540*523fa7a6SAndroid Build Coastguard Worker        json_lowered_module = json.loads(lowered_module_str)
541*523fa7a6SAndroid Build Coastguard Worker        serialized_lowered_module = export_serialize._dict_to_dataclass(
542*523fa7a6SAndroid Build Coastguard Worker            SerdeLoweredBackendModule, json_lowered_module
543*523fa7a6SAndroid Build Coastguard Worker        )
544*523fa7a6SAndroid Build Coastguard Worker
545*523fa7a6SAndroid Build Coastguard Worker        backend_id = serialized_lowered_module.backend_id
546*523fa7a6SAndroid Build Coastguard Worker        processed_bytes = base64.b64decode(serialized_lowered_module.processed_bytes)
547*523fa7a6SAndroid Build Coastguard Worker        compile_specs = [
548*523fa7a6SAndroid Build Coastguard Worker            delegate_CompileSpec(key=cs.key, value=base64.b64decode(cs.value))
549*523fa7a6SAndroid Build Coastguard Worker            for cs in serialized_lowered_module.compile_specs
550*523fa7a6SAndroid Build Coastguard Worker        ]
551*523fa7a6SAndroid Build Coastguard Worker
552*523fa7a6SAndroid Build Coastguard Worker        original_module = ExportedProgramDeserializer().deserialize(
553*523fa7a6SAndroid Build Coastguard Worker            serialized_lowered_module.original_module,
554*523fa7a6SAndroid Build Coastguard Worker            base64.b64decode(serialized_lowered_module.original_state_dict),
555*523fa7a6SAndroid Build Coastguard Worker            base64.b64decode(serialized_lowered_module.original_constants),
556*523fa7a6SAndroid Build Coastguard Worker            None,
557*523fa7a6SAndroid Build Coastguard Worker        )
558*523fa7a6SAndroid Build Coastguard Worker
559*523fa7a6SAndroid Build Coastguard Worker        lowered_module = ExirLoweredBackendModule(
560*523fa7a6SAndroid Build Coastguard Worker            original_module,
561*523fa7a6SAndroid Build Coastguard Worker            backend_id,
562*523fa7a6SAndroid Build Coastguard Worker            processed_bytes,
563*523fa7a6SAndroid Build Coastguard Worker            compile_specs,
564*523fa7a6SAndroid Build Coastguard Worker        )
565*523fa7a6SAndroid Build Coastguard Worker        self.module.register_module(serialized_lowered_module_arg.name, lowered_module)
566*523fa7a6SAndroid Build Coastguard Worker        return self.graph.get_attr(serialized_lowered_module_arg.name)
567*523fa7a6SAndroid Build Coastguard Worker
568*523fa7a6SAndroid Build Coastguard Worker
569*523fa7a6SAndroid Build Coastguard Workerclass ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
570*523fa7a6SAndroid Build Coastguard Worker    def deserialize(
571*523fa7a6SAndroid Build Coastguard Worker        self,
572*523fa7a6SAndroid Build Coastguard Worker        exported_program: export_serialize.ExportedProgram,
573*523fa7a6SAndroid Build Coastguard Worker        state_dict: Union[Dict[str, torch.Tensor], bytes],
574*523fa7a6SAndroid Build Coastguard Worker        constants: Union[Dict[str, torch.Tensor], bytes],
575*523fa7a6SAndroid Build Coastguard Worker        example_inputs: Optional[
576*523fa7a6SAndroid Build Coastguard Worker            Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]
577*523fa7a6SAndroid Build Coastguard Worker        ] = None,
578*523fa7a6SAndroid Build Coastguard Worker    ) -> ep.ExportedProgram:
579*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(exported_program, export_serialize.ExportedProgram)
580*523fa7a6SAndroid Build Coastguard Worker        version = exported_program.schema_version
581*523fa7a6SAndroid Build Coastguard Worker
582*523fa7a6SAndroid Build Coastguard Worker        # TODO(zhxchen17) blocked on thrift schema refactor
583*523fa7a6SAndroid Build Coastguard Worker        if version.major != SCHEMA_VERSION[0] and not (
584*523fa7a6SAndroid Build Coastguard Worker            version.major == 0 and version.minor == 0
585*523fa7a6SAndroid Build Coastguard Worker        ):
586*523fa7a6SAndroid Build Coastguard Worker            raise SerializeError(
587*523fa7a6SAndroid Build Coastguard Worker                f"Serialized schema version {exported_program.schema_version} "
588*523fa7a6SAndroid Build Coastguard Worker                f"does not match our current schema version {SCHEMA_VERSION}."
589*523fa7a6SAndroid Build Coastguard Worker            )
590*523fa7a6SAndroid Build Coastguard Worker
591*523fa7a6SAndroid Build Coastguard Worker        symbol_name_to_range = {
592*523fa7a6SAndroid Build Coastguard Worker            k: symbolic_shapes.ValueRanges(
593*523fa7a6SAndroid Build Coastguard Worker                export_serialize._int_to_sympy_int(v.min_val),
594*523fa7a6SAndroid Build Coastguard Worker                export_serialize._int_to_sympy_int(v.max_val),
595*523fa7a6SAndroid Build Coastguard Worker            )
596*523fa7a6SAndroid Build Coastguard Worker            for k, v in exported_program.range_constraints.items()
597*523fa7a6SAndroid Build Coastguard Worker        }
598*523fa7a6SAndroid Build Coastguard Worker        res = GraphModuleDeserializer().deserialize(
599*523fa7a6SAndroid Build Coastguard Worker            exported_program.graph_module,
600*523fa7a6SAndroid Build Coastguard Worker            state_dict,
601*523fa7a6SAndroid Build Coastguard Worker            constants,
602*523fa7a6SAndroid Build Coastguard Worker            example_inputs,
603*523fa7a6SAndroid Build Coastguard Worker            symbol_name_to_range,
604*523fa7a6SAndroid Build Coastguard Worker        )
605*523fa7a6SAndroid Build Coastguard Worker        range_constraints = self.deserialize_range_constraints(
606*523fa7a6SAndroid Build Coastguard Worker            symbol_name_to_range,
607*523fa7a6SAndroid Build Coastguard Worker            res.names_to_symbols,
608*523fa7a6SAndroid Build Coastguard Worker        )
609*523fa7a6SAndroid Build Coastguard Worker        model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version
610*523fa7a6SAndroid Build Coastguard Worker        self._validate_model_opset_version(model_opset_version)
611*523fa7a6SAndroid Build Coastguard Worker
612*523fa7a6SAndroid Build Coastguard Worker        upgrader = GraphModuleOpUpgrader(
613*523fa7a6SAndroid Build Coastguard Worker            self.expected_opset_version, model_opset_version
614*523fa7a6SAndroid Build Coastguard Worker        )
615*523fa7a6SAndroid Build Coastguard Worker
616*523fa7a6SAndroid Build Coastguard Worker        dummy_g = torch.fx.Graph()
617*523fa7a6SAndroid Build Coastguard Worker        dummy_g.output(())
618*523fa7a6SAndroid Build Coastguard Worker        additional_kwargs = {}
619*523fa7a6SAndroid Build Coastguard Worker        if hasattr(exported_program, "verifiers"):
620*523fa7a6SAndroid Build Coastguard Worker            additional_kwargs["verifiers"] = [
621*523fa7a6SAndroid Build Coastguard Worker                load_verifier(v) for v in exported_program.verifiers  # pyre-ignore
622*523fa7a6SAndroid Build Coastguard Worker            ]
623*523fa7a6SAndroid Build Coastguard Worker        elif hasattr(exported_program, "dialect"):
624*523fa7a6SAndroid Build Coastguard Worker            additional_kwargs["verifier"] = load_verifier(
625*523fa7a6SAndroid Build Coastguard Worker                exported_program.dialect  # pyre-ignore
626*523fa7a6SAndroid Build Coastguard Worker            )
627*523fa7a6SAndroid Build Coastguard Worker        exported_program = ep.ExportedProgram(
628*523fa7a6SAndroid Build Coastguard Worker            root=res.graph_module,
629*523fa7a6SAndroid Build Coastguard Worker            graph=dummy_g,
630*523fa7a6SAndroid Build Coastguard Worker            graph_signature=ep.ExportGraphSignature(input_specs=[], output_specs=[]),
631*523fa7a6SAndroid Build Coastguard Worker            state_dict=res.state_dict,  # type: ignore[arg-type]
632*523fa7a6SAndroid Build Coastguard Worker            range_constraints=range_constraints,
633*523fa7a6SAndroid Build Coastguard Worker            module_call_graph=res.module_call_graph,
634*523fa7a6SAndroid Build Coastguard Worker            example_inputs=res.example_inputs,
635*523fa7a6SAndroid Build Coastguard Worker            constants=res.constants,
636*523fa7a6SAndroid Build Coastguard Worker            **additional_kwargs,
637*523fa7a6SAndroid Build Coastguard Worker        )
638*523fa7a6SAndroid Build Coastguard Worker
639*523fa7a6SAndroid Build Coastguard Worker        exported_program.graph_module.graph = res.graph_module.graph
640*523fa7a6SAndroid Build Coastguard Worker        exported_program._graph_signature = res.signature
641*523fa7a6SAndroid Build Coastguard Worker        for node in res.graph_module.graph.nodes:
642*523fa7a6SAndroid Build Coastguard Worker            if node.op == "get_attr":
643*523fa7a6SAndroid Build Coastguard Worker                setattr(
644*523fa7a6SAndroid Build Coastguard Worker                    exported_program.graph_module,
645*523fa7a6SAndroid Build Coastguard Worker                    node.target,
646*523fa7a6SAndroid Build Coastguard Worker                    getattr(res.graph_module, node.target),
647*523fa7a6SAndroid Build Coastguard Worker                )
648*523fa7a6SAndroid Build Coastguard Worker        return upgrader.upgrade(exported_program)
649*523fa7a6SAndroid Build Coastguard Worker
650*523fa7a6SAndroid Build Coastguard Worker
651*523fa7a6SAndroid Build Coastguard Workerdef serialize(
652*523fa7a6SAndroid Build Coastguard Worker    exported_program: ep.ExportedProgram,
653*523fa7a6SAndroid Build Coastguard Worker    opset_version: Optional[Dict[str, int]] = None,
654*523fa7a6SAndroid Build Coastguard Worker) -> export_serialize.SerializedArtifact:
655*523fa7a6SAndroid Build Coastguard Worker    serialized_artifact = ExportedProgramSerializer(opset_version).serialize(
656*523fa7a6SAndroid Build Coastguard Worker        exported_program
657*523fa7a6SAndroid Build Coastguard Worker    )
658*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
659*523fa7a6SAndroid Build Coastguard Worker    json_program = json.dumps(
660*523fa7a6SAndroid Build Coastguard Worker        export_serialize._dataclass_to_dict(serialized_artifact.exported_program),
661*523fa7a6SAndroid Build Coastguard Worker        cls=export_serialize.EnumEncoder,
662*523fa7a6SAndroid Build Coastguard Worker    )
663*523fa7a6SAndroid Build Coastguard Worker    json_bytes = json_program.encode("utf-8")
664*523fa7a6SAndroid Build Coastguard Worker    artifact = export_serialize.SerializedArtifact(
665*523fa7a6SAndroid Build Coastguard Worker        json_bytes,
666*523fa7a6SAndroid Build Coastguard Worker        serialized_artifact.state_dict,
667*523fa7a6SAndroid Build Coastguard Worker        serialized_artifact.constants,
668*523fa7a6SAndroid Build Coastguard Worker        serialized_artifact.example_inputs,
669*523fa7a6SAndroid Build Coastguard Worker    )
670*523fa7a6SAndroid Build Coastguard Worker    return artifact
671*523fa7a6SAndroid Build Coastguard Worker
672*523fa7a6SAndroid Build Coastguard Worker
673*523fa7a6SAndroid Build Coastguard Workerdef deserialize(
674*523fa7a6SAndroid Build Coastguard Worker    artifact: export_serialize.SerializedArtifact,
675*523fa7a6SAndroid Build Coastguard Worker    expected_opset_version: Optional[Dict[str, int]] = None,
676*523fa7a6SAndroid Build Coastguard Worker) -> ep.ExportedProgram:
677*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(artifact.exported_program, bytes)
678*523fa7a6SAndroid Build Coastguard Worker    exported_program_str = artifact.exported_program.decode("utf-8")
679*523fa7a6SAndroid Build Coastguard Worker    exported_program_dict = json.loads(exported_program_str)
680*523fa7a6SAndroid Build Coastguard Worker    serialized_exported_program = export_serialize._dict_to_dataclass(
681*523fa7a6SAndroid Build Coastguard Worker        schema.ExportedProgram, exported_program_dict
682*523fa7a6SAndroid Build Coastguard Worker    )
683*523fa7a6SAndroid Build Coastguard Worker    return ExportedProgramDeserializer(expected_opset_version).deserialize(
684*523fa7a6SAndroid Build Coastguard Worker        serialized_exported_program,
685*523fa7a6SAndroid Build Coastguard Worker        artifact.state_dict,
686*523fa7a6SAndroid Build Coastguard Worker        artifact.constants,
687*523fa7a6SAndroid Build Coastguard Worker        artifact.example_inputs,
688*523fa7a6SAndroid Build Coastguard Worker    )
689*523fa7a6SAndroid Build Coastguard Worker
690*523fa7a6SAndroid Build Coastguard Worker
691*523fa7a6SAndroid Build Coastguard Workerdef save(
692*523fa7a6SAndroid Build Coastguard Worker    ep_save: ep.ExportedProgram,
693*523fa7a6SAndroid Build Coastguard Worker    f: Union[str, os.PathLike[str], io.BytesIO],
694*523fa7a6SAndroid Build Coastguard Worker    *,
695*523fa7a6SAndroid Build Coastguard Worker    extra_files: Optional[Dict[str, Any]] = None,
696*523fa7a6SAndroid Build Coastguard Worker    opset_version: Optional[Dict[str, int]] = None,
697*523fa7a6SAndroid Build Coastguard Worker) -> None:
698*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(ep_save, ep.ExportedProgram):
699*523fa7a6SAndroid Build Coastguard Worker        raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
700*523fa7a6SAndroid Build Coastguard Worker
701*523fa7a6SAndroid Build Coastguard Worker    artifact: export_serialize.SerializedArtifact = serialize(ep_save, opset_version)
702*523fa7a6SAndroid Build Coastguard Worker
703*523fa7a6SAndroid Build Coastguard Worker    if isinstance(f, (str, os.PathLike)):
704*523fa7a6SAndroid Build Coastguard Worker        f = os.fspath(str(f))
705*523fa7a6SAndroid Build Coastguard Worker
706*523fa7a6SAndroid Build Coastguard Worker    with zipfile.ZipFile(f, "w") as zipf:
707*523fa7a6SAndroid Build Coastguard Worker        # Save every field in the SerializedArtifact to a file.
708*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(artifact.exported_program, bytes)
709*523fa7a6SAndroid Build Coastguard Worker        zipf.writestr("serialized_exported_program.json", artifact.exported_program)
710*523fa7a6SAndroid Build Coastguard Worker        zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
711*523fa7a6SAndroid Build Coastguard Worker        zipf.writestr("serialized_constants.pt", artifact.constants)
712*523fa7a6SAndroid Build Coastguard Worker        zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs)
713*523fa7a6SAndroid Build Coastguard Worker
714*523fa7a6SAndroid Build Coastguard Worker        zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION)))
715*523fa7a6SAndroid Build Coastguard Worker
716*523fa7a6SAndroid Build Coastguard Worker        # Add extra files if provided
717*523fa7a6SAndroid Build Coastguard Worker        if extra_files:
718*523fa7a6SAndroid Build Coastguard Worker            for extra_file_name, content in extra_files.items():
719*523fa7a6SAndroid Build Coastguard Worker                encoded_content = content.encode("utf-8")
720*523fa7a6SAndroid Build Coastguard Worker                zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
721*523fa7a6SAndroid Build Coastguard Worker
722*523fa7a6SAndroid Build Coastguard Worker
723*523fa7a6SAndroid Build Coastguard Workerdef load(
724*523fa7a6SAndroid Build Coastguard Worker    f: Union[str, os.PathLike[str], io.BytesIO],
725*523fa7a6SAndroid Build Coastguard Worker    *,
726*523fa7a6SAndroid Build Coastguard Worker    extra_files: Optional[Dict[str, Any]] = None,
727*523fa7a6SAndroid Build Coastguard Worker    expected_opset_version: Optional[Dict[str, int]] = None,
728*523fa7a6SAndroid Build Coastguard Worker) -> ep.ExportedProgram:
729*523fa7a6SAndroid Build Coastguard Worker    if isinstance(f, (str, os.PathLike)):
730*523fa7a6SAndroid Build Coastguard Worker        f = os.fspath(str(f))
731*523fa7a6SAndroid Build Coastguard Worker
732*523fa7a6SAndroid Build Coastguard Worker    extra_files = extra_files or {}
733*523fa7a6SAndroid Build Coastguard Worker
734*523fa7a6SAndroid Build Coastguard Worker    with zipfile.ZipFile(f, "r") as zipf:
735*523fa7a6SAndroid Build Coastguard Worker        # Check the version
736*523fa7a6SAndroid Build Coastguard Worker        version = zipf.read("version").decode().split(".")
737*523fa7a6SAndroid Build Coastguard Worker
738*523fa7a6SAndroid Build Coastguard Worker        assert len(version) == len(SCHEMA_VERSION)
739*523fa7a6SAndroid Build Coastguard Worker        if version[0] != str(SCHEMA_VERSION[0]):
740*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
741*523fa7a6SAndroid Build Coastguard Worker                f"Serialized version {version} does not match our current "
742*523fa7a6SAndroid Build Coastguard Worker                f"schema version {SCHEMA_VERSION}."
743*523fa7a6SAndroid Build Coastguard Worker            )
744*523fa7a6SAndroid Build Coastguard Worker
745*523fa7a6SAndroid Build Coastguard Worker        # Load serialized_ep and serialized_state_dict from the zip file
746*523fa7a6SAndroid Build Coastguard Worker
747*523fa7a6SAndroid Build Coastguard Worker        serialized_exported_program: Optional[bytes] = None
748*523fa7a6SAndroid Build Coastguard Worker        serialized_state_dict: Optional[bytes] = None
749*523fa7a6SAndroid Build Coastguard Worker        serialized_constants: Optional[bytes] = None
750*523fa7a6SAndroid Build Coastguard Worker        serialized_example_inputs: Optional[bytes] = None
751*523fa7a6SAndroid Build Coastguard Worker
752*523fa7a6SAndroid Build Coastguard Worker        for file_info in zipf.infolist():
753*523fa7a6SAndroid Build Coastguard Worker            file_content = zipf.read(file_info.filename)
754*523fa7a6SAndroid Build Coastguard Worker
755*523fa7a6SAndroid Build Coastguard Worker            if file_info.filename == "serialized_exported_program.json":
756*523fa7a6SAndroid Build Coastguard Worker                serialized_exported_program = file_content
757*523fa7a6SAndroid Build Coastguard Worker            elif file_info.filename == "serialized_state_dict.json":
758*523fa7a6SAndroid Build Coastguard Worker                print("This version of file is deprecated")
759*523fa7a6SAndroid Build Coastguard Worker                serialized_state_dict = file_content
760*523fa7a6SAndroid Build Coastguard Worker            elif file_info.filename == "serialized_constants.json":
761*523fa7a6SAndroid Build Coastguard Worker                print("This version of file is deprecated")
762*523fa7a6SAndroid Build Coastguard Worker                serialized_constants = file_content
763*523fa7a6SAndroid Build Coastguard Worker            elif file_info.filename == "serialized_state_dict.pt":
764*523fa7a6SAndroid Build Coastguard Worker                serialized_state_dict = file_content
765*523fa7a6SAndroid Build Coastguard Worker            elif file_info.filename == "serialized_constants.pt":
766*523fa7a6SAndroid Build Coastguard Worker                serialized_constants = file_content
767*523fa7a6SAndroid Build Coastguard Worker            elif file_info.filename.startswith("extra_files"):
768*523fa7a6SAndroid Build Coastguard Worker                filename = file_info.filename.split("/", 1)[1]
769*523fa7a6SAndroid Build Coastguard Worker                extra_files[filename] = file_content.decode("utf-8")
770*523fa7a6SAndroid Build Coastguard Worker            elif file_info.filename == "serialized_example_inputs.pt":
771*523fa7a6SAndroid Build Coastguard Worker                serialized_example_inputs = file_content
772*523fa7a6SAndroid Build Coastguard Worker
773*523fa7a6SAndroid Build Coastguard Worker        assert serialized_exported_program is not None
774*523fa7a6SAndroid Build Coastguard Worker        assert serialized_state_dict is not None
775*523fa7a6SAndroid Build Coastguard Worker        assert serialized_constants is not None
776*523fa7a6SAndroid Build Coastguard Worker        assert serialized_example_inputs is not None
777*523fa7a6SAndroid Build Coastguard Worker
778*523fa7a6SAndroid Build Coastguard Worker        artifact: export_serialize.SerializedArtifact = (
779*523fa7a6SAndroid Build Coastguard Worker            export_serialize.SerializedArtifact(
780*523fa7a6SAndroid Build Coastguard Worker                serialized_exported_program,
781*523fa7a6SAndroid Build Coastguard Worker                serialized_state_dict,
782*523fa7a6SAndroid Build Coastguard Worker                serialized_constants,
783*523fa7a6SAndroid Build Coastguard Worker                serialized_example_inputs,
784*523fa7a6SAndroid Build Coastguard Worker            )
785*523fa7a6SAndroid Build Coastguard Worker        )
786*523fa7a6SAndroid Build Coastguard Worker
787*523fa7a6SAndroid Build Coastguard Worker        # Deserialize ExportedProgram
788*523fa7a6SAndroid Build Coastguard Worker        ep = deserialize(artifact, expected_opset_version)
789*523fa7a6SAndroid Build Coastguard Worker
790*523fa7a6SAndroid Build Coastguard Worker        return ep
791