xref: /aosp_15_r20/external/pytorch/torch/_export/converter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import builtins
3import logging
4import operator
5import typing
6import warnings
7from contextlib import contextmanager
8from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
9
10import torch
11import torch.export._trace
12from torch import _C
13from torch._export.passes.replace_quantized_ops_with_standard_ops_pass import (
14    replace_quantized_ops_with_standard_ops,
15)
16from torch.export.exported_program import ExportedProgram
17from torch.export.graph_signature import (
18    ConstantArgument,
19    CustomObjArgument,
20    InputKind,
21    InputSpec,
22    OutputKind,
23    OutputSpec,
24    TensorArgument,
25)
26from torch.fx import subgraph_rewriter
27
28
29log = logging.getLogger(__name__)
30
31
32def _get_param_count_list(method_graph, args_params):
33    param_count_list = []
34    for input_, arg_params_ in zip(method_graph.inputs(), args_params):
35        if "PackedParams" in str(input_.type()):
36            in_vars, _ = torch.jit._flatten(arg_params_)
37            param_count_list.append(len(in_vars))
38        else:
39            param_count_list.append(arg_params_ is not None)
40
41    return param_count_list
42
43
44def _trace_and_get_graph_from_model(model, args):
45    # A basic sanity check: make sure the state_dict keys are the same
46    # before and after running the model.  Fail fast!
47    orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
48
49    # Disable Autocast cache because it replaces kernel's weight and bias
50    # by (undesired) constants.
51    # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
52    prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
53    torch.set_autocast_cache_enabled(False)
54    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
55        model,
56        args,
57        strict=False,
58        _force_outplace=False,
59        _return_inputs_states=True,
60    )
61    torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
62
63    if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
64        raise RuntimeError(
65            "state_dict changed after running the tracer; "
66            "something weird is happening in your model!"
67        )
68
69    return trace_graph, torch_out
70
71
72def _create_jit_graph(
73    model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
74) -> Tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]:
75    if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
76        flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
77        torch_out = None
78
79        if isinstance(model, torch.jit.ScriptModule):
80            try:
81                graph = model.forward.graph  # type: ignore[attr-defined]
82            except AttributeError as e:
83                raise RuntimeError("'forward' method must be a script method") from e
84            _C._jit_pass_onnx_function_substitution(graph)
85            freezed_module = _C._freeze_module(
86                typing.cast(_C.ScriptModule, model._c), preserveParameters=True
87            )
88            module, params = _C._jit_onnx_list_model_parameters(freezed_module)
89            method_graph = module._get_method("forward").graph
90            args_params = tuple(args) + tuple(params)
91            param_count_list = _get_param_count_list(method_graph, args_params)
92            in_vars, _ = torch.jit._flatten(args_params)
93            graph = _C._propagate_and_assign_input_shapes(
94                method_graph, tuple(in_vars), param_count_list, False, False
95            )
96            return graph, params, torch_out, module
97
98        # torch.jit.ScriptFunction
99        params = []
100        graph = model.graph
101        _C._jit_pass_onnx_function_substitution(graph)
102        param_count_list = _get_param_count_list(graph, args)
103        graph = _C._propagate_and_assign_input_shapes(
104            graph, flattened_args, param_count_list, False, False
105        )
106        return graph, params, torch_out, None
107
108    graph, torch_out = _trace_and_get_graph_from_model(model, args)
109    _C._jit_pass_onnx_lint(graph)
110    state_dict = torch.jit._unique_state_dict(model)
111    params = list(state_dict.values())
112    graph_inputs = list(graph.inputs())
113    user_input_num = len(graph_inputs) - len(state_dict)
114    param_names = list(state_dict.keys())
115    for i, inp in enumerate(graph_inputs):
116        if i >= user_input_num:
117            inp.setDebugName(param_names[i - user_input_num])
118    _C._jit_pass_onnx_function_substitution(graph)
119    return graph, params, torch_out, None
120
121
122def list_add(a, b):
123    return a + b
124
125
126def list_append(container, element):
127    return container + [element]
128
129
130def execute_subgraph_from_prim_loop(
131    subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs
132):
133    """
134    subgraph: GraphModule from sub-block.
135    iter_idx: The index of interation.
136    len_loop_local_arguments: The number of loop local arguments in args.
137    """
138
139    # Loop local variables. TS graph create those as inputs because their values
140    # are updated inside the loop.
141    loop_local_args = args[:len_loop_local_arguments]
142    # Global variables that are not passed in as inputs to the loop sub-blocks
143    # but are directly used. Most of time, their values are not updated, but
144    # the only exception is when there are some operations that perform inplace
145    # updates.
146    global_args = args[len_loop_local_arguments:]
147    return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs)
148
149
150def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
151    def pattern(im, dim, scale):
152        sym_size_int = torch.ops.aten.sym_size.int(im, dim)
153        scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int)
154        div_scalar_mode = torch.ops.aten.div.Scalar_mode(
155            scalar_tensor, scale, rounding_mode="trunc"
156        )
157        int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode)
158        return int_tensor
159
160    def replacement(im, dim, scale):
161        sym_size_int = torch.ops.aten.sym_size.int(im, dim)
162        return sym_size_int // scale
163
164    replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement)
165
166
167def is_valid_for_codegen(name):
168    if len(name) == 0:
169        raise RuntimeError("Empty argument name for codegen")
170    if name[0].isdigit():
171        return False
172    return True
173
174
175def normalize_name(name: str, prefix: str = "rename") -> str:
176    name = name.replace(".", "_")
177    if is_valid_for_codegen(name):
178        return name
179    return f"{prefix}_{name}"
180
181
182def ir_name_to_func_name(name: str) -> str:
183    """prim::If -> convert_prim_If"""
184    name_list = name.split("::")
185    return "convert_" + "_".join(name_list)
186
187
188def get_node_as_placeholder_or_get_attr(fx_graph, name, is_top_level_graph):
189    if is_top_level_graph:
190        return fx_graph.get_attr(name)
191    return fx_graph.placeholder(name)
192
193
194_TORCH_DTYPE_TO_ENUM = {
195    torch.uint8: 0,
196    torch.int8: 1,
197    torch.int16: 2,
198    torch.int32: 3,
199    torch.int64: 4,
200    torch.float16: 5,
201    torch.float32: 6,
202    torch.float64: 7,
203    torch.complex32: 8,
204    torch.complex64: 9,
205    torch.complex128: 10,
206    torch.bool: 11,
207    torch.qint8: 12,
208    torch.quint8: 13,
209    torch.bfloat16: 15,
210}
211
212_TORCH_ENUM_TO_DTYPE = {value: key for key, value in _TORCH_DTYPE_TO_ENUM.items()}
213
214
215def get_dtype_as_int(tensor):
216    """
217    prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of
218    the tensor and returns the integer corresponding to this dtype based on the
219    enum in ScalarType.h
220    """
221    dtype = tensor.dtype
222    if dtype not in _TORCH_DTYPE_TO_ENUM:
223        raise RuntimeError(f"Unsupported dtype {dtype}")
224    return _TORCH_DTYPE_TO_ENUM[dtype]
225
226
227# Those operators will be automatically populated to a instance method
228# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
229# Please check __init__ for method population implementations.
230kind_to_standard_operators = {
231    "prim::max": builtins.max,
232    "prim::min": builtins.min,
233    "prim::TupleIndex": operator.getitem,
234    "aten::__is__": operator.is_,
235    "aten::__isnot__": operator.is_not,
236    "aten::__not__": operator.not_,
237    "aten::__contains__": operator.contains,
238    "prim::dtype": get_dtype_as_int,
239    "aten::len": len,
240    # Mapping from specialized op to its symbolic counterpart.
241    # They currently do not have any other overrides.
242    "aten::numel": torch.ops.aten.sym_numel,
243    "aten::size": torch.ops.aten.sym_size,
244    "aten::storage_offset": torch.ops.aten.sym_storage_offset,
245    "aten::stride": torch.ops.aten.sym_stride,
246}
247
248
249def get_ir_value_parent_name_and_attr_name(node):
250    irv_parent_name, irv_name = node.input().debugName(), node.output().debugName()
251    attr_name = node.s("name")
252    return irv_name, irv_parent_name, attr_name
253
254
255def construct_fqn(ir, ref_map, name_map):
256    name_list = []
257    while ir in ref_map:
258        name_list.append(name_map[ir])
259        ir = ref_map[ir]
260    return ".".join(reversed(name_list))
261
262
263def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]:
264    """
265    Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
266    When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
267    each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model
268    parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model,
269    we will run this pass which will:
270        1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls.
271        2. Process the graph bottom up to find the lifted attributes of each block by taking the union
272        of the attributes used in the current block, and the lifted attributes of all its child blocks.
273
274    Returns:
275        A mapping of blocks to a set of FQNs of its lifted attributes.
276    """
277
278    # A map from a block to its expected to be lifted arguments.
279    blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {}
280
281    # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a
282    # GetAttr node. By traversing this reference map, we can figure out the
283    # full IR aliasing pass and figure out the FQN of an attribute.
284    # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1"
285    node_to_parent_map: Dict[str, str] = {}
286
287    # Used for reconstructing the FQN of an attribute based on the reference map.
288    # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR
289    # This name map stores which attribute name is called for a src IR --> dest IR action.
290    # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear"
291    node_to_attr_name: Dict[str, str] = {}
292
293    def _dfs_get_attr_dependency(entry):
294        """
295        First DFS path to construct reference map and name map.
296        """
297        for node in entry.nodes():
298            if node.kind() == "prim::GetAttr":
299                (
300                    irv_name,
301                    irv_parent_name,
302                    attr_name,
303                ) = get_ir_value_parent_name_and_attr_name(node)
304                node_to_parent_map[irv_name] = irv_parent_name
305                node_to_attr_name[irv_name] = attr_name
306            for block in node.blocks():
307                _dfs_get_attr_dependency(block)
308
309    def _map_blocks_to_lifted_attrs(entry):
310        """
311        Walk the graph in a bottom-up fashion to build the expected to be
312        lifted arguments for each block.
313        """
314        arguments: Set[str] = set()
315        for node in entry.nodes():
316            for block in node.blocks():
317                # Recursively build.
318                arguments = arguments.union(_map_blocks_to_lifted_attrs(block))
319            if node.kind() == "prim::GetAttr":
320                irv_name = node.output().debugName()
321                # Skip for intermediate GetAttr, which will anyway not result a FQN.
322                # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"}
323                #       node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"}
324                #       There is only one FQN %3-->%2-->%1: self.linear.weight
325                #       %2-->%1 is not a FQN: self.linear
326                if irv_name not in set(node_to_parent_map.values()):
327                    arguments.add(
328                        construct_fqn(irv_name, node_to_parent_map, node_to_attr_name)
329                    )
330        if not isinstance(entry, torch._C.Graph):  # Skip the top level.
331            blocks_to_lifted_attrs[entry] = arguments
332        return arguments
333
334    _dfs_get_attr_dependency(graph)
335    _map_blocks_to_lifted_attrs(graph)
336
337    return blocks_to_lifted_attrs
338
339
340def get_attribute_fqn_from_ts_node(
341    name_to_attribute_fqn: Dict[str, str], node: torch._C.Node
342) -> str:
343    def get_attr(name: str):
344        if name in name_to_attribute_fqn:
345            return name_to_attribute_fqn[name]
346        else:
347            raise ValueError(f"Attribute {name} not found")
348
349    if node.kind() == "prim::SetAttr":
350        input_name = next(node.inputs()).debugName()
351    elif node.kind() == "prim::GetAttr":
352        input_name = node.input().debugName()
353    else:
354        raise RuntimeError(
355            f"Unexpected node kind when getting attribute fqn. node: {node} "
356        )
357
358    attr_name = node.s("name")
359    root_attr_name = get_attr(input_name)
360    attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
361
362    return attr_fqn
363
364
365def get_op_overload(node: torch._C.Node):
366    schema_str = node.schema()
367    assert schema_str != "(no schema)", f"got empty schema for {node}"
368    schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str)
369    ns, op_name = str(schema.name).split("::")
370    override = schema.overload_name
371
372    try:
373        op_overload_mod = getattr(torch.ops, ns)
374        op_overload_packet = getattr(op_overload_mod, op_name)
375        if override:
376            op_overload = getattr(op_overload_packet, override)
377        else:
378            op_overload = op_overload_packet.default
379    except Exception as e:
380        raise RuntimeError(
381            f"Unable to find operator {node.kind()} with schema {node.schema()}"
382        ) from e
383
384    return op_overload
385
386
387class TS2FXGraphConverter:
388    def __init__(
389        self,
390        ts_graph: Union[torch._C.Graph, torch._C.Block],
391        name_to_param: Dict[str, torch.Tensor],
392        name_to_buffer: Dict[str, torch.Tensor],
393        blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
394        name_to_non_tensor_attribute: Dict[str, Any],
395        name_to_constant: Dict[str, Any],
396    ):
397        self.ts_graph = ts_graph
398        self.name_to_param = name_to_param
399        self.name_to_buffer = name_to_buffer
400
401        self.fx_graph: torch.fx.Graph = torch.fx.Graph()
402        self.input_specs: List[InputSpec] = []
403        self.output_specs: List[OutputSpec] = []
404
405        self.name_to_node: Dict[
406            str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
407        ] = {}
408        self.name_to_constant: Dict[str, Any] = name_to_constant
409
410        # Mapping from torchscript node output name to attribute fully qualified name
411        self.name_to_attribute_fqn: Dict[str, str] = {}
412
413        # Mapping from fully qualified name to real values or a fx graph node
414        # During convert, this represents the current value of a non-tensor attribute
415        # One use case is:
416        #   def forward(self, x):
417        #        c1 = self.count
418        #        self.count += 1
419        #        c2 = self.count
420        #        return x + c1 + c2
421        self.name_to_non_tensor_attribute_node: Dict[str, Any] = {}
422
423        # Mapping from fully qualified name to initial real values inputs
424        # We separate it from self.name_to_non_tensor_attribute_node since
425        # we need initial real value input when we construct fx.GraphModule
426        self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute
427
428        self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
429
430        self.blocks_to_lifted_attrs = blocks_to_lifted_attrs
431
432        # Populate methods for the standard operators.
433        for k in kind_to_standard_operators.keys():
434            handler_func_name = ir_name_to_func_name(k)
435            # Create an indirect function call:
436            # convert_<namespace>_<opname> --> lambda node: _convert_standard_operator(node)
437            setattr(
438                self,
439                handler_func_name,
440                lambda node: self._convert_standard_operators(node),
441            )
442
443        # This stores a list of return results that do not appear in the original TS
444        # graph's outputs. The reason we maintain this is because some operations in the sub-block
445        # might have inplace updates to the variable defined in the parent fx graph. After
446        # the execution of that sub-block, the variable defined in the parent fx graph also
447        # needs to be updated.
448        self.name_update_from_subblock_to_parent: Set[str] = set()
449
450    def _is_get_attr_node(self, fqn):
451        return (
452            fqn in self.name_to_buffer
453            or fqn in self.name_to_param
454            or (
455                fqn in self.name_to_constant
456                and isinstance(self.name_to_constant[fqn], torch.ScriptObject)
457            )
458        )
459
460    def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]):
461        subgraph_nodes, subgraph_converters = [], []
462        for block in node.blocks():
463            subgraph_converter = TS2FXGraphConverter(
464                block,
465                self.name_to_param,
466                self.name_to_buffer,
467                self.blocks_to_lifted_attrs,
468                {},
469                self.name_to_constant,
470            )
471            subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn
472
473            for block_arg in arguments:
474                normalized_block_arg_name = normalize_name(block_arg)
475                placeholder_node = subgraph_converter.fx_graph.placeholder(
476                    normalized_block_arg_name
477                )
478                subgraph_converter.name_to_node[block_arg] = placeholder_node
479
480            subgraph = subgraph_converter.convert()
481            subgraph_name = self.add_subgraph(subgraph)
482            subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
483            subgraph_converters.append(subgraph_converter)
484        return subgraph_nodes, subgraph_converters
485
486    def _identify_inputs_as_arguments(self, entry):
487        """
488        Identify inputs from the innermost sub-block. This is needed
489        for nested sub-blocks when the input is hidden in the nested sub-block.
490        E.g., example IR of input is hidden in the nested sub-block.
491        Graph[x.1]
492        %1 = ...
493            Block[]
494                Block[x.1]
495                    %2 = x.1 ...
496        """
497        arguments: Set[str] = set()
498        for block in entry.blocks():
499            for block_node in block.nodes():
500                for block_node_in in block_node.inputs():
501                    if (
502                        block_node_in.debugName() in self.name_to_node
503                        and block_node_in.debugName() not in self.name_to_attribute_fqn
504                    ):
505                        arguments.add(block_node_in.debugName())
506                arguments = arguments.union(
507                    self._identify_inputs_as_arguments(block_node)
508                )
509        return arguments
510
511    def is_top_level_graph(self):
512        return isinstance(self.ts_graph, torch._C.Graph)
513
514    def add_subgraph(self, subgraph) -> str:
515        name = f"subgraph_{len(self.subgraphs)}"
516        self.subgraphs[name] = subgraph
517        return name
518
519    def get_args_kwargs(self, node: torch._C.Node, schema):
520        args = []
521        kwargs = {}
522        for input, schema_arg in zip(node.inputs(), schema.arguments):
523            if schema_arg.kwarg_only:
524                kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input)
525            else:
526                args.append(self.get_fx_value_by_ir_value(input))
527
528        return tuple(args), kwargs
529
530    def get_fx_value_by_ir_value(self, value: torch._C.Value):
531        value_name = value.debugName()
532
533        if value_name in self.name_to_node:
534            input_node = self.name_to_node[value_name]
535            return input_node
536        elif value_name in self.name_to_constant:
537            if isinstance(self.name_to_constant[value_name], torch.ScriptObject):
538                return self.fx_graph.get_attr(value_name)
539            return self.name_to_constant[value_name]
540        else:
541            raise ValueError(f"Input {value_name} not found")
542
543    def get_fx_value_by_fqn(self, name):
544        if name in self.name_to_node:
545            fx_node = self.name_to_node[name]
546        elif name in self.name_to_constant:
547            fx_node = self.name_to_constant[name]
548        elif name in self.name_to_non_tensor_attribute_node:
549            fx_node = self.name_to_non_tensor_attribute_node[name]
550        elif name in self.name_to_non_tensor_attribute:
551            fx_node = self.name_to_non_tensor_attribute[name]
552        else:
553            raise ValueError(f"Attribute {name} not found")
554        return fx_node
555
556    def convert(self) -> torch.fx.GraphModule:
557        self.convert_graph_inputs()
558
559        for node in self.ts_graph.nodes():
560            self.convert_node(node)
561
562        self.convert_graph_outputs()
563
564        # Pass parameter and buffer to the root for lookup.
565        gm = torch.fx.GraphModule(
566            {
567                **self.subgraphs,
568                **self.name_to_param,
569                **self.name_to_buffer,
570                **self.name_to_non_tensor_attribute,
571                **self.name_to_constant,
572            },
573            self.fx_graph,
574        )
575
576        inplace_optimize_sym_size_div(gm)
577
578        gm.graph.lint()
579
580        return gm
581
582    def convert_graph_inputs(self):
583        for graph_input in self.ts_graph.inputs():
584            name = graph_input.debugName()
585
586            if name in self.name_to_param:
587                normalized_name = normalize_name(name)
588                self.input_specs.append(
589                    InputSpec(
590                        InputKind.PARAMETER,
591                        arg=TensorArgument(name=normalized_name),
592                        target=name,
593                    )
594                )
595                fx_node = get_node_as_placeholder_or_get_attr(
596                    self.fx_graph, name, self.is_top_level_graph()
597                )
598            elif name in self.name_to_buffer:
599                normalized_name = normalize_name(name)
600                self.input_specs.append(
601                    InputSpec(
602                        InputKind.BUFFER,
603                        arg=TensorArgument(name=normalized_name),
604                        target=name,
605                        persistent=True,
606                    )
607                )
608                fx_node = get_node_as_placeholder_or_get_attr(
609                    self.fx_graph, name, self.is_top_level_graph()
610                )
611            elif name in self.name_to_constant:
612                assert isinstance(
613                    self.name_to_constant[name], torch.ScriptObject
614                ), "Input conversion only handles ScriptObject"
615                normalized_name = normalize_name(name)
616                self.input_specs.append(
617                    InputSpec(
618                        InputKind.CUSTOM_OBJ,
619                        arg=CustomObjArgument(
620                            name=normalized_name, class_fqn=normalized_name
621                        ),
622                        target=name,
623                        persistent=False,
624                    )
625                )
626                fx_node = get_node_as_placeholder_or_get_attr(
627                    self.fx_graph, name, self.is_top_level_graph()
628                )
629            elif isinstance(graph_input.type(), torch.ClassType):
630                # Directly skip inputs that are ScriptObject but not used in the graph.
631                continue
632            else:
633                normalized_name = normalize_name(name, prefix="input")
634                self.input_specs.append(
635                    InputSpec(
636                        InputKind.USER_INPUT,
637                        arg=TensorArgument(name=normalized_name),
638                        target=name,
639                    )
640                )
641                fx_node = self.fx_graph.placeholder(normalized_name)
642
643            self.name_to_node[name] = fx_node
644
645    def convert_aten_Float(self, node: torch._C.Node):
646        def to_float_tensor(t):
647            return t.to(dtype=torch.float).item()
648
649        inp_list = [
650            self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
651        ]  # noqa: C416
652        fx_node = self.fx_graph.call_function(
653            to_float_tensor,
654            tuple(inp_list),
655        )
656        self.name_to_node[node.output().debugName()] = fx_node
657
658    def convert_aten_tensor(self, node: torch._C.Node):
659        """aten::tensor creates a constant tensor ad-hoc --> GetAttr"""
660        args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema)
661
662        for k in kwargs:
663            if k == "requires_grad":
664                kwargs[k] = bool(kwargs[k])  # 0 -> False, 1 -> True
665
666        to_tensor = (
667            torch.tensor
668            if all(isinstance(a, int) for a in args)
669            else torch._refs.tensor
670        )
671
672        def target(*args, **kwargs):
673            if "dtype" in kwargs and kwargs["dtype"] is not None:
674                kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]]
675            return to_tensor(*args, **kwargs)
676
677        # def to_dynamic_tensor(*args, **kwargs):
678        #     if "dtype" in kwargs and kwargs["dtype"] is not None:
679        #         kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]]
680        #     return torch._refs.tensor(*args, **kwargs)
681
682        output_name = node.output().debugName()
683        fx_node = self.fx_graph.call_function(target, args, kwargs)
684        self.name_to_node[output_name] = fx_node
685
686    def convert_aten_append(self, node: torch._C.Node):
687        # special handle python list append: "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)"
688
689        # inplace append to the list!! This is kinda crazy, as we are inplace mutating the list
690        # This makes the converter "non-functional", and the result depends on the order of the nodes being converter
691        # In a sense, the converter now becomes an stateful interpreter
692        warnings.warn(
693            "Converting aten::append.t, which is a inplace mutation of the list. "
694            "This makes the converter non-functional: the result depends on the order of the append nodes being converter!"
695        )
696
697        args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs())
698        fx_node = self.fx_graph.call_function(list_append, args)
699        self.name_to_node[node.output().debugName()] = fx_node
700
701        # inplace mutate arg[0], which is the python list
702        self.name_to_node[node.inputsAt(0).debugName()] = fx_node
703
704        # Variables that need to be updated to parent module.
705        if not self.is_top_level_graph() and args[0].op == "placeholder":
706            self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName())
707
708    def convert_prim_Constant(self, node: torch._C.Node):
709        name = node.output().debugName()
710
711        value: Any = None
712        if node.hasAttribute("value"):
713            constant_kind = node.kindOf("value")
714            if constant_kind == "i":
715                value = node.i("value")
716            elif constant_kind == "f":
717                value = node.f("value")
718            elif constant_kind == "s":
719                value = node.s("value")
720            elif constant_kind == "t":
721                alias_name = (
722                    f"lifted_tensor_{name}"  # Follow naming convention from EP tracing.
723                )
724                fx_node = self.fx_graph.get_attr(alias_name)
725                self.name_to_node[name] = fx_node
726                name, value = alias_name, node.t("value")
727            elif constant_kind == "ival":
728                value = node.ival("value")
729            else:
730                raise ValueError(f"Unsupported constant type: {node.kindOf('value')}")
731        else:
732            value = None
733
734        self.name_to_constant[name] = value
735
736    def convert_prim_CallMethod(self, node: torch._C.Node):
737        inp_list = [
738            self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
739        ]  # noqa: C416
740        fx_node = self.fx_graph.call_method(
741            node.s("name"),
742            tuple(inp_list),
743        )
744        self.name_to_node[node.output().debugName()] = fx_node
745
746    def convert_prim_device(self, node: torch._C.Node):
747        input_type = node.input().type()
748        if input_type.isSubtypeOf(torch._C.TensorType.get()):
749            device = input_type.device()  # type: ignore[attr-defined]
750            output_name = node.output().debugName()
751            self.name_to_constant[output_name] = device
752        else:
753            raise ValueError(f"Unsupported JitType ({input_type}) when get device")
754
755    def convert_prim_GetAttr(self, node: torch._C.Node):
756        # Build fully qulified name
757        attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
758        output_name = node.output().debugName()
759        self.name_to_attribute_fqn[output_name] = attr_fqn
760
761        if self.is_top_level_graph():
762            if self._is_get_attr_node(attr_fqn):
763                # We insert a get_attr node due to two reasons.
764                # First, ts graph does not lift tensor constants as input nodes. So
765                # tensor constants may be ignored by in convert_graph_inputs().
766                # Second, attr_fqn may have been written to via SetAttr. Two
767                # GetAttr may give different values.
768                self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn)
769            else:
770                if attr_fqn not in self.name_to_non_tensor_attribute_node:
771                    self.name_to_non_tensor_attribute_node[
772                        attr_fqn
773                    ] = self.name_to_non_tensor_attribute[attr_fqn]
774                self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[
775                    attr_fqn
776                ]
777        else:
778            # Special support for if blocks which do not allow SetAttr TorchScript
779            # node and get_attr FX Graph Node.
780            if self._is_get_attr_node(attr_fqn):
781                self.name_to_node[output_name] = self.name_to_node[attr_fqn]
782
783    def convert_prim_SetAttr(self, node: torch._C.Node):
784        attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
785        attr_value = tuple(node.inputs())[1]
786        ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value)
787        if self._is_get_attr_node(attr_fqn):
788            fx_attr_node = self.fx_graph.get_attr(attr_fqn)
789            self.fx_graph.call_function(
790                torch.Tensor.copy_, (fx_attr_node, ts_graph_tensor_input)
791            )
792        else:
793            self.name_to_non_tensor_attribute_node[attr_fqn] = ts_graph_tensor_input
794
795    def convert_call_function_op(self, node: torch._C.Node):
796        target = get_op_overload(node)
797
798        args, kwargs = self.get_args_kwargs(node, target._schema)
799
800        fx_node = self.fx_graph.call_function(target, args, kwargs)
801
802        # TODO: covnert sourceRange() into stack_trace
803        # fx_node.meta["stack_trace"] = node.sourceRange()
804
805        if node.outputsSize() == 1:
806            output_name = node.output().debugName()
807            self.name_to_node[output_name] = fx_node
808        else:
809            for i, outp in enumerate(node.outputs()):
810                output_name = outp.debugName()
811                next_fx_node = self.fx_graph.call_function(
812                    operator.getitem, (fx_node, i)
813                )
814                self.name_to_node[output_name] = next_fx_node
815
816    def convert_prim_TupleConstruct(self, node: torch._C.Node):
817        self._convert_prim_iterator(node)
818
819    def convert_prim_ListConstruct(self, node: torch._C.Node):
820        self._convert_prim_iterator(node)
821
822    def _convert_prim_iterator(self, node: torch._C.Node):
823        output_list = []
824        for inp in node.inputs():
825            output_list.append(self.get_fx_value_by_ir_value(inp))
826
827        output_name = node.output().debugName()
828        self.name_to_node[output_name] = output_list
829
830    def convert_prim_DictConstruct(self, node: torch._C.Node):
831        output_dict = {}
832        k, v = None, None
833        for i, inp in enumerate(node.inputs()):
834            # We assume key value are stored in pair in the DictConstruct.
835            # The first element is the key and the following is the value.
836            if i % 2 == 0:
837                k = self.get_fx_value_by_ir_value(inp)
838            else:
839                v = self.get_fx_value_by_ir_value(inp)
840                assert (
841                    k is not None and v is not None
842                ), "DictConstruct has an empty key value pair."
843                output_dict[k] = v
844                k, v = None, None
845
846        assert (
847            k is None and v is None
848        ), "DictConstruct has an odd number of elements (violating our assumption)."
849
850        output_name = node.output().debugName()
851        self.name_to_node[output_name] = output_dict
852
853    def convert_prim_ListUnpack(self, node: torch._C.Node):
854        self._convert_prim_unpack_iterator(node)
855
856    def convert_prim_TupleUnpack(self, node: torch._C.Node):
857        self._convert_prim_unpack_iterator(node)
858
859    def _convert_prim_unpack_iterator(self, node: torch._C.Node):
860        # Single input and multiple outputs for unpacking.
861        for i, outp in enumerate(node.outputs()):
862            outp_name = outp.debugName()
863            inp = self.get_fx_value_by_ir_value(node.input())
864            fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
865            self.name_to_node[outp_name] = fx_node
866
867    def convert_aten_Int(self, node: torch._C.Node):
868        # converts aten::Int as aten._to_copy + aten::_local_scalar_dense
869        target = torch.ops.aten._to_copy.default
870        args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
871        to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
872
873        fx_node = self.fx_graph.call_function(
874            torch.ops.aten._local_scalar_dense.default, (to_copy_node,)
875        )
876
877        # TODO: covnert sourceRange() into stack_trace
878        # fx_node.meta["stack_trace"] = node.sourceRange()
879
880        output_name = node.output().debugName()
881        self.name_to_node[output_name] = fx_node
882
883    def convert_prim_NumToTensor(self, node: torch._C.Node):
884        # Converts prim::NumToTensor as aten.scalar_tensor.
885        # prim::NumToTensor IRs are currently triggered by:
886        # .size() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L950
887        # .numel() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L971
888        # For both of those APIs, torch.jit.trace implicitly sets the output tensor type
889        # to be LongTensor.
890        target = torch.ops.aten.scalar_tensor
891        args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
892
893        fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long})
894        output_name = node.output().debugName()
895        self.name_to_node[output_name] = fx_node
896
897    def convert_prim_CreateObject(self, node: torch._C.Node):
898        output_name = node.output().debugName()
899        self.name_to_attribute_fqn[output_name] = ""
900
901    def convert_aten__convolution(self, node: torch._C.Node):
902        # converts aten::_convolution as aten.convolution, since aten::_convolution
903        # doesn't have a meta function
904        target = torch.ops.aten.convolution.default
905        args, kwargs = self.get_args_kwargs(node, target._schema)
906
907        fx_node = self.fx_graph.call_function(target, args, kwargs)
908
909        output_name = node.output().debugName()
910        self.name_to_node[output_name] = fx_node
911
912    def convert_aten_div(self, node: torch._C.Node):
913        target = get_op_overload(node)
914        schema = target._schema
915
916        args, kwargs = self.get_args_kwargs(node, schema)
917
918        # converts aten::div.Tensor_mode(x, tensor_constant)
919        # as aten.div.Scalar_mode(x, tensor_constant.item())
920        if schema.overload_name == "Tensor_mode":
921            arg1_name = args[1].name
922            if arg1_name in self.name_to_constant and isinstance(
923                self.name_to_constant[arg1_name], torch.Tensor
924            ):
925                tensor_constant = self.name_to_constant[arg1_name]
926                if tensor_constant.numel() == 1:
927                    updated_args = list(args)
928                    updated_args[1] = self.name_to_constant[arg1_name].item()
929
930                    fx_node = self.fx_graph.call_function(
931                        torch.ops.aten.div.Scalar_mode,
932                        tuple(updated_args),
933                        kwargs,
934                    )
935
936                    # TODO: covnert sourceRange() into stack_trace
937                    # fx_node.meta["stack_trace"] = node.sourceRange()
938
939                    output_name = node.output().debugName()
940                    self.name_to_node[output_name] = fx_node
941                    return
942
943        self.convert_call_function_op(node)
944
945    def convert_aten___getitem__(self, node: torch._C.Node):
946        input_container, index = tuple(
947            self.get_fx_value_by_ir_value(input) for input in node.inputs()
948        )
949        fx_node = self.fx_graph.call_function(
950            operator.getitem, (input_container, index)
951        )
952        output_name = node.output().debugName()
953        self.name_to_node[output_name] = fx_node
954
955    def convert_aten_to(self, node: torch._C.Node):
956        target = get_op_overload(node)
957        args, kwargs = self.get_args_kwargs(node, target._schema)
958
959        # special handle aten.to.dtype and aten.to.prim_dtype followed by inplace_mutation_op
960        # coz aten.to + inplace_mutation_op pattern would trigger
961        # "cannot mutate tensors with frozen storage" functionalization error.
962        # To work around the issue, we override the copy to be True, so that the output
963        # is for sure not an alias of input
964        if target == torch.ops.aten.to.dtype or target == torch.ops.aten.to.prim_dtype:
965            user_nodes = [use.user for use in node.output().uses()]
966            user_targets = [
967                get_op_overload(user_node)
968                for user_node in user_nodes
969                if user_node.schema() != "(no schema)"
970            ]
971            has_mutable_target = any(
972                target._schema.is_mutable for target in user_targets
973            )
974
975            if has_mutable_target:
976                assert len(args) >= 4
977                new_args = list(args)
978                new_args[3] = True  # copy, override to True
979                fx_node = self.fx_graph.call_function(
980                    torch.ops.aten.to.dtype, tuple(new_args)
981                )
982                # temp hack to work around the issue https://github.com/pytorch/pytorch/issues/131679
983                # When this issue is fixed, the clone node would be no longer needed
984                clone_node = self.fx_graph.call_function(
985                    torch.ops.aten.clone.default, (fx_node,)
986                )
987                output_name = node.output().debugName()
988                self.name_to_node[output_name] = clone_node
989                return
990
991        self.convert_call_function_op(node)
992
993    def convert_aten_add(self, node: torch._C.Node):
994        if node.schema() == "(no schema)":
995            if isinstance(node.inputsAt(0).type(), torch.ListType) and isinstance(
996                node.inputsAt(1).type(), torch.ListType
997            ):
998                target = torch.ops.aten.add.t
999            else:
1000                raise RuntimeError(f"unable to determind the target for {node}")
1001        else:
1002            target = get_op_overload(node)
1003
1004        if target == torch.ops.aten.add.t:
1005            # special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for
1006            # RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'.
1007            args, kwargs = self.get_args_kwargs(node, target._schema)
1008            output_name = node.output().debugName()
1009            self.name_to_node[output_name] = self.fx_graph.call_function(list_add, args)
1010        else:
1011            self.convert_call_function_op(node)
1012
1013    def _check_prim_loop_support(self, node):
1014        inputs = list(node.inputs())
1015
1016        # TODO: (1/N) stage.
1017        if inputs[0].debugName() not in self.name_to_constant:
1018            raise RuntimeError(
1019                "prim::Loop currently cannot run with dynamic value of number of iterations."
1020            )
1021
1022        # Make sure the condition is not updated in the subblock.
1023        subblock = next(node.blocks())
1024        condition_output_name = next(subblock.outputs()).debugName()
1025        for node in subblock.nodes():
1026            if (
1027                node.outputsSize() == 1
1028                and node.output().debugName() == condition_output_name
1029            ):
1030                raise RuntimeError(
1031                    "prim::Loop currently cannot run with dynamic value of condition."
1032                )
1033            if node.outputsSize() >= 2:
1034                for outp in node.outputs():
1035                    if outp.debugName() == condition_output_name:
1036                        raise RuntimeError(
1037                            "prim::Loop currently cannot run with dynamic value of condition."
1038                        )
1039
1040    def convert_prim_Loop(self, node: torch._C.Node):
1041        inputs = list(node.inputs())
1042        self._check_prim_loop_support(node)
1043
1044        num_iterations = self.get_fx_value_by_ir_value(inputs[0])
1045
1046        # Find inputs.
1047        loop_local_arguments = [inp.debugName() for inp in inputs[2:]]
1048
1049        global_arguments = self._identify_inputs_as_arguments(node)
1050
1051        # Lift parameters as inputs.
1052        for block in node.blocks():
1053            global_arguments = global_arguments.union(
1054                self.blocks_to_lifted_attrs[block]
1055            )
1056
1057        global_arguments = list(global_arguments)
1058
1059        subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph(
1060            node, global_arguments
1061        )
1062
1063        assert len(subgraph_nodes) == 1
1064        subgraph_converter = subgraph_converters[0]
1065        if not self.is_top_level_graph():
1066            self.name_update_from_subblock_to_parent = (
1067                self.name_update_from_subblock_to_parent.union(
1068                    subgraph_converter.name_update_from_subblock_to_parent
1069                )
1070            )
1071
1072        fx_block_args = [
1073            self.get_fx_value_by_fqn(name)
1074            for name in loop_local_arguments + global_arguments
1075        ]
1076        for iter_idx in range(num_iterations):
1077            loop_node = self.fx_graph.call_function(
1078                execute_subgraph_from_prim_loop,
1079                # Check execute_node function for the expected arguments order.
1080                (
1081                    subgraph_nodes[0],
1082                    iter_idx,
1083                    len(loop_local_arguments),
1084                    *fx_block_args,
1085                ),
1086                {},
1087            )
1088
1089            # Update the value of loop local variables.
1090            if node.outputsSize() >= 1:
1091                for i, outp in enumerate(node.outputs()):
1092                    output_name = outp.debugName()
1093                    self.name_to_node[output_name] = self.fx_graph.call_function(
1094                        operator.getitem,
1095                        (
1096                            loop_node,
1097                            i + 1,
1098                        ),  # + 1 because the 0th element is the condition.
1099                    )
1100                    fx_block_args[i] = self.name_to_node[output_name]
1101
1102            # Update the value of global variables, whose values are modified inplace.
1103            for i, name in enumerate(
1104                subgraph_converter.name_update_from_subblock_to_parent
1105            ):
1106                self.name_to_node[name] = self.fx_graph.call_function(
1107                    operator.getitem,
1108                    (
1109                        loop_node,
1110                        i + node.outputsSize() + 1,
1111                    ),  # + 1 because the 0th element is the condition.
1112                )
1113                global_argument_index = global_arguments.index(name)
1114                fx_block_args[
1115                    i + node.outputsSize() + global_argument_index
1116                ] = self.name_to_node[name]
1117
1118    def _check_set_attr_in_if_block(self, if_node: torch._C.Node):
1119        for block in if_node.blocks():
1120            for node in block.nodes():
1121                if node.kind() == "prim::SetAttr":
1122                    raise RuntimeError(
1123                        "During converting prim::If to torch.cond, found prim::SetAttr op"
1124                        " which is not supported yet. Please file an issue if you come "
1125                        "across this error."
1126                    )
1127
1128    def convert_prim_If(self, node: torch._C.Node):
1129        self._check_set_attr_in_if_block(node)
1130
1131        inputs = list(node.inputs())
1132        assert len(inputs) == 1
1133        predicate = self.get_fx_value_by_ir_value(inputs[0])
1134
1135        # Find inputs.
1136        arguments = self._identify_inputs_as_arguments(node)
1137
1138        # Lift parameters as inputs.
1139        for block in node.blocks():
1140            arguments = arguments.union(self.blocks_to_lifted_attrs[block])
1141
1142        arguments = list(arguments)
1143        subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments)
1144
1145        assert len(subgraph_nodes) == 2
1146
1147        fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments]
1148
1149        args = (
1150            predicate,
1151            subgraph_nodes[0],
1152            subgraph_nodes[1],
1153            tuple(fx_block_args),
1154        )
1155
1156        cond_node = self.fx_graph.call_function(torch.cond, args, {})
1157
1158        # prim::If can also have zero output.
1159        if node.outputsSize() == 1:
1160            output_name = node.output().debugName()
1161            self.name_to_node[output_name] = cond_node
1162        elif node.outputsSize() > 1:
1163            for i, output in enumerate(node.outputs()):
1164                output_name = output.debugName()
1165                getitem = self.fx_graph.call_function(operator.getitem, (cond_node, i))
1166                self.name_to_node[output_name] = getitem
1167
1168    def convert_aten_Bool(self, node: torch._C.Node):
1169        self._convert_as_noop(node)
1170
1171    def convert_prim_Enter(self, node: torch._C.Node):
1172        # export generally treats prim::Enter as noop
1173        # The only context manager export supports is aten::enable_grad.
1174        # Unfortunately, TorchScript does not support aten::enable_grad yet.
1175        # TODO: support aten::enable_grad in both TorchScript and Converter.
1176        return
1177
1178    def convert_prim_Exit(self, node: torch._C.Node):
1179        # export treats prim::Exit as noop
1180        return
1181
1182    def _convert_as_noop(self, node: torch._C.Node):
1183        # Converts the node as a no-op by mapping its output node as arg[0]
1184
1185        target = get_op_overload(node)
1186        schema = target._schema
1187
1188        args, kwargs = self.get_args_kwargs(node, schema)
1189
1190        output_name = node.output().debugName()
1191        self.name_to_node[output_name] = args[0]
1192
1193    def convert_profiler__record_function_exit(self, node: torch._C.Node):
1194        # _record_function_exit has side effect so we keep it in fx.graph
1195        # currently, _record_function_enter_new and _record_function_exit are
1196        # discarded during `retrace_as_exported_program`.
1197        target = torch.ops.profiler._record_function_exit
1198        args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
1199        self.fx_graph.call_function(target, args)
1200
1201    def convert_prim_tolist(self, node: torch._C.Node):
1202        # prim::tolist cannot be supported by `_convert_standard_operators`
1203        # since it requires call_method instead of call_function.
1204        target = "tolist"
1205        args = (self.get_fx_value_by_ir_value(next(node.inputs())),)
1206        fx_node = self.fx_graph.call_method(target, args)
1207        output_name = node.output().debugName()
1208        self.name_to_node[output_name] = fx_node
1209
1210    def convert_prim_Uninitialized(self, node: torch._C.Node):
1211        # `prim::Uninitialized` is inserted by the compiler when it can prove
1212        # the value will never be used. It can be introduced by exceptions,
1213        # breaks, continues, and returns.
1214        # So we add a dummy constant to the graph.
1215        output_name = node.output().debugName()
1216        self.name_to_constant[output_name] = torch.Tensor()
1217
1218    def _convert_standard_operators(self, node: torch._C.Node):
1219        target = kind_to_standard_operators[node.kind()]
1220        args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
1221        fx_node = self.fx_graph.call_function(target, args)
1222        output_name = node.output().debugName()
1223        self.name_to_node[output_name] = fx_node
1224
1225    def convert_node(self, node: torch._C.Node):
1226        node_kind = node.kind()
1227
1228        # Get handler based on namespace and operator name.
1229        # Provide a default node handler as well in case we don't find
1230        # matching converter for that.
1231        handler_func_name = ir_name_to_func_name(node_kind)
1232        handler_func = getattr(self, handler_func_name, self.convert_call_function_op)
1233
1234        # str calls print function implemented in CPP. To avoid repeating
1235        # the entire logic here, we simply keep first line from node string (getting rid
1236        # of sub-blocks IR prints).
1237        node_str = "".join(str(node).split("\n")[:1])
1238        log.debug("[%s] converts [%s]", handler_func.__name__, node_str)
1239        try:
1240            handler_func(node)
1241        except Exception as e:
1242            raise RuntimeError(f"TS2EPConverter failed for node {node_kind}") from e
1243
1244    def convert_graph_outputs(self):
1245        args = []
1246        outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list(
1247            self.name_update_from_subblock_to_parent
1248        )
1249        for output_name in outp_name_list:
1250            if output_name in self.name_to_node:
1251                fx_node = self.name_to_node[output_name]
1252                # TODO: Revisit this later after HigherOrderOp design changes.
1253                # Currently, we cannot directly return input as output.
1254                if (
1255                    not self.is_top_level_graph()
1256                    and isinstance(fx_node, torch.fx.Node)
1257                    and fx_node.op == "placeholder"
1258                ):
1259                    fx_node = self.fx_graph.call_function(torch.clone, (fx_node,))
1260                args.append(fx_node)
1261                self.output_specs.append(
1262                    OutputSpec(
1263                        OutputKind.USER_OUTPUT,
1264                        arg=TensorArgument(name=output_name),
1265                        target=output_name,
1266                    )
1267                )
1268            elif output_name in self.name_to_constant:
1269                args.append(self.name_to_constant[output_name])
1270                self.output_specs.append(
1271                    OutputSpec(
1272                        OutputKind.USER_OUTPUT,
1273                        arg=ConstantArgument(
1274                            name=output_name, value=self.name_to_constant[output_name]
1275                        ),
1276                        target=output_name,
1277                    )
1278                )
1279            else:
1280                raise ValueError(f"Output {output_name} not found")
1281
1282        if len(args) == 0:
1283            # Sub-block of prim::If can have zero output.
1284            self.fx_graph.output([])
1285        elif len(args) == 1:
1286            self.fx_graph.output(
1287                args[0]
1288            )  # Get rid of an extra list wrapped around final output.
1289        elif len(args) > 1:
1290            self.fx_graph.output(
1291                args
1292            )  # For prim::Loop and prim::If with multiple outputs.
1293        else:
1294            # Sub-block of prim::Loop can have multiple outputs.
1295            self.fx_graph.output(args)
1296
1297
1298class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
1299    """
1300    Run TS2FXGraphConverter in an explain mode. It collects all failed operators conversions
1301    and provide that information to users. In order to collect all failed conversions, it
1302    also mocks some internal attributes (e.g., name_to_node).
1303    """
1304
1305    class _DictMock(dict):
1306        def __init__(self, dict_data, mock_value):
1307            super().__init__(dict_data)
1308            self.mock_value = mock_value
1309
1310        def __getitem__(self, key):
1311            # If the original dictionary has the key, return its value.
1312            # Otherwise, return the mock value.
1313            if not super().__contains__(key):
1314                return self.mock_value
1315            return super().__getitem__(key)
1316
1317        def __contains__(self, key):
1318            return True
1319
1320    def __init__(
1321        self,
1322        ts_graph: Union[torch._C.Graph, torch._C.Block],
1323        name_to_param: Dict[str, torch.Tensor],
1324        name_to_buffer: Dict[str, torch.Tensor],
1325        blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
1326        name_to_non_tensor_attribute: Dict[str, Any],
1327        name_to_constant: Dict[str, Any],
1328    ):
1329        super().__init__(
1330            ts_graph,
1331            name_to_param,
1332            name_to_buffer,
1333            blocks_to_lifted_attrs,
1334            name_to_non_tensor_attribute,
1335            name_to_constant,
1336        )
1337
1338        # Data to keep track of unsupported nodes.
1339        self.unsupported_node_list: List[torch._C.Node] = []
1340
1341        # Add mock to needed attributes.
1342        self.name_to_node = ExplainTS2FXGraphConverter._DictMock(
1343            self.name_to_node,
1344            # Dummy node.
1345            torch.fx.Node(
1346                None,  # type: ignore[arg-type]
1347                "mock",
1348                "call_function",
1349                lambda: None,
1350                (),
1351                {},
1352            ),
1353        )
1354
1355    def explain(self):
1356        self.convert_graph_inputs()
1357        for node in self.ts_graph.nodes():
1358            self.convert_node(node)
1359        self.convert_graph_outputs()
1360
1361    def convert_node(self, node):
1362        try:
1363            super().convert_node(node)
1364        except Exception as e:
1365            self.unsupported_node_list.append(node)
1366
1367
1368@contextmanager
1369def disable_logging(log):
1370    disabled = log.disabled
1371    log.disabled = True
1372    try:
1373        yield
1374    finally:
1375        log.disabled = disabled
1376
1377
1378class TS2EPConverter:
1379    # TorchScript model to ExportedProgram converter
1380    def __init__(
1381        self,
1382        ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction],
1383        sample_args: Tuple[Any, ...],
1384        sample_kwargs: Optional[Dict[str, Any]] = None,
1385    ):
1386        self.ts_model = ts_model
1387        self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
1388
1389        self.sample_args = sample_args
1390        self.sample_kwargs = sample_kwargs
1391
1392        self.name_to_param: Dict[str, torch.Tensor] = {}
1393        self.name_to_buffer: Dict[str, torch.Tensor] = {}
1394        param_list = (
1395            list(self.ts_model.parameters())
1396            if not isinstance(self.ts_model, torch._C.ScriptFunction)
1397            else []
1398        )
1399        if not isinstance(self.ts_model, torch._C.ScriptFunction):
1400            for k, tensor in self.ts_model.state_dict().items():  # type: ignore[union-attr]
1401                # Check if tensor belongs to any parameter.
1402                if any(
1403                    (tensor == param).all()
1404                    for param in param_list
1405                    if tensor.shape == param.shape
1406                ):
1407                    self.name_to_param[k] = tensor
1408                else:
1409                    self.name_to_buffer[k] = tensor
1410
1411        self.name_to_non_tensor_attributes: Dict[str, Any] = {}
1412        self.name_to_constant: Dict[str, Any] = {}
1413
1414        self.lift_get_attr()
1415
1416    def convert(self) -> ExportedProgram:
1417        log.info(
1418            """
1419TS2EPConverter logging starts from here.
1420
1421INFO: (TORCH_LOGS="export" <cmd>)
1422    * Log TorchScript IR.
1423
1424DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
1425    * Log conversion IR by IR in a format of [<conversion handler name>] converts [<IR>].
1426        """
1427        )
1428        log.info("TorchScript graph\n\n%s\n", self.ts_graph)
1429
1430        blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
1431
1432        graph_converter = TS2FXGraphConverter(
1433            self.ts_graph,
1434            self.name_to_param,
1435            self.name_to_buffer,
1436            blocks_to_lifted_attrs,
1437            self.name_to_non_tensor_attributes,
1438            self.name_to_constant,
1439        )
1440        gm = graph_converter.convert()
1441
1442        # Post-proccessing step to deal with quantized operators.
1443        replace_quantized_ops_with_standard_ops(gm)
1444        log.info("GraphModule: %s", gm.print_readable(print_output=False))
1445
1446        ep = self.retrace_as_exported_program(
1447            gm,
1448            graph_converter.name_to_constant,
1449        )
1450        log.info("%s", ep)
1451
1452        # Post-processing step to ensure ExportedProgram has the same state_dict as
1453        # the original TorchScript model. Throw warnings for additionally populated
1454        # state_dict entries.
1455        if not isinstance(self.ts_model, torch._C.ScriptFunction):
1456            for k, tensor in self.ts_model.state_dict().items():  # type: ignore[union-attr]
1457                if k not in ep.state_dict:
1458                    warnings.warn(
1459                        f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram."
1460                    )
1461                    ep.state_dict[k] = tensor
1462
1463        return ep
1464
1465    @disable_logging(log)
1466    def explain(self, print_output=True):
1467        blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
1468
1469        graph_converter = ExplainTS2FXGraphConverter(
1470            self.ts_graph,
1471            self.name_to_param,
1472            self.name_to_buffer,
1473            blocks_to_lifted_attrs,
1474            self.name_to_non_tensor_attributes,
1475            self.name_to_constant,
1476        )
1477        graph_converter.explain()
1478        if len(graph_converter.unsupported_node_list) > 0:
1479            explain_str = "Unsupported nodes are found in the following list:"
1480            for i, n in enumerate(graph_converter.unsupported_node_list):
1481                node_str = "".join(str(n).split("\n")[:1])
1482                explain_str += f"\n\n    {i}. {n.kind()} [{node_str}]"
1483        else:
1484            explain_str = "Success!"
1485        if print_output:
1486            print(explain_str)
1487        return explain_str
1488
1489    def retrace_as_exported_program(
1490        self,
1491        gm: torch.fx.GraphModule,
1492        name_to_constant: Dict[str, Any],
1493    ):
1494        # TODO: adjust input orders to match GraphSignature convention
1495        ep = torch.export._trace._export(
1496            gm,
1497            self.sample_args,
1498            strict=False,
1499            pre_dispatch=True,
1500        )
1501
1502        # Post-processing to make sure the ExportedProgram states are correct.
1503        # Because during conversion, we set tensor constants as GetAttr,
1504        # retracing cannot recognize them as tensor constants but instead
1505        # treat them as buffers. We need to set them again here.
1506        ep._constants.update(
1507            {
1508                k: v
1509                for k, v in name_to_constant.items()
1510                if isinstance(v, (torch.Tensor, torch.ScriptObject))
1511            }
1512        )
1513        for k in name_to_constant:
1514            ep.state_dict.pop(k, None)
1515
1516        for i, spec in enumerate(ep.graph_signature.input_specs):
1517            # Mark as constant tensors for erroneously traced buffers.
1518            if spec.kind == InputKind.BUFFER and spec.target in name_to_constant:
1519                assert isinstance(
1520                    name_to_constant[spec.target], torch.Tensor
1521                ), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
1522                spec.kind = InputKind.CONSTANT_TENSOR
1523        ep.verifier().check(ep)
1524
1525        return ep
1526
1527    def lift_get_attr(self):
1528        # This function lifts multiple data types.
1529
1530        #     1. Tensor constants attributes (e.g., self.data = torch.tensor([2,3]))
1531        #     to buffers. Currently, when there are tensor constants, export
1532        #     would error and ask users to register tensor constants as buffers.
1533        #     Since it is hard to manually do so for TorchScript models
1534        #     (e.g., source code is missing), this function automatically
1535        #     lifts tensor constants to be buffers.
1536
1537        #     2. ScriptObbject to constant. It will then be converted to getattr in
1538        #     in the fx graph.
1539        #
1540        # This function should happen in TS2EPConverter instead of
1541        # TS2FXGraphConverter since it gets attributes from self.ts_model
1542        # which is not accessable in TS2FXGraphConverter. It is similar to where
1543        # we collect self.name_to_param and self.name_to_buffer.
1544        name_to_attribute_fqn: Dict[str, str] = {}
1545
1546        def get_attr(fqn: str):
1547            name = fqn.split(".")
1548            v = self.ts_model
1549            for n in name:
1550                v = getattr(v, n)
1551            return v
1552
1553        def get_fqn(node: torch._C.Node):
1554            attr_name = node.s("name")
1555            input_name = node.input().debugName()
1556            root_attr_name = name_to_attribute_fqn[input_name]
1557            attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
1558            return attr_fqn
1559
1560        def _dfs_get_attr(block):
1561            for node in block.nodes():
1562                if node.kind() == "prim::CreateObject":
1563                    output_name = node.output().debugName()
1564                    name_to_attribute_fqn[output_name] = ""
1565
1566                if node.kind() == "prim::GetAttr":
1567                    attr_fqn = get_fqn(node)
1568                    value = get_attr(attr_fqn)
1569                    output_name = node.output().debugName()
1570                    name_to_attribute_fqn[output_name] = attr_fqn
1571                    if isinstance(value, torch.Tensor):
1572                        if attr_fqn not in self.name_to_buffer:
1573                            # Lift tensor constants to be a buffer
1574                            self.name_to_buffer[attr_fqn] = value
1575                    elif isinstance(value, torch.ScriptObject):
1576                        if attr_fqn not in self.name_to_constant:
1577                            self.name_to_constant[attr_fqn] = value
1578                    else:
1579                        self.name_to_non_tensor_attributes[attr_fqn] = value
1580
1581                for subblock in node.blocks():
1582                    _dfs_get_attr(subblock)
1583
1584        _dfs_get_attr(self.ts_graph)
1585