xref: /aosp_15_r20/external/executorch/exir/serde/export_serialize.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-ignore-all-errors
8
9# Copied over from caffe2/torch/_export/serde/serialize.py until dialects
10# are supported in torch export serializer.
11
12import base64
13import copy
14import copyreg
15import dataclasses
16import heapq
17import inspect
18import io
19import json
20import logging
21import math
22import operator
23import re
24import typing
25
26from contextlib import contextmanager
27from dataclasses import dataclass, field
28from enum import Enum
29from typing import (
30    Any,
31    Callable,
32    cast,
33    Dict,
34    final,
35    Iterator,
36    List,
37    Optional,
38    Set,
39    Tuple,
40    Union,
41)
42
43import sympy
44
45import torch
46import torch.export.exported_program
47import torch.export.exported_program as ep
48from torch._export.serde.schema import SchemaVersion
49from torch._export.verifier import load_verifier
50from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
51from torch.fx.experimental import symbolic_shapes
52from torch.utils import _pytree as pytree
53from torch.utils._pytree import treespec_dumps, treespec_loads
54from torch.utils._sympy.numbers import int_oo
55from torch.utils._sympy.value_ranges import ValueRanges
56
57# pyre-ignore
58
59from .schema import (  # type: ignore[attr-defined]
60    Argument,
61    BufferMutationSpec,
62    ConstantInputSpec,
63    ConstantValue,
64    CustomObjArgument,
65    Device,
66    ExportedProgram,
67    GradientToParameterSpec,
68    GradientToUserInputSpec,
69    Graph,
70    GraphArgument,
71    GraphModule,
72    GraphSignature,
73    InputSpec,
74    InputToBufferSpec,
75    InputToCustomObjSpec,
76    InputTokenSpec,
77    InputToParameterSpec,
78    InputToTensorConstantSpec,
79    Layout,
80    LossOutputSpec,
81    MemoryFormat,
82    ModuleCallEntry,
83    ModuleCallSignature,
84    NamedArgument,
85    Node,
86    OptionalTensorArgument,
87    OutputSpec,
88    OutputTokenSpec,
89    RangeConstraint,
90    ScalarType,
91    SCHEMA_VERSION,
92    SymBool,
93    SymBoolArgument,
94    SymExpr,
95    SymExprHint,
96    SymInt,
97    SymIntArgument,
98    TensorArgument,
99    TensorMeta,
100    TokenArgument,
101    TREESPEC_VERSION,
102    UserInputMutationSpec,
103    UserInputSpec,
104    UserOutputSpec,
105)
106from .union import _Union
107
108
109__all__ = [
110    "serialize",
111    "GraphModuleSerializer",
112    "ExportedProgramSerializer",
113    "GraphModuleDeserializer",
114    "ExportedProgramDeserializer",
115]
116
117from .upgrade import GraphModuleOpUpgrader
118
119log = logging.getLogger(__name__)
120
121
122class SerializeError(RuntimeError):
123    pass
124
125
126def _reverse_map(d: Dict[Any, Enum]):
127    return {v.value: k for k, v in d.items()}
128
129
130MetaType = Union[
131    FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument
132]
133
134
135ST_DELIMITER = ";"
136
137_TORCH_TO_SERIALIZE_DTYPE = {
138    torch.uint8: ScalarType.BYTE,
139    torch.int8: ScalarType.CHAR,
140    torch.int16: ScalarType.SHORT,
141    torch.int32: ScalarType.INT,
142    torch.int64: ScalarType.LONG,
143    torch.float16: ScalarType.HALF,
144    torch.float32: ScalarType.FLOAT,
145    torch.float64: ScalarType.DOUBLE,
146    torch.complex32: ScalarType.COMPLEXHALF,
147    torch.complex64: ScalarType.COMPLEXFLOAT,
148    torch.complex128: ScalarType.COMPLEXDOUBLE,
149    torch.bool: ScalarType.BOOL,
150    torch.bfloat16: ScalarType.BFLOAT16,
151    torch.uint16: ScalarType.UINT16
152}
153
154
155_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE)  # type: ignore[arg-type]
156
157
158_TORCH_TO_SERIALIZE_LAYOUT = {
159    torch.sparse_coo: Layout.SparseCoo,
160    torch.sparse_csr: Layout.SparseCsr,
161    torch.sparse_csc: Layout.SparseCsc,
162    torch.sparse_bsr: Layout.SparseBsr,
163    torch.sparse_bsc: Layout.SparseBsc,
164    torch._mkldnn: Layout._mkldnn,  # type: ignore[attr-defined]
165    torch.strided: Layout.Strided,
166}
167
168
169_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT)  # type: ignore[arg-type]
170
171
172_TORCH_TO_SERIALIZE_MEMORY_FORMAT = {
173    torch.contiguous_format: MemoryFormat.ContiguousFormat,
174    torch.channels_last: MemoryFormat.ChannelsLast,
175    torch.channels_last_3d: MemoryFormat.ChannelsLast3d,
176    torch.preserve_format: MemoryFormat.PreserveFormat,
177}
178
179
180_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT)  # type: ignore[arg-type]
181
182
183_SYM_INT_OPS = {
184    operator.mul,
185    operator.add,
186    operator.sub,
187    operator.floordiv,
188    operator.mod,
189    torch.sym_int,
190    torch.sym_float,
191    torch.sym_ite,
192    torch.sym_max,
193    torch.sym_min,
194    torch.sym_sqrt,
195}
196
197
198_SYM_BOOL_OPS = {
199    operator.eq,
200    operator.ne,
201    operator.le,
202    operator.ge,
203    operator.lt,
204    operator.gt,
205    torch.sym_not,
206}
207
208
209@dataclass
210class SerializedArtifact:
211    exported_program: bytes
212    state_dict: bytes
213    constants: bytes
214    example_inputs: bytes
215
216
217@dataclass
218class _SerializedProgram:
219    exported_program: ExportedProgram
220    state_dict: bytes
221    constants: bytes
222    example_inputs: bytes
223
224
225def deserialize_device(d: Device) -> torch.device:
226    if d.index is None:
227        return torch.device(type=d.type)  # type: ignore[call-overload]
228    return torch.device(type=d.type, index=d.index)
229
230
231def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt:
232    if isinstance(s, (torch.SymInt, int)):
233        if symbolic_shapes.is_concrete_int(s):
234            return SymInt.create(as_int=int(s))
235        else:
236            assert isinstance(s, torch.SymInt)
237            if s.node.hint is None:
238                return SymInt.create(as_expr=SymExpr(str(s)))
239            else:
240                return SymInt.create(
241                    as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint))
242                )
243    else:
244        raise SerializeError(
245            f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`"
246        )
247
248
249def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool:
250    if isinstance(s, (torch.SymBool, bool)):
251        if symbolic_shapes.is_concrete_bool(s):
252            return SymBool.create(as_bool=bool(s))
253        else:
254            return SymBool.create(as_expr=SymExpr(expr_str=str(s)))
255    else:
256        raise SerializeError(
257            f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`"
258        )
259
260
261def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta:
262    """
263    Extract a TensorMeta describing `t`.
264    """
265    return TensorMeta(
266        dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype],
267        sizes=[serialize_sym_int(s) for s in t.shape],
268        requires_grad=t.requires_grad,
269        device=Device(type=t.device.type, index=t.device.index),
270        strides=[serialize_sym_int(s) for s in t.stride()],
271        storage_offset=serialize_sym_int(0),  # TODO needs to be fixed.
272        layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout],
273    )
274
275
276_CURRENT_DESERIALIZER: List["GraphModuleDeserializer"] = []
277
278
279def _reduce_fake_tensor(fake_tensor: FakeTensor):
280    is_parameter = isinstance(fake_tensor, torch.nn.Parameter)
281    tensor_meta = serialize_tensor_meta(fake_tensor)
282    tensor_meta_bytes = json.dumps(
283        _dataclass_to_dict(tensor_meta), cls=EnumEncoder
284    ).encode("utf-8")
285    return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter)
286
287
288def _reconstruct_fake_tensor(
289    serialized_tensor_meta: bytes, is_parameter: bool
290) -> FakeTensor:
291    # Deserialize the bytes into a TensorMeta
292    json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8"))
293    tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta)
294    # Find the current fake mode
295    assert len(_CURRENT_DESERIALIZER) != 0, "Need access to current deserializer state"
296    fake_tensor = _CURRENT_DESERIALIZER[-1].deserialize_tensor_meta(tensor_meta)
297    if is_parameter:
298        fake_tensor = torch.nn.Parameter(fake_tensor)  # type: ignore[assignment]
299    return fake_tensor
300
301
302def serialize_torch_artifact(artifact: Dict[str, Any]) -> bytes:
303    assert (
304        FakeTensor not in copyreg.dispatch_table
305    ), "Refusing to stomp on existing FakeTensor reducer"
306    try:
307        copyreg.pickle(FakeTensor, _reduce_fake_tensor)
308        buffer = io.BytesIO()
309        # This is a workaround for backend's tensor deserialization problem:
310        # unpickleTensor() always create a tensor on the device where it was originally saved
311        # This behavior is bad for multi-gpu training, as we wish to directly load the tensor
312        # on the designated device.
313        # For now, we simply move the tensor to cpu before saving.
314        # TODO: this should be fixed by deserialization instead.
315        torch.save(artifact, buffer)
316        return buffer.getvalue()
317    finally:
318        del copyreg.dispatch_table[FakeTensor]
319
320
321def deserialize_torch_artifact(
322    serialized: Union[Dict[str, Any], Tuple[Any, ...], bytes]
323):
324    if isinstance(serialized, (dict, tuple)):
325        return serialized
326    if len(serialized) == 0:
327        return {}
328    buffer = io.BytesIO(serialized)
329    buffer.seek(0)
330    artifact = torch.load(buffer)
331    assert isinstance(artifact, (tuple, dict))
332    return artifact
333
334
335def _sympy_int_to_int(val: sympy.Expr, adjust: str):
336    # Convert simple sympy Integers into concrete int
337    if val in (sympy.oo, int_oo):
338        return math.inf
339    if val in (-sympy.oo, -int_oo):
340        return -math.inf
341    if isinstance(val, sympy.Integer):
342        return int(val)
343
344    # TODO: Remove this adjustment when Ed gets rid of fractional ranges
345    log.warning(
346        "Export constraints cannot be non-integer expressions. Found "
347        "type %s, and value %s. We will attempt to %s "
348        "this value.",
349        type(val),
350        val,
351        adjust,
352    )
353
354    if adjust == "floor":
355        return math.floor(val)
356    elif adjust == "ceil":
357        return math.ceil(val)
358    else:
359        raise RuntimeError(f"Got invalid adjustment {adjust}")
360
361
362def _int_to_sympy_int(val) -> sympy.Expr:
363    # Convert concrete int into simple sympy Integers
364    if val == math.inf:
365        return int_oo
366    if val == -math.inf:
367        return -int_oo
368    return sympy.Integer(val)
369
370
371def serialize_range_constraints(
372    range_constraints: Dict[sympy.Symbol, ValueRanges]
373) -> Dict[str, RangeConstraint]:
374    return {
375        str(k): RangeConstraint(
376            _sympy_int_to_int(v.lower, "ceil"),  # type: ignore[arg-type]
377            _sympy_int_to_int(v.upper, "floor"),  # type: ignore[arg-type]
378        )
379        for k, v in range_constraints.items()
380    }
381
382
383def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool:
384    returns = target._schema.returns
385    return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType)
386
387
388def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool:
389    returns = target._schema.returns
390    if len(returns) != 1:
391        return False
392    return_type = returns[0].real_type
393    return isinstance(return_type, torch.ListType) and isinstance(
394        return_type.getElementType(), torch.TensorType
395    )
396
397
398def _output_node_at_index(node, index):
399    for user in node.users:
400        assert user.target is operator.getitem, f"{user} is not a getitem node"
401        if index == user.args[1]:
402            return user
403    return None
404
405
406@dataclass
407class GraphState:
408    inputs: List[Argument] = field(default_factory=list)
409    outputs: List[Argument] = field(default_factory=list)
410    nodes: List[Node] = field(default_factory=list)
411    tensor_values: Dict[str, TensorMeta] = field(default_factory=dict)
412    sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
413    sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
414    is_single_tensor_return: bool = False
415    custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
416
417
418class Final(type):
419    def __new__(metacls, name, bases, classdict):
420        for b in bases:
421            if isinstance(b, Final):
422                raise TypeError(f"type '{b.__name__}' is not an acceptable base type")
423        return type.__new__(metacls, name, bases, dict(classdict))
424
425
426class GraphModuleSerializer:
427    def __init__(
428        self,
429        graph_signature: ep.ExportGraphSignature,
430        module_call_graph: List[ep.ModuleCallEntry],
431    ):
432        self.graph_state = GraphState()
433        self.graph_signature = graph_signature
434        self.module_call_graph = module_call_graph
435        self.custom_objs: Dict[str, torch._C.ScriptObject] = {}
436
437    @contextmanager
438    def save_graph_state(self):
439        saved = self.graph_state
440        self.graph_state = GraphState()
441        try:
442            yield
443        finally:
444            self.graph_state = saved
445
446    def handle_placeholder(self, node: torch.fx.Node):
447        assert node.op == "placeholder"
448        if isinstance(node.meta["val"], torch.Tensor):
449            graph_input = Argument.create(as_tensor=TensorArgument(name=node.name))
450            self.graph_state.tensor_values[node.name] = serialize_tensor_meta(
451                node.meta["val"]
452            )
453        elif isinstance(node.meta["val"], torch.SymInt):
454            graph_input = Argument.create(
455                as_sym_int=SymIntArgument.create(as_name=node.name)
456            )
457            self.graph_state.sym_int_values[node.name] = serialize_sym_int(
458                node.meta["val"]
459            )
460        elif isinstance(node.meta["val"], (int, bool, str, float, type(None))):
461            graph_input = self.serialize_input(node.meta["val"])
462        elif isinstance(node.meta["val"], ep.CustomObjArgument):
463            class_fqn = node.meta["val"].class_fqn
464            graph_input = Argument.create(
465                as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)
466            )
467            self.graph_state.custom_obj_values[node.name] = (
468                self.serialize_script_obj_meta(node.meta["val"])
469            )
470        else:
471            raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}")
472        self.graph_state.inputs.append(graph_input)
473
474    def handle_output(self, node: torch.fx.Node):
475        assert node.op == "output"
476        assert len(node.args) == 1, "FX.Node's args should have one arg"
477        node_args = node.args[0]
478        if isinstance(node_args, torch.fx.Node):
479            # For singleton tensor returns
480            self.graph_state.is_single_tensor_return = True
481            self.graph_state.outputs = [self.serialize_input(node_args)]
482        else:
483            assert isinstance(node_args, (tuple, list))
484            self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args]
485
486    def serialize_operator(self, target) -> str:
487        if isinstance(target, str):
488            return target
489        elif target.__module__.startswith("torch._ops"):
490            # TODO(zhxchen17) Maybe provide a function name helper in FX.
491            # From torch.fx.node._get_qualified_name
492            module = target.__module__.replace("torch._ops", "torch.ops")
493            return f"{module}.{target.__name__}"
494        else:  # TODO(zhxchen17) Don't catch all here.
495            return f"{target.__module__}.{target.__name__}"
496
497    def handle_call_function(self, node: torch.fx.Node):
498        assert node.op == "call_function"
499
500        # getitem has been handled in the producer node, skip it here
501        if node.target is operator.getitem:
502            return
503
504        if node.target in _SYM_INT_OPS:
505            assert len(node.kwargs) == 0
506            meta_val = node.meta["val"]
507            ex_node = Node(
508                target=self.serialize_operator(node.target),
509                inputs=self.serialize_sym_op_inputs(node.target, node.args),
510                outputs=[
511                    Argument.create(
512                        as_sym_int=self.serialize_sym_int_output(node.name, meta_val)
513                    )
514                ],
515                metadata=self.serialize_metadata(node),
516            )
517        elif node.target in _SYM_BOOL_OPS:
518            assert len(node.kwargs) == 0
519            meta_val = node.meta["val"]
520            ex_node = Node(
521                target=self.serialize_operator(node.target),
522                inputs=self.serialize_sym_op_inputs(node.target, node.args),
523                outputs=[
524                    Argument.create(
525                        as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val)
526                    )
527                ],
528                metadata=self.serialize_metadata(node),
529            )
530        elif isinstance(node.target, torch._ops.OpOverload):
531            ex_node = Node(
532                target=self.serialize_operator(node.target),
533                inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
534                outputs=self.serialize_outputs(node),
535                # TODO: create a new tensor_values here, meta might have faketensor info
536                metadata=self.serialize_metadata(node),
537            )
538        elif isinstance(node.target, torch._ops.HigherOrderOperator):
539            ex_node = Node(
540                target=self.serialize_operator(node.target),
541                inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
542                outputs=self.serialize_hoo_outputs(node),
543                metadata=self.serialize_metadata(node),
544            )
545        else:
546            raise SerializeError(f"Serializing {node.target} is not supported")
547
548        self.graph_state.nodes.append(ex_node)
549
550    def handle_get_attr(self, node):
551        pass
552
553    def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
554        ret = {}
555        if stack_trace := node.meta.get("stack_trace"):
556            ret["stack_trace"] = stack_trace
557
558        if nn_module_stack := node.meta.get("nn_module_stack"):
559
560            def export_nn_module_stack(val):
561                assert isinstance(val, tuple) and len(val) == 2
562                path, ty = val
563
564                assert isinstance(path, str)
565
566                # node.meta["nn_module_stack"] could have two forms:
567                # 1. (path: str, module_type: 'type'), e.g.
568                #    ('', <class 'sigmoid.inference.MySimpleModel'>)
569                # 2. (path: str, module_type: str), e.g.
570                #    ('', 'sigmoid.inference.MySimpleModel')
571                # ExportedProgram directly produced by torch.export() has form 1
572                # ExportedProgram deserialized from disk has form 2
573                # TODO: This is not ideal, we should fix this.
574                if isinstance(ty, str):
575                    normalized_ty = ty
576                else:
577                    normalized_ty = ty.__module__ + "." + ty.__qualname__
578
579                return path + "," + normalized_ty
580
581            # Serialize to "key,orig_path,type_str"
582            nn_module_list = [
583                f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items()
584            ]
585            ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list)
586
587        if source_fn_st := node.meta.get("source_fn_stack"):
588            source_fn_list = [
589                f"{source_fn[0]},{self.serialize_operator(source_fn[1])}"
590                for source_fn in source_fn_st
591            ]
592            ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list)
593
594        if torch_fn := node.meta.get("torch_fn"):
595            ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn))
596
597        return ret
598
599    def serialize_script_obj_meta(
600        self, script_obj_meta: ep.CustomObjArgument
601    ) -> CustomObjArgument:
602        return CustomObjArgument(
603            name=script_obj_meta.name,
604            class_fqn=script_obj_meta.class_fqn,
605        )
606
607    def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]:
608        serialized_args = []
609        args_names = inspect.signature(op).parameters.keys()
610        for args_name, arg in zip(args_names, args):
611            serialized_args.append(
612                NamedArgument(name=args_name, arg=self.serialize_input(arg))
613            )
614        return serialized_args
615
616    def serialize_inputs(
617        self, target: torch._ops.OpOverload, args, kwargs=None
618    ) -> List[NamedArgument]:
619        assert isinstance(target, torch._ops.OpOverload)
620        kwargs = kwargs or {}
621        serialized_args = []
622        for i, schema_arg in enumerate(target._schema.arguments):
623            if schema_arg.name in kwargs:
624                serialized_args.append(
625                    NamedArgument(
626                        name=schema_arg.name,
627                        arg=self.serialize_input(
628                            kwargs[schema_arg.name], schema_arg.type
629                        ),
630                    )
631                )
632            elif not schema_arg.kwarg_only and i < len(args):
633                serialized_args.append(
634                    NamedArgument(
635                        name=schema_arg.name,
636                        arg=self.serialize_input(args[i], schema_arg.type),
637                    )
638                )
639            else:
640                # We intentionally don't serialize the missing arguments
641                # with default values
642                pass
643
644        return serialized_args
645
646    def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]:
647        """
648        For serializing HOO inputs since HOOs do not have a schema.
649        """
650        inputs = [
651            NamedArgument(
652                name="",
653                arg=self.serialize_input(a),
654            )
655            for a in args
656        ]
657        inputs.extend(
658            [
659                NamedArgument(name=name, arg=self.serialize_input(a))
660                for name, a in kwargs.items()
661            ]
662        )
663        return inputs
664
665    def is_sym_int_arg(self, arg) -> bool:
666        return isinstance(arg, int) or (
667            isinstance(arg, torch.fx.Node)
668            and arg.name in self.graph_state.sym_int_values
669        )
670
671    def is_sym_bool_arg(self, arg) -> bool:
672        return isinstance(arg, bool) or (
673            isinstance(arg, torch.fx.Node)
674            and arg.name in self.graph_state.sym_bool_values
675        )
676
677    def serialize_input(
678        self, arg, arg_type: Optional[torch._C.Argument] = None
679    ) -> Argument:
680        import torch._inductor.ir as inductor_ir
681
682        inductor_tensor_buffers = (
683            inductor_ir.Buffer,
684            inductor_ir.ReinterpretView,
685        )
686
687        if isinstance(arg, torch.fx.Node):
688            if arg.op == "get_attr":
689                assert isinstance(arg.target, str)
690                attr = getattr(arg.graph.owning_module, arg.target)
691
692                if isinstance(attr, torch.Tensor):
693                    raise SerializeError(
694                        "getattr nodes containing tensors should not appear in the graph"
695                    )
696                elif isinstance(attr, torch.fx.GraphModule):
697                    with self.save_graph_state():
698                        graph = self.serialize_graph(attr)
699                    return Argument.create(
700                        as_graph=GraphArgument(name=arg.target, graph=graph)
701                    )
702                else:
703                    raise SerializeError(
704                        f"Unsupported getattr attribute {arg.target} with type: {type(attr)}"
705                    )
706            elif self.is_sym_int_arg(arg):
707                return Argument.create(
708                    as_sym_int=SymIntArgument.create(as_name=arg.name)
709                )
710            elif self.is_sym_bool_arg(arg):
711                return Argument.create(
712                    as_sym_bool=SymBoolArgument.create(as_name=arg.name)
713                )
714            else:
715                if isinstance(arg.meta["val"], ep.CustomObjArgument):
716                    return Argument.create(
717                        as_custom_obj=CustomObjArgument(
718                            name=arg.name, class_fqn=arg.meta["val"].class_fqn
719                        )
720                    )
721                return Argument.create(as_tensor=TensorArgument(name=arg.name))
722        elif isinstance(arg, inductor_tensor_buffers):
723            # Other branches are for arguments in fx node.
724            # This is a special branch for handling buffers (representing tensor arguments)
725            # for inductor's ExternalFallbackNode
726            # export_extern_kernel_node() is using this function to serialize arguments
727            arg_name = arg.get_name()
728            assert arg_name is not None, "Buffer must have valid name"
729            return Argument.create(as_tensor=TensorArgument(name=arg_name))
730        elif isinstance(arg, torch.SymInt):
731            # This is a special branch for handling SymInt args in inductor's
732            # ExternalFallbackNode.
733            # For regular FX graph, SymInt arg should be a fx.Node with
734            # self.is_sym_int_arg(arg) being true
735            return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
736        elif isinstance(arg, bool):
737            return Argument.create(as_bool=arg)
738        elif isinstance(arg, str):
739            return Argument.create(as_string=arg)
740        elif isinstance(arg, int):
741            return Argument.create(as_int=arg)
742        elif isinstance(arg, float):
743            return Argument.create(as_float=arg)
744        elif arg is None:
745            return Argument.create(as_none=())
746        elif isinstance(arg, (list, tuple)):
747            if len(arg) == 0:
748                if arg_type is not None:
749                    if isinstance(arg_type, torch.OptionalType):
750                        arg_type = arg_type.getElementType()  # type: ignore[assignment]
751                    assert isinstance(arg_type, torch.ListType)
752                    elem_type = arg_type.getElementType()
753                    if isinstance(elem_type, torch.OptionalType):
754                        elem_type = elem_type.getElementType()
755
756                    if isinstance(elem_type, torch.BoolType):
757                        return Argument.create(as_bools=[])
758                    elif isinstance(elem_type, torch.IntType):
759                        return Argument.create(as_ints=[])
760                    elif isinstance(elem_type, torch.FloatType):
761                        return Argument.create(as_floats=[])
762                    elif isinstance(elem_type, torch.StringType):
763                        return Argument.create(as_strings=[])
764                    elif isinstance(elem_type, torch.TensorType):
765                        return Argument.create(as_tensors=[])
766                    else:
767                        # I believe empty symint lists default to ints, but
768                        # please file an issue if this is not the case
769                        raise SerializeError(f"Empty list with type {elem_type} nyi.")
770                else:
771                    # We could serialize this by default to a tensor list. This
772                    # is needed in the HOO case
773                    log.warning(
774                        "Unsure how to serialize the given empty list, "
775                        "as we don't know what is the type of this argument. "
776                        "Serializing it as a tensor list by default."
777                    )
778                    return Argument.create(as_tensors=[])
779
780            # Must check bool first, as bool is also treated as int
781            if all(isinstance(a, bool) for a in arg):
782                return Argument.create(as_bools=list(arg))
783            elif all(isinstance(a, int) for a in arg):
784                return Argument.create(as_ints=list(arg))
785            elif all(isinstance(a, float) for a in arg):
786                return Argument.create(as_floats=list(arg))
787            elif all(isinstance(a, str) for a in arg):
788                return Argument.create(as_strings=list(arg))
789            elif all(isinstance(a, torch.SymInt) for a in arg):
790                # This is a special branch for handling SymInt args in inductor's
791                # ExternalFallbackNode.
792                # For regular FX graph, SymInt arg should be a fx.Node with
793                # self.is_sym_int_arg(arg) being true
794                return Argument.create(
795                    as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg]
796                )
797            elif all(self.is_sym_int_arg(a) for a in arg):
798                # list of sym_ints
799                values = []
800                for a in arg:
801                    if isinstance(a, torch.fx.Node):
802                        values.append(SymIntArgument.create(as_name=a.name))
803                    elif isinstance(a, int):
804                        values.append(SymIntArgument.create(as_int=a))
805                return Argument.create(as_sym_ints=values)
806            elif all(self.is_sym_bool_arg(a) for a in arg):
807                # list of sym_bools
808                values = []
809                for a in arg:
810                    if isinstance(a, torch.fx.Node):
811                        values.append(SymBoolArgument.create(as_name=a.name))
812                    elif isinstance(a, bool):
813                        values.append(SymBoolArgument.create(as_bool=a))
814                return Argument.create(as_sym_bools=values)
815            elif all(isinstance(a, torch.fx.Node) for a in arg):
816                # list of tensors
817                arguments = []
818                for a in arg:
819                    if a.op == "get_attr":
820                        raise SerializeError(
821                            "getattr nodes containing tensors should not appear in the graph"
822                        )
823                    arguments.append(TensorArgument(name=a.name))
824                return Argument.create(as_tensors=arguments)
825            elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg):
826                # list of optional tensors
827                def serialize_optional_tensor_args(a):
828                    if a is None:
829                        return OptionalTensorArgument.create(as_none=())
830                    elif isinstance(a, torch.fx.Node):
831                        return OptionalTensorArgument.create(
832                            as_tensor=TensorArgument(name=a.name)
833                        )
834                    else:
835                        raise SerializeError(f"Unsupported list/tuple argument: {a}")
836
837                return Argument.create(
838                    as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
839                )
840            elif all(isinstance(a, inductor_tensor_buffers) for a in arg):
841                # list of inductor buffers
842                return Argument.create(
843                    as_tensors=[TensorArgument(name=a.get_name()) for a in arg],
844                )
845            elif all(
846                isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg
847            ):
848                # list of inductor buffers as optional tensors
849                def serialize_optional_tensor_args(a):
850                    if a is None:
851                        return OptionalTensorArgument.create(as_none=())
852                    elif isinstance(a, inductor_tensor_buffers):
853                        return OptionalTensorArgument.create(
854                            as_tensor=TensorArgument(name=a.get_name())
855                        )
856                    else:
857                        raise SerializeError(f"Unsupported list/tuple argument: {a}")
858
859                return Argument.create(
860                    as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
861                )
862            else:
863                raise SerializeError(
864                    f"Unsupported list/tuple argument type: {[type(a) for a in arg]}"
865                )
866        elif isinstance(arg, torch.dtype):
867            return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg])
868        elif isinstance(arg, torch.device):
869            return Argument.create(as_device=Device(type=arg.type, index=arg.index))
870        elif isinstance(arg, torch.memory_format):
871            return Argument.create(
872                as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg]
873            )
874        elif isinstance(arg, torch.layout):
875            return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg])
876        elif isinstance(arg, torch._C.ScriptObject):
877            if not (
878                arg._has_method("__getstate__")  # type: ignore[attr-defined]
879                and arg._has_method("__setstate__")  # type: ignore[attr-defined]
880            ):
881                raise SerializeError(
882                    f"Unable to serialize custom class {arg}. Please define "
883                    "serialization methods via def_pickle()."
884                )
885            # Custom objects through torchind are serializable with pickle,
886            # through implementing the .def_pickle function.  This should result
887            # in the object containing a __getstate__ and __setstate__
888            # serialize/deserialize function.
889            custom_obj_name = f"_custom_obj_{len(self.custom_objs)}"
890            self.custom_objs[custom_obj_name] = arg
891            class_fqn = arg._type().qualified_name()  # type: ignore[attr-defined]
892            return Argument.create(
893                as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn)
894            )
895        elif isinstance(arg, torch._ops.OpOverload):
896            return Argument.create(as_operator=self.serialize_operator(arg))
897        else:
898            raise SerializeError(f"Unsupported argument type: {type(arg)}")
899
900    def serialize_tensor_output(self, name, meta_val) -> TensorArgument:
901        assert name not in self.graph_state.tensor_values
902        self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val)
903        return TensorArgument(name=name)
904
905    def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument:
906        assert name not in self.graph_state.sym_int_values
907        self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val)
908        return SymIntArgument.create(as_name=name)
909
910    def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument:
911        assert name not in self.graph_state.sym_bool_values
912        self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val)
913        return SymBoolArgument.create(as_name=name)
914
915    def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
916        if spec.kind == ep.InputKind.USER_INPUT:
917            if isinstance(spec.arg, ep.ConstantArgument):
918                if isinstance(spec.arg.value, int):
919                    constant_spec = ConstantValue.create(as_int=spec.arg.value)
920                elif isinstance(spec.arg.value, bool):
921                    constant_spec = ConstantValue.create(as_bool=spec.arg.value)
922                elif isinstance(spec.arg.value, str):
923                    constant_spec = ConstantValue.create(as_string=spec.arg.value)
924                elif isinstance(spec.arg.value, float):
925                    constant_spec = ConstantValue.create(as_float=spec.arg.value)
926                elif spec.arg.value is None:
927                    constant_spec = ConstantValue.create(as_none=())
928                else:
929                    raise SerializeError(
930                        f"Unhandled constant input {spec.arg.value} to serialize"
931                    )
932                return InputSpec.create(
933                    constant_input=ConstantInputSpec(
934                        name=spec.arg.name, value=constant_spec
935                    )
936                )
937            else:
938                return InputSpec.create(
939                    user_input=UserInputSpec(arg=self.serialize_argument_spec(spec.arg))
940                )
941        elif spec.kind == ep.InputKind.PARAMETER:
942            assert spec.target is not None
943            assert isinstance(spec.arg, ep.TensorArgument)
944            return InputSpec.create(
945                parameter=InputToParameterSpec(
946                    arg=TensorArgument(name=spec.arg.name),
947                    parameter_name=spec.target,
948                )
949            )
950        elif spec.kind == ep.InputKind.BUFFER:
951            assert spec.target is not None
952            assert isinstance(spec.arg, ep.TensorArgument)
953            assert spec.persistent is not None
954            return InputSpec.create(
955                buffer=InputToBufferSpec(
956                    arg=TensorArgument(name=spec.arg.name),
957                    buffer_name=spec.target,
958                    persistent=spec.persistent,
959                )
960            )
961        elif spec.kind == ep.InputKind.CONSTANT_TENSOR:
962            assert spec.target is not None
963            assert isinstance(spec.arg, ep.TensorArgument)
964            return InputSpec.create(
965                tensor_constant=InputToTensorConstantSpec(
966                    arg=TensorArgument(name=spec.arg.name),
967                    tensor_constant_name=spec.target,
968                )
969            )
970        elif spec.kind == ep.InputKind.CUSTOM_OBJ:
971            assert spec.target is not None
972            assert isinstance(spec.arg, ep.CustomObjArgument)
973            return InputSpec.create(
974                custom_obj=InputToCustomObjSpec(
975                    arg=CustomObjArgument(
976                        name=spec.arg.name, class_fqn=spec.arg.class_fqn
977                    ),
978                    custom_obj_name=spec.target,
979                )
980            )
981        elif spec.kind == ep.InputKind.TOKEN:
982            assert isinstance(spec.arg, ep.TokenArgument)
983            return InputSpec.create(
984                token=InputTokenSpec(
985                    arg=TokenArgument(name=spec.arg.name),
986                )
987            )
988        else:
989            raise AssertionError(f"Unknown argument kind: {spec}")
990
991    def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
992        if spec.kind == ep.OutputKind.USER_OUTPUT:
993            return OutputSpec.create(
994                user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg))
995            )
996        elif spec.kind == ep.OutputKind.LOSS_OUTPUT:
997            assert isinstance(spec.arg, ep.TensorArgument)
998            return OutputSpec.create(
999                loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name))
1000            )
1001        elif spec.kind == ep.OutputKind.BUFFER_MUTATION:
1002            assert spec.target is not None
1003            assert isinstance(spec.arg, ep.TensorArgument)
1004            return OutputSpec.create(
1005                buffer_mutation=BufferMutationSpec(
1006                    arg=TensorArgument(name=spec.arg.name),
1007                    buffer_name=spec.target,
1008                )
1009            )
1010        elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER:
1011            assert spec.target is not None
1012            assert isinstance(spec.arg, ep.TensorArgument)
1013            return OutputSpec.create(
1014                gradient_to_parameter=GradientToParameterSpec(
1015                    arg=TensorArgument(name=spec.arg.name),
1016                    parameter_name=spec.target,
1017                )
1018            )
1019        elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT:
1020            assert spec.target is not None
1021            assert isinstance(spec.arg, ep.TensorArgument)
1022            return OutputSpec.create(
1023                gradient_to_user_input=GradientToUserInputSpec(
1024                    arg=TensorArgument(name=spec.arg.name),
1025                    user_input_name=spec.target,
1026                )
1027            )
1028        elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION:
1029            assert spec.target is not None
1030            assert isinstance(spec.arg, ep.TensorArgument)
1031            return OutputSpec.create(
1032                user_input_mutation=UserInputMutationSpec(
1033                    arg=TensorArgument(name=spec.arg.name),
1034                    user_input_name=spec.target,
1035                )
1036            )
1037        elif spec.kind == ep.OutputKind.TOKEN:
1038            assert isinstance(spec.arg, ep.TokenArgument)
1039            return OutputSpec.create(
1040                token=OutputTokenSpec(
1041                    arg=TokenArgument(name=spec.arg.name),
1042                )
1043            )
1044        else:
1045            raise AssertionError(f"Unknown argument kind: {spec}")
1046
1047    def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature:
1048        return GraphSignature(
1049            input_specs=[self.serialize_input_spec(s) for s in sig.input_specs],
1050            output_specs=[self.serialize_output_spec(s) for s in sig.output_specs],
1051        )
1052
1053    def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument:
1054        if isinstance(x, ep.TensorArgument):
1055            return Argument.create(as_tensor=TensorArgument(name=x.name))
1056        elif isinstance(x, ep.SymIntArgument):
1057            return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
1058        elif isinstance(x, ep.ConstantArgument):
1059            return self.serialize_input(x.value)
1060        elif isinstance(x, ep.CustomObjArgument):
1061            return Argument.create(
1062                as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn)
1063            )
1064        else:
1065            raise AssertionError("TODO")
1066
1067    def serialize_module_call_signature(
1068        self, module_call_signature: ep.ModuleCallSignature
1069    ) -> ModuleCallSignature:
1070        return ModuleCallSignature(
1071            inputs=[
1072                self.serialize_argument_spec(x) for x in module_call_signature.inputs
1073            ],
1074            outputs=[
1075                self.serialize_argument_spec(x) for x in module_call_signature.outputs
1076            ],
1077            in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
1078            out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
1079        )
1080
1081    def serialize_module_call_graph(
1082        self, module_call_graph: List[ep.ModuleCallEntry]
1083    ) -> List[ModuleCallEntry]:
1084        return [
1085            ModuleCallEntry(
1086                fqn=entry.fqn,
1087                signature=(
1088                    self.serialize_module_call_signature(entry.signature)
1089                    if entry.signature
1090                    else None
1091                ),
1092            )
1093            for entry in module_call_graph
1094        ]
1095
1096    def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
1097        """For a given node, return the dataclass representing its output values.
1098
1099        [NOTE: Multiple outputs] We handle aggregates differently than FX. For
1100        FX, it looks like:
1101
1102            x = call_function("multiple_return", ...)
1103            element0 = call_function(getitem, x, 0)
1104            foo = call_function("use_output", element0)
1105
1106        We do not want the intermediate `getitem` call, so our serialized thing looks like:
1107
1108            element0, element1, element2 = call_function("multiple_return", ...)
1109            foo = call_function("use_output", element0)
1110
1111        We want names to be consistent across these two schemes, so that we can
1112        mostly reuse the names coming from FX. This function computes a mapping from
1113        the FX representation to our representation, preserving the names.
1114        """
1115        assert node.op == "call_function" and isinstance(
1116            node.target, torch._ops.OpOverload
1117        )
1118
1119        assert isinstance(node.target, torch._ops.OpOverload)
1120        returns = node.target._schema.returns
1121
1122        if len(returns) == 0:
1123            return []
1124
1125        meta_val = node.meta["val"]
1126
1127        # Check single value return
1128        if _is_single_tensor_list_return(node.target):
1129            # e.g "-> Tensor[]"
1130            tensor_args = []
1131            for idx, meta in enumerate(meta_val):
1132                user_node = _output_node_at_index(node, idx)
1133                name = (
1134                    user_node.name
1135                    if user_node is not None
1136                    else f"{node.name}_unused_{idx}"
1137                )
1138                tensor_args.append(self.serialize_tensor_output(name, meta))
1139            return [Argument.create(as_tensors=tensor_args)]
1140        elif len(returns) == 1:
1141            return [self.serialize_output(node.name, meta_val)]
1142
1143        # There are a two possibilities at this point:
1144        # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)"
1145        # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])"
1146        #
1147        # Either way, start by gathering a list of TensorArguments with the correct names.
1148        # For consistent naming with FX, consult the downstream `getitem` node and
1149        # make sure our outputs have the same name.
1150
1151        output_arguments = []
1152        for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)):
1153            if meta is None:
1154                assert isinstance(
1155                    return_schema.real_type, (torch.OptionalType, torch.TensorType)
1156                )
1157                # When the return type is annoated as Tensor type, the op can also return an
1158                # undefined Tensor which will be implicitly converted to None in Python.
1159                output_arguments.append(Argument.create(as_none=()))
1160            elif isinstance(meta, FakeTensor):
1161                assert isinstance(
1162                    return_schema.real_type, (torch.OptionalType, torch.TensorType)
1163                )
1164                user_node = _output_node_at_index(node, idx)
1165                name = (
1166                    user_node.name
1167                    if user_node is not None
1168                    else f"{node.name}_unused_{idx}"
1169                )
1170                output_arguments.append(self.serialize_output(name, meta))
1171            elif isinstance(meta, list):
1172                # for List[Tensor] return type
1173                assert isinstance(
1174                    return_schema.real_type, torch.ListType
1175                ) and isinstance(
1176                    return_schema.real_type.getElementType(), torch.TensorType
1177                )
1178                user_node = _output_node_at_index(node, idx)
1179                assert user_node is not None
1180
1181                args = []
1182                for i, m in enumerate(meta):
1183                    if m is None:
1184                        continue
1185                    sub_user_node = _output_node_at_index(user_node, i)
1186                    assert sub_user_node is not None, f"No user found at index {i}"
1187
1188                    args.append(self.serialize_tensor_output(sub_user_node.name, m))
1189                output_arguments.append(Argument.create(as_tensors=args))
1190            elif isinstance(meta, (int, SymInt)):
1191                user_node = _output_node_at_index(node, idx)
1192                name = (
1193                    user_node.name
1194                    if user_node is not None
1195                    else f"{node.name}_unused_{idx}"
1196                )
1197                output_arguments.append(self.serialize_output(name, meta))
1198            else:
1199                raise ValueError(
1200                    f"Unhandled output type {type(meta)} from node {node.format_node()}"
1201                )
1202
1203        return output_arguments
1204
1205    def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]:
1206        """
1207        For serializing HOO outputs since HOOs do not have a schema.
1208        """
1209        meta_val = node.meta["val"]
1210
1211        if isinstance(meta_val, tuple):
1212            # Note: Since we don't have a schema, we just serialize all tuple
1213            # outputs to be a list of values. Even if the output is supposed to
1214            # be a tensor list (Tensor[]), we will serialize it to be a list of
1215            # tensors (Tensor, Tensor, Tensor). An exception is that if there's
1216            # a singleton tensor, we will serialize this to be a singleton
1217            # tensor list so that the deserializer knows to insert getitem nodes.
1218
1219            if len(meta_val) == 1:
1220                assert isinstance(meta_val[0], torch.Tensor)
1221                user_node = _output_node_at_index(node, 0)
1222                name = (
1223                    user_node.name if user_node is not None else f"{node.name}_unused_0"
1224                )
1225                return [
1226                    Argument.create(
1227                        as_tensors=[self.serialize_tensor_output(name, meta_val[0])]
1228                    )
1229                ]
1230
1231            outputs = []
1232            for i, element_meta_val in enumerate(meta_val):
1233                user_node = _output_node_at_index(node, i)
1234                if isinstance(element_meta_val, list):
1235                    # e.g "-> Tensor[]"
1236                    assert user_node is not None
1237
1238                    tensors = []
1239                    for j, m in enumerate(element_meta_val):
1240                        if not isinstance(m, torch.Tensor):
1241                            raise SerializeError(
1242                                f"Serialize list output with type {type(m)} nyi"
1243                            )
1244
1245                        sub_user_node = _output_node_at_index(user_node, j)
1246                        name = (
1247                            sub_user_node.name
1248                            if sub_user_node is not None
1249                            else f"{user_node.name}_unused_{j}"
1250                        )
1251                        tensors.append(self.serialize_tensor_output(name, m))
1252                    outputs.append(Argument.create(as_tensors=tensors))
1253
1254                else:
1255                    name = (
1256                        user_node.name
1257                        if user_node is not None
1258                        else f"{node.name}_unused_{i}"
1259                    )
1260
1261                    outputs.append(self.serialize_output(name, element_meta_val))
1262
1263            return outputs
1264        else:
1265            return [self.serialize_output(node.name, meta_val)]
1266
1267    def serialize_output(self, name: str, meta_val: Any) -> Argument:
1268        # Check single value return
1269        if meta_val is None:
1270            return Argument.create(as_none=())
1271        if isinstance(meta_val, torch.Tensor):
1272            # e.g "-> Tensor"
1273            return Argument.create(
1274                as_tensor=self.serialize_tensor_output(name, meta_val)
1275            )
1276        elif isinstance(meta_val, (int, torch.SymInt)):
1277            # e.g "-> SymInt"
1278            return Argument.create(
1279                as_sym_int=self.serialize_sym_int_output(name, meta_val)
1280            )
1281        elif isinstance(meta_val, torch.SymBool):
1282            # e.g "-> SymBool"
1283            return Argument.create(
1284                as_sym_bool=self.serialize_sym_bool_output(name, meta_val)
1285            )
1286
1287        # list outputs should've been handled earlier
1288        raise SerializeError(f"Unable to serialize output {meta_val}")
1289
1290    def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]:
1291        meta_val = node.meta["val"]
1292
1293        idx_to_name = {}
1294        for user in node.users:
1295            assert (
1296                user.target is operator.getitem
1297            ), f"User node {user} of {node} is incorrect"
1298            idx_to_name[user.args[1]] = user.name
1299
1300        for idx, _ in enumerate(meta_val):
1301            # FX does not emit a getitem node for any outputs that are unused.
1302            # However, we need a name for them so that the number of outputs will
1303            # correctly match the schema. Just assign a dummy name.
1304            if idx not in idx_to_name:
1305                idx_to_name[idx] = f"{node.name}_unused_{idx}"
1306
1307        arg_list = []
1308        for i, element_meta_val in enumerate(meta_val):
1309            arg_list.append(
1310                self.serialize_tensor_output(idx_to_name[i], element_meta_val)
1311            )
1312
1313        return arg_list
1314
1315    def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph:
1316        assert isinstance(graph_module, torch.fx.GraphModule)
1317        for node in graph_module.graph.nodes:
1318            try:
1319                getattr(self, f"handle_{node.op}")(node)
1320            except Exception as e:
1321                raise SerializeError(
1322                    f"Failed serializing node {node} in graph: {node.format_node()}"
1323                ) from e
1324
1325        return Graph(
1326            inputs=self.graph_state.inputs,
1327            nodes=self.graph_state.nodes,
1328            tensor_values=self.graph_state.tensor_values,
1329            sym_int_values=self.graph_state.sym_int_values,
1330            sym_bool_values=self.graph_state.sym_bool_values,
1331            custom_obj_values=self.graph_state.custom_obj_values,
1332            outputs=self.graph_state.outputs,
1333            is_single_tensor_return=self.graph_state.is_single_tensor_return,
1334        )
1335
1336    def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule:
1337        graph = self.serialize_graph(graph_module)
1338
1339        return GraphModule(
1340            graph=graph,
1341            signature=self.serialize_signature(self.graph_signature),
1342            module_call_graph=self.serialize_module_call_graph(self.module_call_graph),
1343        )
1344
1345
1346class ExportedProgramSerializer:
1347    def __init__(self, opset_version: Optional[Dict[str, int]] = None):
1348        self.opset_version: Dict[str, int] = {}
1349        if opset_version:
1350            self.opset_version.update(opset_version)
1351        if "aten" not in self.opset_version:
1352            self.opset_version["aten"] = torch._C._get_max_operator_version()
1353
1354    def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram:
1355        """
1356        Args:
1357            exported_program: Exported Program to serialize
1358        """
1359        exported_program._validate()
1360
1361        gm_serializer = GraphModuleSerializer(
1362            exported_program.graph_signature, exported_program.module_call_graph
1363        )
1364        serialized_graph_module = gm_serializer.serialize(exported_program.graph_module)
1365        serialized_range_constraints = serialize_range_constraints(
1366            exported_program.range_constraints
1367        )
1368
1369        # TODO: Directly serialize exported_program.constants once
1370        # CustomClassHolders get stored in the ExportedProgram rather than in
1371        # the graph
1372        constants = {}
1373        for n, c in gm_serializer.custom_objs.items():
1374            constants[n] = c
1375        for n, t in exported_program.constants.items():
1376            assert n not in constants
1377            constants[n] = t
1378
1379        additional_kwargs = {}
1380        if hasattr(exported_program, "verifiers"):
1381            additional_kwargs["verifiers"] = [
1382                v.dialect for v in exported_program.verifiers
1383            ]
1384        elif hasattr(exported_program, "dialect"):
1385            additional_kwargs["dialect"] = exported_program.dialect
1386        serialized_ep = ExportedProgram(
1387            graph_module=serialized_graph_module,
1388            opset_version=self.opset_version,
1389            range_constraints=serialized_range_constraints,
1390            schema_version=SchemaVersion(
1391                major=SCHEMA_VERSION[0],
1392                minor=SCHEMA_VERSION[1],
1393            ),
1394            **additional_kwargs,
1395        )
1396
1397        # Test canonical form is well defined.
1398        canonicalize(serialized_ep)
1399
1400        return _SerializedProgram(
1401            serialized_ep,
1402            serialize_torch_artifact(exported_program.state_dict),
1403            serialize_torch_artifact(constants),
1404            serialize_torch_artifact(exported_program.example_inputs),
1405        )
1406
1407
1408class GraphModuleDeserializer:
1409    @dataclasses.dataclass
1410    class Result:
1411        graph_module: torch.fx.GraphModule
1412        signature: ep.ExportGraphSignature
1413        module_call_graph: List[ep.ModuleCallEntry]
1414        names_to_symbols: Dict[str, sympy.Symbol]
1415        state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
1416        constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]]
1417        example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]]
1418
1419    def __init__(self):
1420        self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
1421        self.serialized_name_to_meta: Dict[str, MetaType] = {}
1422        self.graph = torch.fx.Graph()
1423        self.module = torch.nn.Module()
1424
1425    @contextmanager
1426    def save_graph_module(self) -> Iterator[None]:
1427        saved = (
1428            self.graph,
1429            self.module,
1430            self.serialized_name_to_node,
1431            self.serialized_name_to_meta,
1432        )
1433        self.graph = torch.fx.Graph()
1434        self.module = torch.nn.Module()
1435        self.serialized_name_to_node = {}
1436        self.serialized_name_to_meta = {}
1437        try:
1438            yield
1439        finally:
1440            (
1441                self.graph,
1442                self.module,
1443                self.serialized_name_to_node,
1444                self.serialized_name_to_meta,
1445            ) = saved
1446
1447    def deserialize_operator(self, serialized_target: str):
1448        if serialized_target.startswith(
1449            "_operator"
1450        ):  # TODO(zhxchen17) Follow up on this.
1451            module = operator
1452            serialized_target_names = serialized_target.split(".")[1:]
1453        elif serialized_target.startswith("torch"):
1454            module = torch  # type: ignore[misc]
1455            serialized_target_names = serialized_target.split(".")[1:]
1456        else:  # TODO(zhxchen17) Don't catch all here.
1457            return serialized_target
1458
1459        target = module
1460        for name in serialized_target_names:
1461            if not hasattr(target, name):
1462                return serialized_target
1463            else:
1464                target = getattr(target, name)
1465        return target
1466
1467    def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
1468        val = s.value
1469        if s.type == "as_expr":
1470            if val.hint is None:
1471                hint = None
1472            else:
1473                assert val.hint.type == "as_int"
1474                hint = val.hint.value
1475
1476            if val.expr_str in self.symbol_name_to_symbol:
1477                sym = self.symbol_name_to_symbol[val.expr_str]
1478                if (
1479                    isinstance(sym, sympy.Symbol)
1480                    and sym not in self.shape_env.var_to_val
1481                ):
1482                    if hint is not None:
1483                        self.shape_env.add_var_to_val(sym, hint)
1484            else:
1485                sym = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol)
1486                # NOTE(avik): Assumptions on symbols are not explicitly serialized.
1487                # This seems dangerous: it might cause unknown differences in shape env behavior
1488                # on deserialization? Probably deserves a follow-up.
1489
1490                # Here we force symbols corresponding to SymInts to be at least integers.
1491                # Otherwise some expressions that the shape env would otherwise evaluate to False,
1492                # e.g., 2*s = 9, can have rational solutions, e.g., 9/2.
1493                sym = sym.subs(
1494                    {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols}
1495                )
1496                if isinstance(sym, sympy.Symbol):
1497                    self.symbol_name_to_symbol[val.expr_str] = sym
1498                    if hint is not None:
1499                        self.shape_env.add_var_to_val(sym, hint)
1500
1501                    if vr := self.symbol_name_to_range.get(val.expr_str):
1502                        self.shape_env.constrain_symbol_range(
1503                            sym,
1504                            compiler_min=vr.lower,  # type: ignore[arg-type]
1505                            compiler_max=vr.upper,  # type: ignore[arg-type]
1506                        )
1507                else:
1508                    # Placeholders, in particular, can have shapes as symbolic expressions.
1509                    # We need to populate the shape env with the range constraints of their
1510                    # free symbols, otherwise evaluating such expressions will error.
1511                    self.symbol_name_to_symbol[val.expr_str] = sym
1512                    free_symbols = sym.free_symbols
1513                    for s in free_symbols:
1514                        if s.name not in self.symbol_name_to_symbol:
1515                            self.symbol_name_to_symbol[s.name] = s
1516                        if vr := self.symbol_name_to_range.get(s.name):
1517                            self.shape_env.constrain_symbol_range(
1518                                s,
1519                                compiler_min=vr.lower,  # type: ignore[arg-type]
1520                                compiler_max=vr.upper,  # type: ignore[arg-type]
1521                            )
1522
1523            return self.shape_env.create_symintnode(sym, hint=hint)
1524        elif s.type == "as_int":
1525            assert isinstance(val, int)
1526            return val
1527        else:
1528            raise SerializeError(
1529                f"SymInt has invalid field type {s.type} with value {s.value}"
1530            )
1531
1532    def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]:
1533        val = s.value
1534        if s.type == "as_expr":
1535            expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol)
1536            return self.shape_env.create_symboolnode(expr)
1537        elif s.type == "as_bool":
1538            assert isinstance(val, bool)
1539            return val
1540        else:
1541            raise SerializeError(
1542                f"SymBool has invalid field type {s.type} with value {s.value}"
1543            )
1544
1545    def deserialize_tensor_meta(
1546        self,
1547        tensor_meta: TensorMeta,
1548    ) -> FakeTensor:
1549        with self.fake_tensor_mode:
1550            return cast(
1551                FakeTensor,
1552                torch.empty_strided(
1553                    tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes),  # type: ignore[misc]
1554                    tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides),  # type: ignore[misc]
1555                    device=deserialize_device(tensor_meta.device),
1556                    dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype],
1557                ),
1558            )
1559
1560    def deserialize_script_obj_meta(
1561        self, script_obj_meta: CustomObjArgument
1562    ) -> ep.CustomObjArgument:
1563        return ep.CustomObjArgument(
1564            name=script_obj_meta.name,
1565            class_fqn=script_obj_meta.class_fqn,
1566        )
1567
1568    def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]:
1569        if output.type == "as_tensor":
1570            return self.serialized_name_to_node[output.as_tensor.name]
1571        elif output.type == "as_sym_int":
1572            return self.serialized_name_to_node[output.as_sym_int.as_name]
1573        elif output.type == "as_sym_bool":
1574            return self.serialized_name_to_node[output.as_sym_bool.as_name]
1575        elif output.type == "as_int":
1576            return output.as_int
1577        elif output.type == "as_none":
1578            return None
1579        else:
1580            raise SerializeError(f"Unable to deserialize output node {output}")
1581
1582    def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
1583        # Handle the tensor metas.
1584        for name, tensor_value in serialized_graph.tensor_values.items():
1585            meta_val = self.deserialize_tensor_meta(tensor_value)
1586            self.serialized_name_to_meta[name] = meta_val
1587
1588        for name, sym_int_value in serialized_graph.sym_int_values.items():
1589            self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value)
1590
1591        for name, sym_bool_value in serialized_graph.sym_bool_values.items():
1592            self.serialized_name_to_meta[name] = self.deserialize_sym_bool(
1593                sym_bool_value
1594            )
1595
1596        for name, script_obj_meta in serialized_graph.custom_obj_values.items():
1597            self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta(
1598                script_obj_meta
1599            )
1600
1601        # Inputs: convert to placeholder nodes in FX.
1602        for i, input_ in enumerate(serialized_graph.inputs):
1603            if input_.type in ("as_tensor", "as_sym_int", "as_custom_obj"):
1604                if input_.type == "as_sym_int":
1605                    node_name = input_.value.as_name
1606                else:
1607                    node_name = input_.value.name
1608                placeholder_node = self.graph.placeholder(node_name)
1609                # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments)
1610                # we will overwrite it
1611                placeholder_node.name = node_name
1612                self.sync_fx_node(node_name, placeholder_node)
1613            elif input_.type in (
1614                "as_int",
1615                "as_float",
1616                "as_bool",
1617                "as_none",
1618                "as_string",
1619            ):
1620                node_name = self.signature.input_specs[i].arg.name
1621                placeholder_node = self.graph.placeholder(node_name)
1622                placeholder_node.meta["val"] = self.deserialize_input(input_)
1623            else:
1624                raise SerializeError(f"Invalid input type {input_}")
1625
1626        # Nodes: convert to call_function nodes.
1627        for serialized_node in serialized_graph.nodes:
1628            try:
1629                target = self.deserialize_operator(serialized_node.target)
1630                self.deserialize_node(serialized_node, target)
1631
1632            except Exception as e:
1633                raise SerializeError(
1634                    f"Failed deserializing node {serialized_node}"
1635                ) from e
1636
1637        # Outputs: convert to a single `output` node.
1638        outputs = []
1639        for output in serialized_graph.outputs:
1640            outputs.append(self.deserialize_graph_output(output))
1641
1642        if serialized_graph.is_single_tensor_return:
1643            assert len(outputs) == 1
1644            outputs = outputs[0]  # type: ignore[assignment]
1645        else:
1646            outputs = tuple(outputs)  # type: ignore[assignment]
1647
1648        output_node = self.graph.output(outputs)
1649
1650        if serialized_graph.is_single_tensor_return:
1651            output_node.meta["val"] = output_node.args[0].meta["val"]
1652        else:
1653            output_node.meta["val"] = tuple(
1654                arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
1655                for arg in output_node.args[0]
1656            )
1657
1658        return self.graph
1659
1660    def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
1661        if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS:
1662            name = serialized_node.outputs[0].value.as_name
1663            args = self.deserialize_sym_op_inputs(serialized_node.inputs)
1664
1665            fx_node = self.graph.create_node("call_function", target, args, {}, name)
1666            self.deserialize_sym_op_outputs(serialized_node, fx_node)
1667
1668        elif isinstance(target, torch._ops.HigherOrderOperator):
1669            args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs)
1670            # If HOP returns a single tensor, name the
1671            # newly-created node after it. This ensures that these tensor values
1672            # have names that are consistent with serialized.
1673            #
1674            # HOPs don't have schema yet, just check the output lengths and as_tensor attribute
1675            name = (
1676                serialized_node.outputs[0].as_tensor.name
1677                if len(serialized_node.outputs) == 1
1678                and hasattr(serialized_node.outputs[0], "as_tensor")
1679                else None
1680            )
1681            fx_node = self.graph.create_node(
1682                "call_function", target, args, kwargs, name
1683            )
1684            self.deserialize_outputs(serialized_node, fx_node)
1685            fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
1686
1687        elif isinstance(target, torch._ops.OpOverload):
1688            # For convenience: if this node returns a single tensor, name the
1689            # newly-created node after it. This ensures that these tensor values
1690            # have names that are consistent with serialized.
1691            name = (
1692                serialized_node.outputs[0].as_tensor.name
1693                if _is_single_tensor_return(target)
1694                else None  # FX will generate a name for us.
1695            )
1696            args, kwargs = self.deserialize_inputs(target, serialized_node)
1697            fx_node = self.graph.create_node(
1698                "call_function", target, args, kwargs, name
1699            )
1700            self.deserialize_outputs(serialized_node, fx_node)
1701        else:
1702            raise SerializeError(
1703                f"Unsupported target type for node {serialized_node}: {target}"
1704            )
1705
1706        fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
1707        if (
1708            fx_node.op not in ["placeholder", "output"]
1709            and "nn_module_stack" not in fx_node.meta
1710        ):
1711            fx_node.meta["nn_module_stack"] = (
1712                {}
1713            )  # serialization throws away empty dicts
1714
1715    def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec:
1716        if i.type == "user_input":
1717            return ep.InputSpec(
1718                kind=ep.InputKind.USER_INPUT,
1719                arg=self.deserialize_argument_spec(i.user_input.arg),
1720                target=None,
1721            )
1722        elif i.type == "parameter":
1723            return ep.InputSpec(
1724                kind=ep.InputKind.PARAMETER,
1725                arg=ep.TensorArgument(name=i.parameter.arg.name),
1726                target=i.parameter.parameter_name,
1727            )
1728        elif i.type == "buffer":
1729            return ep.InputSpec(
1730                kind=ep.InputKind.BUFFER,
1731                arg=ep.TensorArgument(name=i.buffer.arg.name),
1732                target=i.buffer.buffer_name,
1733                persistent=i.buffer.persistent,
1734            )
1735        elif i.type == "tensor_constant":
1736            return ep.InputSpec(
1737                kind=ep.InputKind.CONSTANT_TENSOR,
1738                arg=ep.TensorArgument(name=i.tensor_constant.arg.name),
1739                target=i.tensor_constant.tensor_constant_name,
1740            )
1741        elif i.type == "custom_obj":
1742            return ep.InputSpec(
1743                kind=ep.InputKind.CUSTOM_OBJ,
1744                arg=ep.CustomObjArgument(
1745                    name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn
1746                ),
1747                target=i.custom_obj.custom_obj_name,
1748            )
1749        elif i.type == "token":
1750            return ep.InputSpec(
1751                kind=ep.InputKind.TOKEN,
1752                arg=ep.TokenArgument(name=i.token.arg.name),
1753                target=None,
1754            )
1755        elif i.type == "constant_input":
1756            return ep.InputSpec(
1757                kind=ep.InputKind.USER_INPUT,
1758                arg=ep.ConstantArgument(
1759                    name=i.constant_input.name,
1760                    value=self.deserialize_constant_input(i.constant_input.value),
1761                ),
1762                target=None,
1763            )
1764        else:
1765            raise AssertionError(f"Unknown input spec {i}")
1766
1767    def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec:
1768        if o.type == "user_output":
1769            return ep.OutputSpec(
1770                kind=ep.OutputKind.USER_OUTPUT,
1771                arg=self.deserialize_argument_spec(o.user_output.arg),
1772                target=None,
1773            )
1774        elif o.type == "loss_output":
1775            return ep.OutputSpec(
1776                kind=ep.OutputKind.LOSS_OUTPUT,
1777                arg=ep.TensorArgument(name=o.loss_output.arg.name),
1778                target=None,
1779            )
1780        elif o.type == "buffer_mutation":
1781            return ep.OutputSpec(
1782                kind=ep.OutputKind.BUFFER_MUTATION,
1783                arg=ep.TensorArgument(name=o.buffer_mutation.arg.name),
1784                target=o.buffer_mutation.buffer_name,
1785            )
1786        elif o.type == "gradient_to_parameter":
1787            return ep.OutputSpec(
1788                kind=ep.OutputKind.GRADIENT_TO_PARAMETER,
1789                arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name),
1790                target=o.gradient_to_parameter.parameter_name,
1791            )
1792        elif o.type == "gradient_to_user_input":
1793            return ep.OutputSpec(
1794                kind=ep.OutputKind.GRADIENT_TO_USER_INPUT,
1795                arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name),
1796                target=o.gradient_to_user_input.user_input_name,
1797            )
1798        elif o.type == "user_input_mutation":
1799            return ep.OutputSpec(
1800                kind=ep.OutputKind.USER_INPUT_MUTATION,
1801                arg=ep.TensorArgument(name=o.user_input_mutation.arg.name),
1802                target=o.user_input_mutation.user_input_name,
1803            )
1804        elif o.type == "token":
1805            return ep.OutputSpec(
1806                kind=ep.OutputKind.TOKEN,
1807                arg=ep.TokenArgument(name=o.token.arg.name),
1808                target=None,
1809            )
1810        else:
1811            raise AssertionError(f"Unknown output spec {o}")
1812
1813    def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature:
1814        return ep.ExportGraphSignature(
1815            input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs],
1816            output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs],
1817        )
1818
1819    def deserialize(
1820        self,
1821        serialized_graph_module: GraphModule,
1822        serialized_state_dict: Union[Dict[str, torch.Tensor], bytes],
1823        constants: Union[Dict[str, Any], bytes],
1824        example_inputs: Optional[
1825            Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]
1826        ] = None,
1827        symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
1828    ) -> Result:
1829        global _CURRENT_DESERIALIZER
1830        current_deserializer_state = _CURRENT_DESERIALIZER.copy()
1831        _CURRENT_DESERIALIZER.append(self)
1832        try:
1833            self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True)
1834            self.fake_tensor_mode = FakeTensorMode(
1835                allow_fallback_kernels=False,
1836                allow_non_fake_inputs=True,
1837                shape_env=self.shape_env,
1838            )
1839            self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
1840            self.constants = deserialize_torch_artifact(constants)
1841            self.signature = self.deserialize_signature(
1842                serialized_graph_module.signature
1843            )
1844
1845            # deserialization does analysis with checks on 0/1, so we create fake range constraints and
1846            # restore the original range constraints afterwards
1847            self.symbol_name_to_range = {}
1848            if symbol_name_to_range:
1849                for k, vr in symbol_name_to_range.items():
1850                    if math.isinf(vr.lower) and vr.lower < 0:
1851                        lower = -math.inf
1852                    elif math.isinf(vr.lower):
1853                        lower = math.inf
1854                    else:
1855                        lower = int(vr.lower)
1856
1857                    if vr.upper >= 2:  # max is >= 2, not sym bool range
1858                        lower = max(2, lower)
1859                    self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(
1860                        _int_to_sympy_int(lower), vr.upper
1861                    )
1862
1863            if example_inputs is not None and len(example_inputs) > 0:
1864                self.example_inputs = deserialize_torch_artifact(example_inputs)
1865            else:
1866                self.example_inputs = None
1867            self.deserialize_graph(serialized_graph_module.graph)
1868
1869            module_call_graph = self.deserialize_module_call_graph(
1870                serialized_graph_module.module_call_graph
1871            )
1872            return GraphModuleDeserializer.Result(
1873                graph_module=ep._create_graph_module_for_export(
1874                    self.module, self.graph
1875                ),
1876                signature=self.signature,
1877                module_call_graph=module_call_graph,
1878                names_to_symbols=self.symbol_name_to_symbol,
1879                state_dict=deserialize_torch_artifact(serialized_state_dict),
1880                constants=self.constants,
1881                example_inputs=self.example_inputs,
1882            )
1883        finally:
1884            _CURRENT_DESERIALIZER.pop()
1885            assert current_deserializer_state == _CURRENT_DESERIALIZER
1886
1887    def sync_fx_node(self, name: str, fx_node: torch.fx.Node):
1888        if name in self.serialized_name_to_node:
1889            raise SerializeError(f"Node {name} has already been deserialized before.")
1890        self.serialized_name_to_node[name] = fx_node
1891        assert "val" not in fx_node.meta
1892        fx_node.meta["val"] = self.serialized_name_to_meta[name]
1893
1894    def deserialize_sym_op_inputs(self, inputs):
1895        return tuple(self.deserialize_input(input.arg) for input in inputs)
1896
1897    def deserialize_inputs(self, target: torch._ops.OpOverload, serialized_node: Node):
1898        schema_args = target._schema.arguments
1899        actual_args = {
1900            input.name: self.deserialize_input(input.arg)
1901            for input in serialized_node.inputs
1902        }
1903        args = []
1904        kwargs = {}
1905        for schema_arg in schema_args:
1906            is_positional = (
1907                not schema_arg.has_default_value() and not schema_arg.kwarg_only
1908            )
1909            if is_positional:
1910                args.append(actual_args[schema_arg.name])
1911            else:
1912                if schema_arg.name in actual_args:
1913                    kwargs[schema_arg.name] = actual_args[schema_arg.name]
1914        return tuple(args), kwargs
1915
1916    def deserialize_hoo_inputs(self, inputs: List[NamedArgument]):
1917        """
1918        For deserializing HOO inputs since HOOs do not have a schema.
1919        """
1920        args = []
1921        kwargs = {}
1922        for input_ in inputs:
1923            if input_.name != "":
1924                kwargs[input_.name] = self.deserialize_input(input_.arg)
1925            else:
1926                args.append(self.deserialize_input(input_.arg))
1927        return (tuple(args), kwargs)
1928
1929    def deserialize_input(self, inp: Argument) -> Any:
1930        value = inp.value
1931        typ_ = inp.type
1932        if typ_ == "as_none":
1933            # None should converted as None, but is encoded as bool in serialized
1934            # Convert serialized object to torch equivalent
1935            return None
1936        elif typ_ == "as_tensor":
1937            return self.serialized_name_to_node[inp.as_tensor.name]
1938        elif typ_ == "as_scalar_type":
1939            return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type]
1940        elif typ_ == "as_memory_format":
1941            return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format]
1942        elif typ_ == "as_layout":
1943            return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout]
1944        elif typ_ == "as_graph":
1945            assert isinstance(value, GraphArgument)
1946            with self.save_graph_module():
1947                self.deserialize_graph(value.graph)
1948                submodule = ep._create_graph_module_for_export(self.module, self.graph)
1949            self.module.register_module(value.name, submodule)
1950            return self.graph.create_node(
1951                "get_attr",
1952                value.name,
1953                name=value.name,
1954            )
1955        elif typ_ == "as_device":
1956            return deserialize_device(inp.as_device)
1957        elif typ_ == "as_int":
1958            return inp.as_int
1959        elif typ_ == "as_float":
1960            return inp.as_float
1961        elif typ_ == "as_bool":
1962            return inp.as_bool
1963        elif typ_ == "as_string":
1964            return inp.as_string
1965        elif typ_ == "as_sym_int":
1966            return self.deserialize_sym_argument(inp.as_sym_int)
1967        elif typ_ == "as_sym_bool":
1968            return self.deserialize_sym_argument(inp.as_sym_bool)
1969        elif isinstance(value, list):
1970            if len(value) == 0:
1971                return []
1972            elif typ_ == "as_tensors":
1973                result = []
1974                for arg in value:
1975                    result.append(self.serialized_name_to_node[arg.name])
1976                return result
1977            elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
1978                # convert from serialized.python.types.List to python list
1979                return list(value)
1980            elif typ_ in ("as_sym_ints", "as_sym_bools"):
1981                return [self.deserialize_sym_argument(arg) for arg in value]
1982            elif typ_ == "as_optional_tensors":
1983
1984                def deserialize_optional_tensor_args(a):
1985                    if a.type == "as_none":
1986                        return None
1987                    elif a.type == "as_tensor":
1988                        return self.serialized_name_to_node[a.value.name]
1989                    else:
1990                        raise SerializeError(f"Unhandled argument {inp}")
1991
1992                return list(map(deserialize_optional_tensor_args, value))
1993            else:
1994                raise SerializeError(f"Unhandled argument {inp}")
1995        elif typ_ == "as_custom_obj":
1996            if inp.as_custom_obj.name in self.serialized_name_to_node:
1997                # Custom object has been lifted as an input
1998                return self.serialized_name_to_node[inp.as_custom_obj.name]
1999            return self.constants[inp.as_custom_obj.name]
2000        elif typ_ == "as_operator":
2001            return self.deserialize_operator(inp.as_operator)
2002        else:
2003            raise SerializeError(f"Unhandled argument {inp}")
2004
2005    def deserialize_constant_input(self, inp: ConstantValue) -> Any:
2006        if inp.type == "as_int":
2007            return int(inp.as_int)
2008        elif inp.type == "as_float":
2009            return float(inp.as_float)
2010        elif inp.type == "as_string":
2011            return str(inp.as_string)
2012        elif inp.type == "as_bool":
2013            return bool(inp.as_bool)
2014        elif inp.type == "as_none":
2015            return None
2016        else:
2017            raise SerializeError(f"Unhandled constant argument {inp} to deserialize")
2018
2019    def deserialize_sym_argument(self, sym_arg):
2020        if isinstance(sym_arg, SymIntArgument):
2021            if sym_arg.type == "as_int":
2022                return sym_arg.as_int
2023            elif sym_arg.type == "as_name":
2024                return self.serialized_name_to_node[sym_arg.as_name]
2025        elif isinstance(sym_arg, SymBoolArgument):
2026            if sym_arg.type == "as_bool":
2027                return sym_arg.as_bool
2028            elif sym_arg.type == "as_name":
2029                return self.serialized_name_to_node[sym_arg.as_name]
2030        raise SerializeError(f"Unknown symbolic argument type: {sym_arg}")
2031
2032    def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
2033        self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
2034
2035    def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
2036        # Check single value return
2037        if len(serialized_node.outputs) == 0:
2038            return
2039        if (
2040            len(serialized_node.outputs) == 1
2041            and serialized_node.outputs[0].type == "as_tensor"
2042        ):
2043            self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
2044            return
2045        elif len(serialized_node.outputs) == 1 and isinstance(
2046            serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)
2047        ):
2048            self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
2049            return
2050
2051        self.deserialize_multiple_outputs(serialized_node, fx_node)
2052
2053    def deserialize_multiple_outputs(
2054        self, serialized_node: Node, fx_node: torch.fx.Node
2055    ) -> None:
2056        deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
2057
2058        def generate_getitem(
2059            meta_val,
2060            fx_node: torch.fx.Node,
2061            arg: Union[TensorArgument, SymIntArgument],
2062            idx: int,
2063        ):
2064            if isinstance(arg, TensorArgument):
2065                name = arg.name
2066            elif isinstance(arg, SymIntArgument):
2067                name = arg.as_name
2068            else:
2069                raise AssertionError(
2070                    f"generate_getitem got unknown argument type {type(arg)}"
2071                )
2072            individual_output = self.graph.create_node(
2073                "call_function",
2074                operator.getitem,
2075                (fx_node, idx),
2076                name=name,
2077            )
2078            self.sync_fx_node(name, individual_output)
2079            meta_val.append(self.serialized_name_to_meta[name])
2080            # The derived `getitem` nodes should have the same stacktrace as the
2081            # original `fx_node`
2082            individual_output.meta.update(deserialized_metadata)
2083
2084        def generate_getitems(meta_val, fx_node: torch.fx.Node, args):
2085            for idx, arg in enumerate(args):
2086                if isinstance(arg, Argument):
2087                    arg = arg.value
2088                if isinstance(arg, (TensorArgument, SymIntArgument)):
2089                    generate_getitem(meta_val, fx_node, arg, idx)
2090                elif isinstance(arg, (list, tuple)):
2091                    list_output = self.graph.create_node(
2092                        "call_function",
2093                        operator.getitem,
2094                        (fx_node, idx),
2095                    )
2096                    meta_val.append([])
2097                    generate_getitems(meta_val[-1], list_output, arg)
2098                    list_output.meta.update(deserialized_metadata)
2099                    list_output.meta["val"] = meta_val[-1]
2100                else:
2101                    raise NotImplementedError(f"Unimplemented node output type: {arg}")
2102
2103        # Convert multiple return types to FX format.
2104        # In FX, each node only returns one value. So in order to represent
2105        # multiple return values, we have to emit a `getitem` node for each
2106        # return value.
2107        # This performs the inverse mapping of the `serialize_outputs` call in
2108        # serialization, see [NOTE: Multiple outputs]
2109        meta_val: List[Any] = []
2110        if len(serialized_node.outputs) == 1:
2111            assert isinstance(serialized_node.outputs[0].value, list)
2112            assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
2113            generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors)
2114        else:
2115            generate_getitems(meta_val, fx_node, serialized_node.outputs)
2116
2117        # also update the metaval for `fx_node` to be a list(meta)
2118        fx_node.meta["val"] = tuple(meta_val)
2119        self.serialized_name_to_node[fx_node.name] = fx_node
2120
2121    def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
2122        ret: Dict[str, Any] = {}
2123        if stack_trace := metadata.get("stack_trace"):
2124            ret["stack_trace"] = stack_trace
2125
2126        def deserialize_meta_func(serialized_target: str):
2127            module = None
2128            if serialized_target.startswith("torch.nn"):
2129                module = torch.nn
2130                serialized_target_names = serialized_target.split(".")[2:]
2131            elif serialized_target.startswith("torch"):
2132                module = torch
2133                serialized_target_names = serialized_target.split(".")[1:]
2134            else:
2135                return self.deserialize_operator(serialized_target)
2136
2137            target = module
2138            for name in serialized_target_names:
2139                if not hasattr(target, name):
2140                    return serialized_target
2141                else:
2142                    target = getattr(target, name)
2143            return target
2144
2145        if nn_module_stack_str := metadata.get("nn_module_stack"):
2146            # Originally serialized to "key,orig_path,type_str"
2147            def import_nn_module_stack(key, path, ty):
2148                return key, (path, ty)
2149
2150            # Helper function that splits strings by commas except for those
2151            # encapsulated by parens, which are valid traces.
2152            # TODO: Currently this is needed due to indexing Sequential
2153            # layers introducing names in the form "layer.slice(1, None, None)".
2154            # If that naming is improved, this fancier splitting can probably be
2155            # reverted to a simple split by comma.
2156            def metadata_split(metadata):
2157                # Remove the parentheses and commas inside them
2158                metadata = re.sub(r"\(.*?\)", "", metadata)
2159                # Split the string by comma, except for those inside parentheses
2160                return re.split(r"(?<!\()\s*,\s*(?!\()", metadata)
2161
2162            nn_module_stack = dict(
2163                import_nn_module_stack(*metadata_split(item))
2164                for item in nn_module_stack_str.split(ST_DELIMITER)
2165            )
2166            ret["nn_module_stack"] = nn_module_stack
2167
2168        if source_fn_st_str := metadata.get("source_fn_stack"):
2169            # Originally serializes to "fx_node_name,op_str"
2170            source_fn_st = []
2171            for source_fn_str in source_fn_st_str.split(ST_DELIMITER):
2172                name, target_str = source_fn_str.split(",")
2173                source_fn_st.append((name, deserialize_meta_func(target_str)))
2174            ret["source_fn_stack"] = source_fn_st
2175
2176        if torch_fn_str := metadata.get("torch_fn"):
2177            ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER))
2178        return ret
2179
2180    def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
2181        if x.type == "as_tensor":
2182            return ep.TensorArgument(name=x.as_tensor.name)
2183        elif x.type == "as_sym_int":
2184            return ep.SymIntArgument(name=x.as_sym_int.as_name)
2185        elif x.type == "as_custom_obj":
2186            return ep.ConstantArgument(
2187                name=x.as_custom_obj.name, value=self.deserialize_input(x)
2188            )
2189        else:
2190            return ep.ConstantArgument(name="", value=self.deserialize_input(x))
2191
2192    def deserialize_module_call_signature(
2193        self, module_call_signature: ModuleCallSignature
2194    ) -> ep.ModuleCallSignature:
2195        return ep.ModuleCallSignature(
2196            inputs=[
2197                self.deserialize_argument_spec(x) for x in module_call_signature.inputs
2198            ],
2199            outputs=[
2200                self.deserialize_argument_spec(x) for x in module_call_signature.outputs
2201            ],
2202            in_spec=treespec_loads(module_call_signature.in_spec),
2203            out_spec=treespec_loads(module_call_signature.out_spec),
2204        )
2205
2206    def deserialize_module_call_graph(
2207        self, module_call_graph: List[ModuleCallEntry]
2208    ) -> List[ep.ModuleCallEntry]:
2209        return [
2210            ep.ModuleCallEntry(
2211                fqn=entry.fqn,
2212                signature=(
2213                    self.deserialize_module_call_signature(entry.signature)
2214                    if entry.signature
2215                    else None
2216                ),
2217            )
2218            for entry in module_call_graph
2219        ]
2220
2221
2222class ExportedProgramDeserializer:
2223    def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None):
2224        self.expected_opset_version: Dict[str, int] = {}
2225        if expected_opset_version:
2226            self.expected_opset_version.update(expected_opset_version)
2227        if "aten" not in self.expected_opset_version:
2228            self.expected_opset_version["aten"] = torch._C._get_max_operator_version()
2229
2230    def deserialize_range_constraints(
2231        self,
2232        symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges],
2233        symbol_name_to_symbol: Dict[str, sympy.Symbol],
2234    ) -> Dict[sympy.Symbol, ValueRanges]:
2235        range_constraints = {}
2236        for k, v in symbol_name_to_range.items():
2237            if symbol := symbol_name_to_symbol.get(k):
2238                range_constraints[symbol] = v  # type: ignore[arg-type]
2239            else:
2240                log.warning(f"Symbol {k} did not appear in the graph that was deserialized")  # noqa: G004
2241        return range_constraints
2242
2243    def deserialize(
2244        self,
2245        exported_program: ExportedProgram,
2246        state_dict: Union[Dict[str, torch.Tensor], bytes],
2247        constants: Union[Dict[str, torch.Tensor], bytes],
2248        example_inputs: Optional[
2249            Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]
2250        ] = None,
2251    ) -> ep.ExportedProgram:
2252        assert isinstance(exported_program, ExportedProgram)
2253        version = exported_program.schema_version
2254
2255        # TODO(zhxchen17) blocked on thrift schema refactor
2256        if version.major != SCHEMA_VERSION[0] and not (
2257            version.major == 0 and version.minor == 0
2258        ):
2259            raise SerializeError(
2260                f"Serialized schema version {exported_program.schema_version} "
2261                f"does not match our current schema version {SCHEMA_VERSION}."
2262            )
2263
2264        symbol_name_to_range = {
2265            k: symbolic_shapes.ValueRanges(
2266                _int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)
2267            )
2268            for k, v in exported_program.range_constraints.items()
2269        }
2270        res = GraphModuleDeserializer().deserialize(
2271            exported_program.graph_module,
2272            state_dict,
2273            constants,
2274            example_inputs,
2275            symbol_name_to_range,
2276        )
2277        range_constraints = self.deserialize_range_constraints(
2278            symbol_name_to_range,
2279            res.names_to_symbols,
2280        )
2281        model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version
2282        self._validate_model_opset_version(model_opset_version)
2283
2284        upgrader = GraphModuleOpUpgrader(
2285            self.expected_opset_version, model_opset_version
2286        )
2287
2288        exported_program = ep.ExportedProgram(
2289            root=res.graph_module,
2290            graph=res.graph_module.graph,
2291            graph_signature=res.signature,
2292            state_dict=res.state_dict,  # type: ignore[arg-type]
2293            range_constraints=range_constraints,
2294            module_call_graph=res.module_call_graph,
2295            example_inputs=res.example_inputs,
2296            verifier=load_verifier(exported_program.dialect),
2297            constants=res.constants,
2298        )
2299        return upgrader.upgrade(exported_program)
2300
2301    def _validate_model_opset_version(
2302        self, model_opset_version: Optional[Dict[str, int]]
2303    ):
2304        """Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version
2305        difference.
2306        E.g., model_opset_version = {"aten": 3, "custom": 4}
2307        expected_opset_version = {"aten": 4, "custom": 4}
2308        This means we can use an upgrader for ATen to reconcile the deserialized model.
2309
2310        The logic of this method:
2311
2312        For common op namespaces:
2313        1. if model version < expected version, this case can be handled by upgraders.
2314        2. if model version > expected version, we need downgraders but not implemented yet.
2315        3. if model version == expected version, we don't need extra handling.
2316
2317        For op namespace only in model_opset_version, we should give a warning because it is missing from
2318        expected_opset_version.
2319        """
2320        if not model_opset_version:
2321            raise RuntimeError("Serialized model should have opset version.")
2322        common_namespaces = {
2323            key for key in model_opset_version if key in self.expected_opset_version
2324        }
2325        for namespace in common_namespaces:
2326            model_version = model_opset_version[namespace]
2327            assert isinstance(
2328                model_version, int
2329            ), f"model_opset_version value should be int, got {model_version}"
2330
2331            compiler_version = self.expected_opset_version[namespace]
2332            assert isinstance(
2333                compiler_version, int
2334            ), f"expected_opset_version value should be int, got {compiler_version}"
2335
2336            # TODO(larryliu0820): Add support for upgrader & downgrader
2337            if model_version != compiler_version:
2338                raise NotImplementedError(
2339                    f"Model opset version {model_opset_version} doesn't match to compiler opset version "
2340                    f"{self.expected_opset_version}! Upgrader/downgrader is not implemented yet."
2341                )
2342        for namespace in model_opset_version:
2343            if namespace in common_namespaces:
2344                continue
2345            log.warning(
2346                "Compiler doesn't have a version table for op namespace: {ns}. ",
2347                extra={"ns": namespace},
2348            )
2349
2350
2351class EnumEncoder(json.JSONEncoder):
2352    def default(self, obj):
2353        if isinstance(obj, Enum):
2354            return obj.value
2355        if isinstance(obj, bytes):
2356            return base64.b64encode(obj).decode("utf-8")
2357        return super().default(obj)
2358
2359
2360def _dataclass_to_dict(obj):
2361    if isinstance(obj, _Union):
2362        return {obj.type: _dataclass_to_dict(obj.value)}
2363    elif dataclasses.is_dataclass(obj):
2364        return {
2365            f.name: _dataclass_to_dict(getattr(obj, f.name))
2366            for f in dataclasses.fields(obj)
2367            if not (f.default is None and getattr(obj, f.name) is None)
2368        }
2369    elif isinstance(obj, list):
2370        return [_dataclass_to_dict(x) for x in obj]
2371    elif isinstance(obj, tuple):
2372        return tuple(_dataclass_to_dict(x) for x in obj)
2373    elif isinstance(obj, dict):
2374        return {k: _dataclass_to_dict(v) for k, v in obj.items()}
2375    else:
2376        return obj
2377
2378
2379def serialize(
2380    exported_program: ep.ExportedProgram,
2381    opset_version: Optional[Dict[str, int]] = None,
2382) -> SerializedArtifact:
2383    serialized_program = ExportedProgramSerializer(opset_version).serialize(
2384        exported_program
2385    )
2386    assert isinstance(serialized_program.exported_program, ExportedProgram)
2387
2388    json_program = json.dumps(
2389        _dataclass_to_dict(serialized_program.exported_program), cls=EnumEncoder
2390    )
2391    json_bytes = json_program.encode("utf-8")
2392    artifact = SerializedArtifact(
2393        json_bytes,
2394        serialized_program.state_dict,
2395        serialized_program.constants,
2396        serialized_program.example_inputs,
2397    )
2398    return artifact
2399
2400
2401def _dict_to_dataclass(cls, data):
2402    assert not isinstance(cls, str), f"Unresolved class type: '{cls}'."
2403    if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls):
2404        if data is None:
2405            return None
2406        ty_args = typing.get_args(cls)
2407        assert len(ty_args) == 2
2408        return _dict_to_dataclass(ty_args[0], data)
2409    elif isinstance(cls, type) and issubclass(cls, _Union):
2410        assert isinstance(data, dict)
2411        assert len(data) == 1
2412        _type = next(iter(data.keys()))
2413        _value = next(iter(data.values()))
2414        assert isinstance(_type, str)
2415        field_type = cls.__annotations__[_type]
2416        return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
2417    elif dataclasses.is_dataclass(cls):
2418        obj = cls(**data)  # type: ignore[assignment]
2419        type_hints = typing.get_type_hints(cls)
2420        for f in dataclasses.fields(cls):
2421            name = f.name
2422            new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name))
2423            setattr(obj, name, new_field_obj)
2424        return obj
2425    elif isinstance(data, list):
2426        if len(data) == 0:
2427            return data
2428        d_type = typing.get_args(cls)[0]
2429        return [_dict_to_dataclass(d_type, d) for d in data]
2430    elif isinstance(data, dict):
2431        v_type = typing.get_args(cls)[1]
2432        return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()}
2433    return data
2434
2435
2436def deserialize(
2437    artifact: SerializedArtifact,
2438    expected_opset_version: Optional[Dict[str, int]] = None,
2439) -> ep.ExportedProgram:
2440    assert isinstance(artifact.exported_program, bytes)
2441    exported_program_str = artifact.exported_program.decode("utf-8")
2442    exported_program_dict = json.loads(exported_program_str)
2443    serialized_exported_program = _dict_to_dataclass(
2444        ExportedProgram, exported_program_dict
2445    )
2446    return ExportedProgramDeserializer(expected_opset_version).deserialize(
2447        serialized_exported_program,
2448        artifact.state_dict,
2449        artifact.constants,
2450        artifact.example_inputs,
2451    )
2452
2453
2454def _canonicalize_graph(
2455    sorted_inputs, sorted_outputs, graph
2456) -> Tuple[Graph, Dict[str, str]]:
2457    def _get_argument(a: Argument):
2458        if a.type == "as_none":
2459            return None
2460        elif a.type == "as_tensor":
2461            return a.as_tensor
2462        elif a.type == "as_tensors":
2463            return a.as_tensors
2464        elif a.type == "as_int":
2465            return None
2466        elif a.type == "as_ints":
2467            return None
2468        elif a.type == "as_float":
2469            return None
2470        elif a.type == "as_floats":
2471            return None
2472        elif a.type == "as_string":
2473            return None
2474        elif a.type == "as_strings":
2475            return None
2476        elif a.type == "as_sym_int":
2477            return a.as_sym_int
2478        elif a.type == "as_sym_ints":
2479            return a.as_sym_ints
2480        elif a.type == "as_scalar_type":
2481            return None
2482        elif a.type == "as_memory_format":
2483            return None
2484        elif a.type == "as_layout":
2485            return None
2486        elif a.type == "as_device":
2487            return None
2488        elif a.type == "as_bool":
2489            return None
2490        elif a.type == "as_bools":
2491            return None
2492        elif a.type == "as_sym_bool":
2493            return a.as_sym_bool
2494        elif a.type == "as_sym_bools":
2495            return a.as_sym_bools
2496        elif a.type == "as_graph":
2497            return None
2498        elif a.type == "as_optional_tensors":
2499            return a.as_optional_tensors
2500        elif a.type == "as_custom_obj":
2501            return None
2502        elif a.type == "as_operator":
2503            return None
2504        else:
2505            raise AssertionError(f"Unknown input type to the ExportedProgram: {a}")
2506
2507    # Stage 1: Reorder named items.
2508    def for_args(f, a):
2509        assert isinstance(a, Argument)
2510        pytree.tree_map(f, _get_argument(a))
2511
2512    def sort_nodes(nodes):
2513        @dataclass
2514        class Edges:
2515            outs: List[int]
2516            ins: int
2517
2518        graph_inputs: Set[str] = set()
2519        def_table: Dict[str, int] = {}
2520        edges: Dict[int, Edges] = {}
2521        candidates: List[Tuple[str, List[Tuple[str, List[int]]], int]] = []
2522        rank: Dict[str, int] = {}
2523        ret: List[Node] = []
2524
2525        def get_name(a) -> Optional[str]:
2526            if a is None:
2527                return None
2528            if isinstance(a, TensorArgument):
2529                return a.name
2530            elif isinstance(a, (SymIntArgument, SymBoolArgument)):
2531                if a.type == "as_name":
2532                    return a.as_name
2533                elif a.type in ("as_int", "as_bool"):
2534                    return None
2535                else:
2536                    raise AssertionError(f"Unknown argument type: {a}")
2537            elif isinstance(a, OptionalTensorArgument):
2538                if a.type == "as_tensor":
2539                    return a.as_tensor.name
2540                elif a.type == "as_none":
2541                    return None
2542                else:
2543                    raise AssertionError(f"Unknown optional tensor type: {a}")
2544            else:
2545                raise AssertionError(f"Unknown argument type: {a}")
2546
2547        for i in sorted_inputs:
2548
2549            def add_input(a):
2550                if s := get_name(a):
2551                    graph_inputs.add(s)
2552
2553            for_args(add_input, i)
2554
2555        for idx, node in enumerate(nodes):
2556
2557            def add_def(a):
2558                if s := get_name(a):
2559                    assert s not in def_table
2560                    def_table[s] = idx
2561
2562            for o in node.outputs:
2563                for_args(add_def, o)
2564
2565            edges[idx] = Edges([], 0)
2566
2567        for idx, user in enumerate(nodes):
2568
2569            def add_edge(a):
2570                if s := get_name(a):
2571                    if s not in def_table:
2572                        assert s in graph_inputs
2573                        return
2574                    src = def_table[s]
2575                    edges[src].outs.append(idx)
2576                    edges[idx].ins += 1
2577
2578            for i in user.inputs:
2579                for_args(add_edge, i.arg)
2580
2581        def add_rank(a):
2582            if s := get_name(a):
2583                assert s not in rank
2584                rank[s] = len(rank)
2585
2586        def get_rank(a):
2587            if s := get_name(a):
2588                return rank[s]
2589            else:
2590                return -1
2591
2592        for i in sorted_inputs:
2593            for_args(add_rank, i)
2594
2595        def add_candidate(idx: int):
2596            def get_ranks(i):
2597                ranks = []
2598                for_args(lambda x: ranks.append(get_rank(x)), i)
2599                return ranks
2600
2601            node = nodes[idx]
2602            args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs]
2603            heapq.heappush(candidates, (node.target, args_rank, idx))
2604
2605        for idx, e in edges.items():
2606            if e.ins == 0:
2607                add_candidate(idx)
2608
2609        while len(candidates) > 0:
2610            _, _, idx = heapq.heappop(candidates)
2611            node = nodes[idx]
2612            for o in node.outputs:
2613                for_args(add_rank, o)
2614            ret.append(node)
2615            assert idx in edges
2616            for user in edges[idx].outs:
2617                e = edges[user]
2618                assert e.ins > 0
2619                e.ins -= 1
2620                if e.ins == 0:
2621                    add_candidate(user)
2622            edges[idx].outs.clear()
2623
2624        return ret
2625
2626    sorted_nodes = sort_nodes(graph.nodes)
2627    assert len(sorted_nodes) == len(graph.nodes)
2628
2629    # Stage 2: Rename nodes.
2630    name_table: Dict[str, str] = {}
2631
2632    def rename_def(a):
2633        def _rename(arg_name, values):
2634            new_name = f"_{len(name_table)}"
2635            assert arg_name not in name_table
2636            name_table[arg_name] = new_name
2637            assert arg_name in values
2638            values[new_name] = values.pop(arg_name)
2639            return new_name
2640
2641        if a is None:
2642            return
2643        if isinstance(a, TensorArgument):
2644            a.name = _rename(a.name, graph.tensor_values)
2645        elif isinstance(a, SymIntArgument):
2646            if a.type == "as_name":
2647                a.as_name = _rename(a.as_name, graph.sym_int_values)
2648        elif isinstance(a, SymBoolArgument):
2649            if a.type == "as_name":
2650                a.as_name = _rename(a.as_name, graph.sym_bool_values)
2651        else:
2652            raise AssertionError(f"Unknown argument type: {a}")
2653
2654    def replace_use(a):
2655        if a is None:
2656            return
2657        if isinstance(a, TensorArgument):
2658            a.name = name_table.get(a.name, a.name)
2659        elif isinstance(a, SymIntArgument):
2660            if a.type == "as_name":
2661                a.as_name = name_table.get(a.as_name, a.as_name)
2662        elif isinstance(a, SymBoolArgument):
2663            if a.type == "as_name":
2664                a.as_name = name_table.get(a.as_name, a.as_name)
2665        elif isinstance(a, OptionalTensorArgument):
2666            if a.type == "as_tensor":
2667                a.as_tensor.name = name_table.get(a.as_tensor.name, a.as_tensor.name)
2668        else:
2669            raise AssertionError(f"Unknown argument type: {a}")
2670
2671    for i in sorted_inputs:
2672        for_args(rename_def, i)
2673
2674    for n in sorted_nodes:
2675        for o in n.outputs:
2676            for_args(rename_def, o)
2677
2678    for n in sorted_nodes:
2679        for i in n.inputs:
2680            for_args(replace_use, i.arg)
2681
2682    for o in sorted_outputs:
2683        for_args(replace_use, o)
2684
2685    # Stage 3: Remove unstable fields.
2686    for n in sorted_nodes:
2687        n.metadata.clear()
2688
2689    # Stage 4: Aggregate values.
2690    sorted_tensor_values = dict(
2691        sorted(graph.tensor_values.items(), key=operator.itemgetter(0))
2692    )
2693    sorted_sym_int_values = dict(
2694        sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
2695    )
2696    sorted_sym_bool_values = dict(
2697        sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
2698    )
2699
2700    # Stage 5: Recurse in subgraphs.
2701    counter = 0
2702    for node in sorted_nodes:
2703        for i in node.inputs:
2704            a = i.arg
2705            if a.type == "as_graph":
2706                a.as_graph.graph = _canonicalize_graph(
2707                    a.as_graph.graph.inputs, a.as_graph.graph.outputs, a.as_graph.graph
2708                )
2709                a.as_graph.name = f"_g{counter}"
2710                counter += 1
2711
2712    graph = Graph(
2713        inputs=sorted_inputs,
2714        outputs=sorted_outputs,
2715        nodes=sorted_nodes,
2716        tensor_values=sorted_tensor_values,
2717        sym_int_values=sorted_sym_int_values,
2718        sym_bool_values=sorted_sym_bool_values,
2719        is_single_tensor_return=graph.is_single_tensor_return,
2720    )
2721    return graph, name_table
2722
2723
2724def canonicalize(ep: ExportedProgram) -> ExportedProgram:
2725    """
2726    Normalize a serialized ExportedProgram, so that different eager program which
2727    shares the same semantics can get a single representation on disk.
2728
2729    This function canonicalizes an ExportedProgram by:
2730
2731    1. Sorting nodes in topological order.
2732    2. Rename nodes to have unique names.
2733    3. Remove unstable fields.
2734    4. Aggregate the above program fields.
2735    5. Recurse in subgraphs.
2736
2737    Args:
2738        ep (ExportedProgram): The ExportedProgram to canonicalize.
2739
2740    Returns:
2741        ExportedProgram: The canonicalized exported program.
2742    """
2743    ep = copy.deepcopy(ep)
2744
2745    opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0)))
2746    range_constraints = dict(
2747        sorted(ep.range_constraints.items(), key=operator.itemgetter(0))
2748    )
2749    module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn)
2750    signature = ep.graph_module.signature
2751    graph = ep.graph_module.graph
2752
2753    assert len(graph.inputs) == len(signature.input_specs)
2754    assert len(graph.outputs) == len(signature.output_specs)
2755
2756    def rank_input(inp) -> Tuple[int, Optional[str], int]:
2757        idx, (arg, spec) = inp
2758        assert isinstance(spec, InputSpec)
2759        if spec.type == "user_input":
2760            return 5, None, idx
2761        elif spec.type == "parameter":
2762            return 1, spec.parameter.parameter_name, idx
2763        elif spec.type == "buffer":
2764            return 2, spec.buffer.buffer_name, idx
2765        elif spec.type == "tensor_constant":
2766            return 3, spec.tensor_constant.tensor_constant_name, idx
2767        elif spec.type == "custom_obj":
2768            return 4, spec.custom_obj.custom_obj_name, idx
2769        elif spec.type == "token":
2770            return 0, None, idx
2771        elif spec.type == "constant_input":
2772            return 6, spec.constant_input.name, idx
2773        else:
2774            raise AssertionError(f"Unknown input type: {spec}")
2775
2776    def rank_output(out) -> Tuple[int, Optional[str], int]:
2777        idx, (arg, spec) = out
2778        assert isinstance(spec, OutputSpec)
2779        if spec.type == "user_output":
2780            return 3, None, idx
2781        elif spec.type == "loss_output":
2782            return 3, None, idx
2783        elif spec.type == "buffer_mutation":
2784            return 1, spec.buffer_mutation.buffer_name, idx
2785        elif spec.type == "gradient_to_parameter":
2786            return 4, spec.gradient_to_parameter.parameter_name, idx
2787        elif spec.type == "gradient_to_user_input":
2788            return 5, None, idx
2789        elif spec.type == "user_input_mutation":
2790            return 2, None, idx
2791        elif spec.type == "token":
2792            return 0, None, idx
2793        else:
2794            raise AssertionError(f"Unknown output type: {spec}")
2795
2796    sorted_ins = sorted(
2797        enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input
2798    )
2799    sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins))  # type: ignore[assignment]
2800
2801    sorted_outs = sorted(
2802        enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output
2803    )
2804    sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs))  # type: ignore[assignment]
2805
2806    sorted_graph, replace_table = _canonicalize_graph(
2807        sorted_inputs, sorted_outputs, graph
2808    )
2809
2810    def replace_input(inp):
2811        assert isinstance(spec, InputSpec)
2812        if spec.type == "user_input":
2813            arg = spec.user_input.arg
2814            if arg.type == "as_tensor":
2815                t = arg.as_tensor
2816                t.name = replace_table[t.name]
2817            elif arg.type == "as_sym_int":
2818                s = arg.as_sym_int
2819                if s.type == "as_name":
2820                    s.as_name = replace_table[s.as_name]
2821                elif s.type == "as_int":
2822                    pass
2823                else:
2824                    raise AssertionError(f"Unknown sym_int type: {s}")
2825            elif arg.type in (
2826                "as_none",
2827                "as_bool",
2828                "as_int",
2829                "as_float",
2830                "as_string",
2831                "as_custom_obj",
2832            ):
2833                return
2834            else:
2835                raise AssertionError(f"Unknown input type: {arg}")
2836        elif spec.type == "parameter":
2837            t = spec.parameter.arg
2838            t.name = replace_table[t.name]
2839        elif spec.type == "buffer":
2840            t = spec.buffer.arg
2841            t.name = replace_table[t.name]
2842        elif spec.type == "tensor_constant":
2843            t = spec.tensor_constant.arg
2844            t.name = replace_table[t.name]
2845        elif spec.type == "custom_obj":
2846            return
2847        elif spec.type == "token":
2848            tok = spec.token.arg
2849            tok.name = replace_table[tok.name]
2850        elif spec.type == "constant_input":
2851            return
2852        else:
2853            raise AssertionError(f"Unknown input type: {spec}")
2854
2855    def replace_output(out):
2856        assert isinstance(spec, OutputSpec)
2857        if spec.type == "user_output":
2858            arg = spec.user_output.arg
2859            if arg.type == "as_tensor":
2860                t = arg.as_tensor
2861                t.name = replace_table[t.name]
2862            elif arg.type == "as_sym_int":
2863                s = arg.as_sym_int
2864                if s.type == "as_name":
2865                    s.as_name = replace_table[s.as_name]
2866                elif s.type == "as_int":
2867                    pass
2868                else:
2869                    raise AssertionError(f"Unknown sym_int type: {s}")
2870            elif arg.type in ("as_none", "as_int", "as_float", "as_string"):
2871                return
2872            else:
2873                raise AssertionError(f"Unknown input type: {arg}")
2874        elif spec.type == "loss_output":
2875            t = spec.loss_output.arg
2876            t.name = replace_table[t.name]
2877        elif spec.type == "buffer_mutation":
2878            t = spec.buffer_mutation.arg
2879            t.name = replace_table[t.name]
2880        elif spec.type == "gradient_to_parameter":
2881            t = spec.gradient_to_parameter.arg
2882            t.name = replace_table[t.name]
2883        elif spec.type == "gradient_to_user_input":
2884            g = spec.gradient_to_user_input
2885            g.arg.name = replace_table[g.arg.name]
2886            g.user_input_name = replace_table[g.user_input_name]
2887        elif spec.type == "user_input_mutation":
2888            u = spec.user_input_mutation
2889            u.arg.name = replace_table[u.arg.name]
2890            u.user_input_name = replace_table[u.user_input_name]
2891        elif spec.type == "token":
2892            tok = spec.token.arg
2893            tok.name = replace_table[tok.name]
2894        else:
2895            raise AssertionError(f"Unknown output type: {spec}")
2896
2897    for spec in input_specs:
2898        replace_input(spec)
2899
2900    for spec in output_specs:
2901        replace_output(spec)
2902
2903    return ExportedProgram(
2904        graph_module=GraphModule(
2905            graph=sorted_graph,
2906            signature=GraphSignature(
2907                input_specs=list(input_specs),
2908                output_specs=list(output_specs),
2909            ),
2910            module_call_graph=module_call_graph,
2911        ),
2912        opset_version=opset_version,
2913        range_constraints=range_constraints,
2914        schema_version=ep.schema_version,
2915        dialect=ep.dialect,
2916    )
2917