xref: /aosp_15_r20/external/executorch/exir/emit/_emitter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker"""Takes an ExportedArtifact, or a collection of ExportedArtifacts, in execution dialect, and turns
8*523fa7a6SAndroid Build Coastguard Workerthem into a single ExecuTorch Program.
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard WorkerThe provided ExportedArtifact's graph modules are in execution dialect and the emitter parses and
11*523fa7a6SAndroid Build Coastguard Workerconverts them into executorch instructions. The emitter walks the provided graphs and as it
12*523fa7a6SAndroid Build Coastguard Workerencounters concrete values such as tensors or ints, it converts them to the serialized format and
13*523fa7a6SAndroid Build Coastguard Workerstores them in a list for later use. The emitter walks the graph by traversing fx.nodes, these can
14*523fa7a6SAndroid Build Coastguard Workercome in a variety of forms and are the primitives of execution at the graph module level. The most
15*523fa7a6SAndroid Build Coastguard Workercommon 3 we care about are 'call_function', 'place_holder', and 'output'. 'placeholder' and 'output'
16*523fa7a6SAndroid Build Coastguard Workerhandle io for the module and 'call_function' handles everything else. Within 'call_function' we may
17*523fa7a6SAndroid Build Coastguard Workerencounter an operator or delegate call, in such case we parse the schema and emit all the inputs and
18*523fa7a6SAndroid Build Coastguard Workeroutputs (unless they have already previously been emitted), and then we convert the actual function
19*523fa7a6SAndroid Build Coastguard Workercall into an executorch instruction such as KernelCall or DelegateCall.
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard WorkerWhen control flow is present in the graphmodule it will take the form of a few different types of
22*523fa7a6SAndroid Build Coastguard Worker'call_function'. Today (June 14th 2023) only cond and map are supported. The actual operations of
23*523fa7a6SAndroid Build Coastguard Workerthese, such as the true/false branches of cond, or the mapping function of map, are stored as sub
24*523fa7a6SAndroid Build Coastguard Workergraphmodules. When these are encountered during emission, the emitter will recursively emit them and
25*523fa7a6SAndroid Build Coastguard Workertheir instructions.
26*523fa7a6SAndroid Build Coastguard Worker"""
27*523fa7a6SAndroid Build Coastguard Worker# TODO(jakeszwe): add information here about how weights and other parameters are handled in the
28*523fa7a6SAndroid Build Coastguard Worker# presence of aot autograd param lifting.
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
31*523fa7a6SAndroid Build Coastguard Workerimport ctypes
32*523fa7a6SAndroid Build Coastguard Workerimport hashlib
33*523fa7a6SAndroid Build Coastguard Workerimport operator
34*523fa7a6SAndroid Build Coastguard Workerimport typing
35*523fa7a6SAndroid Build Coastguard Workerimport warnings
36*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field
37*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, Union
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.memory as memory
40*523fa7a6SAndroid Build Coastguard Workerimport executorch.extension.pytree as ex_pytree
41*523fa7a6SAndroid Build Coastguard Workerimport torch
42*523fa7a6SAndroid Build Coastguard Workerimport torch.fx
43*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate, is_lowered_module
44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.backend._ops import BackendOpOverload
45*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload
46*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, InternalError
47*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.operator.convert import is_out_variant
48*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.executorch_prim_ops_registry import is_sym_op
49*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.print_program import _stacktrace_to_framelist, inspect_node
50*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import (
51*523fa7a6SAndroid Build Coastguard Worker    BackendDelegate,
52*523fa7a6SAndroid Build Coastguard Worker    BackendDelegateDataReference,
53*523fa7a6SAndroid Build Coastguard Worker    BackendDelegateInlineData,
54*523fa7a6SAndroid Build Coastguard Worker    Bool,
55*523fa7a6SAndroid Build Coastguard Worker    BoolList,
56*523fa7a6SAndroid Build Coastguard Worker    Buffer,
57*523fa7a6SAndroid Build Coastguard Worker    Chain,
58*523fa7a6SAndroid Build Coastguard Worker    ContainerMetadata,
59*523fa7a6SAndroid Build Coastguard Worker    DataLocation,
60*523fa7a6SAndroid Build Coastguard Worker    DelegateCall,
61*523fa7a6SAndroid Build Coastguard Worker    Double,
62*523fa7a6SAndroid Build Coastguard Worker    DoubleList,
63*523fa7a6SAndroid Build Coastguard Worker    EValue,
64*523fa7a6SAndroid Build Coastguard Worker    ExecutionPlan,
65*523fa7a6SAndroid Build Coastguard Worker    FreeCall,
66*523fa7a6SAndroid Build Coastguard Worker    Instruction,
67*523fa7a6SAndroid Build Coastguard Worker    Int,
68*523fa7a6SAndroid Build Coastguard Worker    IntList,
69*523fa7a6SAndroid Build Coastguard Worker    JumpFalseCall,
70*523fa7a6SAndroid Build Coastguard Worker    KernelCall,
71*523fa7a6SAndroid Build Coastguard Worker    MoveCall,
72*523fa7a6SAndroid Build Coastguard Worker    Null,
73*523fa7a6SAndroid Build Coastguard Worker    Operator,
74*523fa7a6SAndroid Build Coastguard Worker    OptionalTensorList,
75*523fa7a6SAndroid Build Coastguard Worker    ScalarType,
76*523fa7a6SAndroid Build Coastguard Worker    String,
77*523fa7a6SAndroid Build Coastguard Worker    Tensor,
78*523fa7a6SAndroid Build Coastguard Worker    TensorList,
79*523fa7a6SAndroid Build Coastguard Worker    TensorShapeDynamism,
80*523fa7a6SAndroid Build Coastguard Worker)
81*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import (
82*523fa7a6SAndroid Build Coastguard Worker    AddressSpaceOverflowException,
83*523fa7a6SAndroid Build Coastguard Worker    layout_enum,
84*523fa7a6SAndroid Build Coastguard Worker    make_allocation_info,
85*523fa7a6SAndroid Build Coastguard Worker    make_tensor_value,
86*523fa7a6SAndroid Build Coastguard Worker    memory_format_enum,
87*523fa7a6SAndroid Build Coastguard Worker    scalar_type_enum,
88*523fa7a6SAndroid Build Coastguard Worker    TensorSpec,
89*523fa7a6SAndroid Build Coastguard Worker)
90*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.types import LeafValueSpec, ValueSpec
91*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram
94*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker@dataclass
100*523fa7a6SAndroid Build Coastguard Workerclass _ProgramState:
101*523fa7a6SAndroid Build Coastguard Worker    """State shared between all methods of a program and the graph module it represents.
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker    Initialized once within emit_program and then shared across each entry point as they are
104*523fa7a6SAndroid Build Coastguard Worker    emitted.
105*523fa7a6SAndroid Build Coastguard Worker    """
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker    # Parallel list of specs and the buffers that backed them, have to add + 1 to any index in here
108*523fa7a6SAndroid Build Coastguard Worker    # as index 0 in the constant_buffer is reserved.
109*523fa7a6SAndroid Build Coastguard Worker    allocated_specs: List[TensorSpec] = field(default_factory=list)
110*523fa7a6SAndroid Build Coastguard Worker    # Weights in any arbitrary graph_module only need to compare against weights from previously
111*523fa7a6SAndroid Build Coastguard Worker    # emitted graph modules, not any weights emitted from itself. This should speed up the lookup,
112*523fa7a6SAndroid Build Coastguard Worker    # from O(N) to O(1)
113*523fa7a6SAndroid Build Coastguard Worker    cached_spec_hash_values: Dict[str, int] = field(default_factory=dict)
114*523fa7a6SAndroid Build Coastguard Worker    cached_spec_mutable_hash_values: Dict[str, int] = field(default_factory=dict)
115*523fa7a6SAndroid Build Coastguard Worker    # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder.
116*523fa7a6SAndroid Build Coastguard Worker    constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")])
117*523fa7a6SAndroid Build Coastguard Worker    # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder.
118*523fa7a6SAndroid Build Coastguard Worker    mutable_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")])
119*523fa7a6SAndroid Build Coastguard Worker    # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
120*523fa7a6SAndroid Build Coastguard Worker    # and should be copied to Program.backend_delegate_data.
121*523fa7a6SAndroid Build Coastguard Worker    backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker@dataclass
125*523fa7a6SAndroid Build Coastguard Workerclass _EmitterState:
126*523fa7a6SAndroid Build Coastguard Worker    """State of a single emitter.
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker    Local to at least the entry point, and may be local to a subgraph of an entry point originating
129*523fa7a6SAndroid Build Coastguard Worker    from control flow.
130*523fa7a6SAndroid Build Coastguard Worker    """
131*523fa7a6SAndroid Build Coastguard Worker
132*523fa7a6SAndroid Build Coastguard Worker    values: List[EValue]
133*523fa7a6SAndroid Build Coastguard Worker    operators: List[Operator]
134*523fa7a6SAndroid Build Coastguard Worker    delegates: List[BackendDelegate]
135*523fa7a6SAndroid Build Coastguard Worker    operator_cache: Dict[Tuple[str, str], int]
136*523fa7a6SAndroid Build Coastguard Worker    delegate_cache: Dict[bytes, int]
137*523fa7a6SAndroid Build Coastguard Worker    emit_stacktrace: bool
138*523fa7a6SAndroid Build Coastguard Worker
139*523fa7a6SAndroid Build Coastguard Worker    spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict)
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker    def spec2id(self, spec: TensorSpec) -> int:
142*523fa7a6SAndroid Build Coastguard Worker        """Map a TensorSpec to value index in the values array."""
143*523fa7a6SAndroid Build Coastguard Worker        assert spec in self.spec2id_dict, f"Spec is not found: {spec.debug()}"
144*523fa7a6SAndroid Build Coastguard Worker        return self.spec2id_dict[spec]
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker@dataclass
148*523fa7a6SAndroid Build Coastguard Workerclass _AbstractValue:
149*523fa7a6SAndroid Build Coastguard Worker    """Represents an already emitted EValue"""
150*523fa7a6SAndroid Build Coastguard Worker
151*523fa7a6SAndroid Build Coastguard Worker    # Index in the values table of this EValue.
152*523fa7a6SAndroid Build Coastguard Worker    id: int
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker    # Used for sanity checks for functions that expect to only receive AbstractValues.
155*523fa7a6SAndroid Build Coastguard Worker    tensor: Optional[Tensor]
156*523fa7a6SAndroid Build Coastguard Worker
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker_EmitterValue: TypeAlias = Union[
159*523fa7a6SAndroid Build Coastguard Worker    _AbstractValue, List[_AbstractValue], Tuple[_AbstractValue, ...]
160*523fa7a6SAndroid Build Coastguard Worker]
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Worker_PythonValue: TypeAlias = Union[bool, int, float, torch.Tensor, List["_PythonValue"]]
163*523fa7a6SAndroid Build Coastguard Worker_SchemaType: TypeAlias = Union[
164*523fa7a6SAndroid Build Coastguard Worker    torch.OptionalType,
165*523fa7a6SAndroid Build Coastguard Worker    torch.ListType,
166*523fa7a6SAndroid Build Coastguard Worker    torch.FloatType,
167*523fa7a6SAndroid Build Coastguard Worker    torch.BoolType,
168*523fa7a6SAndroid Build Coastguard Worker    torch.IntType,
169*523fa7a6SAndroid Build Coastguard Worker    torch.StringType,
170*523fa7a6SAndroid Build Coastguard Worker    torch.TensorType,
171*523fa7a6SAndroid Build Coastguard Worker]
172*523fa7a6SAndroid Build Coastguard Worker
173*523fa7a6SAndroid Build Coastguard Worker_Target: TypeAlias = Union[Callable[..., _PythonValue], str]
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker_Argument: TypeAlias = Union[
176*523fa7a6SAndroid Build Coastguard Worker    _EmitterValue,
177*523fa7a6SAndroid Build Coastguard Worker    Tuple["_Argument", ...],
178*523fa7a6SAndroid Build Coastguard Worker    List["_Argument"],
179*523fa7a6SAndroid Build Coastguard Worker    Dict[str, "_Argument"],
180*523fa7a6SAndroid Build Coastguard Worker    str,
181*523fa7a6SAndroid Build Coastguard Worker    int,
182*523fa7a6SAndroid Build Coastguard Worker    float,
183*523fa7a6SAndroid Build Coastguard Worker    bool,
184*523fa7a6SAndroid Build Coastguard Worker    complex,
185*523fa7a6SAndroid Build Coastguard Worker    torch.dtype,
186*523fa7a6SAndroid Build Coastguard Worker    torch.Tensor,
187*523fa7a6SAndroid Build Coastguard Worker    torch.memory_format,
188*523fa7a6SAndroid Build Coastguard Worker    torch.layout,
189*523fa7a6SAndroid Build Coastguard Worker    None,
190*523fa7a6SAndroid Build Coastguard Worker]
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker_DelegateDebugIdentifierMap: TypeAlias = Union[
193*523fa7a6SAndroid Build Coastguard Worker    Dict[int, Tuple[int]], Dict[str, Tuple[int]]
194*523fa7a6SAndroid Build Coastguard Worker]
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore[13]: Attribute `node` is never initialized.
198*523fa7a6SAndroid Build Coastguard Workerclass _Emitter(torch.fx.Interpreter):
199*523fa7a6SAndroid Build Coastguard Worker    """An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the
200*523fa7a6SAndroid Build Coastguard Worker    given traced torch.fx.GraphModule to the flatbuffer schema."""
201*523fa7a6SAndroid Build Coastguard Worker
202*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker    def __init__(
205*523fa7a6SAndroid Build Coastguard Worker        self,
206*523fa7a6SAndroid Build Coastguard Worker        graph_module: torch.fx.GraphModule,
207*523fa7a6SAndroid Build Coastguard Worker        emitter_state: _EmitterState,
208*523fa7a6SAndroid Build Coastguard Worker        program_state: _ProgramState,
209*523fa7a6SAndroid Build Coastguard Worker        instruction_start_offset: int = 0,
210*523fa7a6SAndroid Build Coastguard Worker        binding_input_values: Optional[List[_AbstractValue]] = None,
211*523fa7a6SAndroid Build Coastguard Worker        binding_output_values: Optional[List[_AbstractValue]] = None,
212*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
213*523fa7a6SAndroid Build Coastguard Worker        super().__init__(graph_module)
214*523fa7a6SAndroid Build Coastguard Worker        self.emitter_state = emitter_state
215*523fa7a6SAndroid Build Coastguard Worker        self.program_state = program_state
216*523fa7a6SAndroid Build Coastguard Worker        self.outputs: List[int] = []
217*523fa7a6SAndroid Build Coastguard Worker
218*523fa7a6SAndroid Build Coastguard Worker        self.chain = Chain(
219*523fa7a6SAndroid Build Coastguard Worker            inputs=[],
220*523fa7a6SAndroid Build Coastguard Worker            outputs=[],
221*523fa7a6SAndroid Build Coastguard Worker            instructions=[],
222*523fa7a6SAndroid Build Coastguard Worker            stacktrace=None,
223*523fa7a6SAndroid Build Coastguard Worker        )
224*523fa7a6SAndroid Build Coastguard Worker
225*523fa7a6SAndroid Build Coastguard Worker        if "non_const_buffer_sizes" not in graph_module.meta.keys():
226*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
227*523fa7a6SAndroid Build Coastguard Worker                "Must set 'non_const_buffer_sizes' in graph meta in memory planning pass"
228*523fa7a6SAndroid Build Coastguard Worker            )
229*523fa7a6SAndroid Build Coastguard Worker        self.instruction_start_offset = instruction_start_offset
230*523fa7a6SAndroid Build Coastguard Worker        self.binding_input_values = binding_input_values
231*523fa7a6SAndroid Build Coastguard Worker        self.binding_output_values = binding_output_values
232*523fa7a6SAndroid Build Coastguard Worker        self.graph_module: torch.fx.GraphModule = graph_module
233*523fa7a6SAndroid Build Coastguard Worker        self.nodes: List[torch.fx.Node] = list(self.graph_module.graph.nodes)
234*523fa7a6SAndroid Build Coastguard Worker
235*523fa7a6SAndroid Build Coastguard Worker        # Marks the placeholder node with its order so that we can match them with the corresponding
236*523fa7a6SAndroid Build Coastguard Worker        # Abstract Value coming from top level.
237*523fa7a6SAndroid Build Coastguard Worker        self.placeholder_count = 0
238*523fa7a6SAndroid Build Coastguard Worker
239*523fa7a6SAndroid Build Coastguard Worker        self.concrete_output_ids: List[_AbstractValue] = []
240*523fa7a6SAndroid Build Coastguard Worker        self.debug_handle_map: Dict[int, Union[int, List[int]]] = {}
241*523fa7a6SAndroid Build Coastguard Worker        self.instr_id_to_delegate_debug_id_map: Dict[
242*523fa7a6SAndroid Build Coastguard Worker            int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]
243*523fa7a6SAndroid Build Coastguard Worker        ] = {}
244*523fa7a6SAndroid Build Coastguard Worker
245*523fa7a6SAndroid Build Coastguard Worker    def _emit_node_specific_error(self, node: torch.fx.Node, err_msg: str) -> str:
246*523fa7a6SAndroid Build Coastguard Worker        """Returns 'err_msg' with node specific information attached."""
247*523fa7a6SAndroid Build Coastguard Worker        err_msg = f"Failed with error: {str(err_msg)}\n" + inspect_node(
248*523fa7a6SAndroid Build Coastguard Worker            self.graph_module.graph, node
249*523fa7a6SAndroid Build Coastguard Worker        )
250*523fa7a6SAndroid Build Coastguard Worker        return err_msg
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker    def _internal_assert_emitter(
253*523fa7a6SAndroid Build Coastguard Worker        self, pred: bool, node: torch.fx.Node, assert_msg: str
254*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
255*523fa7a6SAndroid Build Coastguard Worker        """If pred is False, construct and raise a node specific error message."""
256*523fa7a6SAndroid Build Coastguard Worker        if not pred:
257*523fa7a6SAndroid Build Coastguard Worker            raise InternalError(self._emit_node_specific_error(node, assert_msg))
258*523fa7a6SAndroid Build Coastguard Worker
259*523fa7a6SAndroid Build Coastguard Worker    def _emit_int_list(self, val: List[_Argument]) -> EValue:
260*523fa7a6SAndroid Build Coastguard Worker        """Emits a list of integers as a collection of EValues.
261*523fa7a6SAndroid Build Coastguard Worker
262*523fa7a6SAndroid Build Coastguard Worker        For every argument in 'val':
263*523fa7a6SAndroid Build Coastguard Worker            - If it is a concrete value, then emit it and then place its location in the boxed list
264*523fa7a6SAndroid Build Coastguard Worker            - If it is already an abstract value, then just place its location in the boxed list
265*523fa7a6SAndroid Build Coastguard Worker
266*523fa7a6SAndroid Build Coastguard Worker        Int lists are boxed to handle symints whose values are determined at runtime, but could
267*523fa7a6SAndroid Build Coastguard Worker        still end up inside lists for ops like view_copy(Tensor self, SymInt[] size)
268*523fa7a6SAndroid Build Coastguard Worker        """
269*523fa7a6SAndroid Build Coastguard Worker        boxed_list = []
270*523fa7a6SAndroid Build Coastguard Worker        for item in val:
271*523fa7a6SAndroid Build Coastguard Worker            if isinstance(item, _AbstractValue):
272*523fa7a6SAndroid Build Coastguard Worker                boxed_list.append(item.id)
273*523fa7a6SAndroid Build Coastguard Worker            elif isinstance(item, int):
274*523fa7a6SAndroid Build Coastguard Worker                boxed_list.append(
275*523fa7a6SAndroid Build Coastguard Worker                    self._emit_evalue(self._constant_to_evalue(item, None)).id
276*523fa7a6SAndroid Build Coastguard Worker                )
277*523fa7a6SAndroid Build Coastguard Worker            else:
278*523fa7a6SAndroid Build Coastguard Worker                self._internal_assert_emitter(
279*523fa7a6SAndroid Build Coastguard Worker                    False, self.node, "Unsupported type encountered in int list."
280*523fa7a6SAndroid Build Coastguard Worker                )
281*523fa7a6SAndroid Build Coastguard Worker
282*523fa7a6SAndroid Build Coastguard Worker        return EValue(IntList(boxed_list))
283*523fa7a6SAndroid Build Coastguard Worker
284*523fa7a6SAndroid Build Coastguard Worker    def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
285*523fa7a6SAndroid Build Coastguard Worker        """Emits a list type.
286*523fa7a6SAndroid Build Coastguard Worker
287*523fa7a6SAndroid Build Coastguard Worker        Emits the list stored in val. If the list is of Tensors, Optionals, or Ints the emitted list
288*523fa7a6SAndroid Build Coastguard Worker        is boxed, otherwise the values are constant at runtime and stored inline.
289*523fa7a6SAndroid Build Coastguard Worker
290*523fa7a6SAndroid Build Coastguard Worker        NOTE: When symbool and symfloat are supported bool and float lists will be stored boxed.
291*523fa7a6SAndroid Build Coastguard Worker        """
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val_type, torch.BoolType):
294*523fa7a6SAndroid Build Coastguard Worker            return EValue(BoolList(typing.cast(List[bool], val)))
295*523fa7a6SAndroid Build Coastguard Worker
296*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val_type, torch.IntType):
297*523fa7a6SAndroid Build Coastguard Worker            return self._emit_int_list(val)
298*523fa7a6SAndroid Build Coastguard Worker
299*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val_type, torch.FloatType):
300*523fa7a6SAndroid Build Coastguard Worker            return EValue(DoubleList(typing.cast(List[float], val)))
301*523fa7a6SAndroid Build Coastguard Worker
302*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val_type, torch.TensorType):
303*523fa7a6SAndroid Build Coastguard Worker            values = []
304*523fa7a6SAndroid Build Coastguard Worker            for v in val:
305*523fa7a6SAndroid Build Coastguard Worker                assert isinstance(v, _AbstractValue)
306*523fa7a6SAndroid Build Coastguard Worker                self._internal_assert_emitter(
307*523fa7a6SAndroid Build Coastguard Worker                    v.tensor is not None,
308*523fa7a6SAndroid Build Coastguard Worker                    self.node,
309*523fa7a6SAndroid Build Coastguard Worker                    "AbstractValue corresponding to tensor type doesn't contain tensor value",
310*523fa7a6SAndroid Build Coastguard Worker                )
311*523fa7a6SAndroid Build Coastguard Worker                values.append(v.id)
312*523fa7a6SAndroid Build Coastguard Worker            return EValue(TensorList(values))
313*523fa7a6SAndroid Build Coastguard Worker
314*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val_type, torch.OptionalType):
315*523fa7a6SAndroid Build Coastguard Worker            # refine further
316*523fa7a6SAndroid Build Coastguard Worker            actual_type = val_type.getElementType()
317*523fa7a6SAndroid Build Coastguard Worker            if isinstance(actual_type, torch.TensorType):
318*523fa7a6SAndroid Build Coastguard Worker                vals = []
319*523fa7a6SAndroid Build Coastguard Worker                for v in val:
320*523fa7a6SAndroid Build Coastguard Worker                    if v is None:
321*523fa7a6SAndroid Build Coastguard Worker                        vals.append(-1)
322*523fa7a6SAndroid Build Coastguard Worker                    else:
323*523fa7a6SAndroid Build Coastguard Worker                        assert isinstance(v, _AbstractValue)
324*523fa7a6SAndroid Build Coastguard Worker                        vals.append(v.id)
325*523fa7a6SAndroid Build Coastguard Worker                return EValue(OptionalTensorList(vals))
326*523fa7a6SAndroid Build Coastguard Worker
327*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
328*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}"
329*523fa7a6SAndroid Build Coastguard Worker        )
330*523fa7a6SAndroid Build Coastguard Worker
331*523fa7a6SAndroid Build Coastguard Worker    def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
332*523fa7a6SAndroid Build Coastguard Worker        """Constructs an EValue from the given TensorSpec."""
333*523fa7a6SAndroid Build Coastguard Worker
334*523fa7a6SAndroid Build Coastguard Worker        allocation_info = None
335*523fa7a6SAndroid Build Coastguard Worker        buffer_idx = 0
336*523fa7a6SAndroid Build Coastguard Worker
337*523fa7a6SAndroid Build Coastguard Worker        # Need to memory plan
338*523fa7a6SAndroid Build Coastguard Worker        # Some users set mem_id on all tensors and then rely on the
339*523fa7a6SAndroid Build Coastguard Worker        # default algos to set offsets, so need to check both.
340*523fa7a6SAndroid Build Coastguard Worker        if spec.mem_id is not None and spec.mem_offset is not None:
341*523fa7a6SAndroid Build Coastguard Worker            # Tensor is an activation.
342*523fa7a6SAndroid Build Coastguard Worker            self._internal_assert_emitter(
343*523fa7a6SAndroid Build Coastguard Worker                isinstance(spec.mem_id, int) and spec.mem_id >= 0,
344*523fa7a6SAndroid Build Coastguard Worker                self.node,
345*523fa7a6SAndroid Build Coastguard Worker                f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}",
346*523fa7a6SAndroid Build Coastguard Worker            )
347*523fa7a6SAndroid Build Coastguard Worker
348*523fa7a6SAndroid Build Coastguard Worker            self._internal_assert_emitter(
349*523fa7a6SAndroid Build Coastguard Worker                isinstance(spec.mem_offset, int) and spec.mem_offset >= 0,
350*523fa7a6SAndroid Build Coastguard Worker                self.node,
351*523fa7a6SAndroid Build Coastguard Worker                f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}",
352*523fa7a6SAndroid Build Coastguard Worker            )
353*523fa7a6SAndroid Build Coastguard Worker            try:
354*523fa7a6SAndroid Build Coastguard Worker                allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset)
355*523fa7a6SAndroid Build Coastguard Worker            except AddressSpaceOverflowException as e:
356*523fa7a6SAndroid Build Coastguard Worker                raise InternalError(
357*523fa7a6SAndroid Build Coastguard Worker                    self._emit_node_specific_error(
358*523fa7a6SAndroid Build Coastguard Worker                        self.node,
359*523fa7a6SAndroid Build Coastguard Worker                        (
360*523fa7a6SAndroid Build Coastguard Worker                            f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, "
361*523fa7a6SAndroid Build Coastguard Worker                            f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an "
362*523fa7a6SAndroid Build Coastguard Worker                            f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) "
363*523fa7a6SAndroid Build Coastguard Worker                            "during torch.export()."
364*523fa7a6SAndroid Build Coastguard Worker                        ),
365*523fa7a6SAndroid Build Coastguard Worker                    )
366*523fa7a6SAndroid Build Coastguard Worker                )
367*523fa7a6SAndroid Build Coastguard Worker
368*523fa7a6SAndroid Build Coastguard Worker        if spec.const:
369*523fa7a6SAndroid Build Coastguard Worker            # Tensor with a blob we need to serialize. May not actually be constant at runtime
370*523fa7a6SAndroid Build Coastguard Worker            # if it's a weight with an associated gradient
371*523fa7a6SAndroid Build Coastguard Worker            spec_array_type = (
372*523fa7a6SAndroid Build Coastguard Worker                ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes()
373*523fa7a6SAndroid Build Coastguard Worker            )
374*523fa7a6SAndroid Build Coastguard Worker
375*523fa7a6SAndroid Build Coastguard Worker            buffer_data = (
376*523fa7a6SAndroid Build Coastguard Worker                bytes(
377*523fa7a6SAndroid Build Coastguard Worker                    ctypes.cast(
378*523fa7a6SAndroid Build Coastguard Worker                        typing.cast(torch.UntypedStorage, spec.storage).data_ptr(),
379*523fa7a6SAndroid Build Coastguard Worker                        ctypes.POINTER(spec_array_type),
380*523fa7a6SAndroid Build Coastguard Worker                    ).contents
381*523fa7a6SAndroid Build Coastguard Worker                )
382*523fa7a6SAndroid Build Coastguard Worker                if spec.allocated_memory != 0
383*523fa7a6SAndroid Build Coastguard Worker                else b""
384*523fa7a6SAndroid Build Coastguard Worker            )
385*523fa7a6SAndroid Build Coastguard Worker
386*523fa7a6SAndroid Build Coastguard Worker            hashed = hashlib.sha256(buffer_data).hexdigest()
387*523fa7a6SAndroid Build Coastguard Worker
388*523fa7a6SAndroid Build Coastguard Worker            if allocation_info:
389*523fa7a6SAndroid Build Coastguard Worker                buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
390*523fa7a6SAndroid Build Coastguard Worker                    hashed, -1
391*523fa7a6SAndroid Build Coastguard Worker                )
392*523fa7a6SAndroid Build Coastguard Worker            else:
393*523fa7a6SAndroid Build Coastguard Worker                buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)
394*523fa7a6SAndroid Build Coastguard Worker
395*523fa7a6SAndroid Build Coastguard Worker            # Haven't seen this constant before
396*523fa7a6SAndroid Build Coastguard Worker            if buffer_idx == -1:
397*523fa7a6SAndroid Build Coastguard Worker                # Update buffer_idx to point to the end of the list where we are adding the new buffer.
398*523fa7a6SAndroid Build Coastguard Worker                buffer = Buffer(storage=buffer_data)
399*523fa7a6SAndroid Build Coastguard Worker                self.program_state.allocated_specs.append(spec)
400*523fa7a6SAndroid Build Coastguard Worker                # +1 because the first buffer location is reserved
401*523fa7a6SAndroid Build Coastguard Worker
402*523fa7a6SAndroid Build Coastguard Worker                if allocation_info:
403*523fa7a6SAndroid Build Coastguard Worker                    buffer_idx = len(self.program_state.mutable_buffer)
404*523fa7a6SAndroid Build Coastguard Worker                    self.program_state.cached_spec_mutable_hash_values[hashed] = (
405*523fa7a6SAndroid Build Coastguard Worker                        buffer_idx
406*523fa7a6SAndroid Build Coastguard Worker                    )
407*523fa7a6SAndroid Build Coastguard Worker                    self.program_state.mutable_buffer.append(buffer)
408*523fa7a6SAndroid Build Coastguard Worker                else:
409*523fa7a6SAndroid Build Coastguard Worker                    buffer_idx = len(self.program_state.constant_buffer)
410*523fa7a6SAndroid Build Coastguard Worker                    self.program_state.cached_spec_hash_values[hashed] = buffer_idx
411*523fa7a6SAndroid Build Coastguard Worker                    self.program_state.constant_buffer.append(buffer)
412*523fa7a6SAndroid Build Coastguard Worker
413*523fa7a6SAndroid Build Coastguard Worker            if spec.const and spec.nbytes() != len(buffer_data):
414*523fa7a6SAndroid Build Coastguard Worker                raise InternalError(
415*523fa7a6SAndroid Build Coastguard Worker                    self._emit_node_specific_error(
416*523fa7a6SAndroid Build Coastguard Worker                        self.node,
417*523fa7a6SAndroid Build Coastguard Worker                        f"Tensor spec has buffer of size {len(buffer_data)}, but expected nbytes of {spec.nbytes()}",
418*523fa7a6SAndroid Build Coastguard Worker                    )
419*523fa7a6SAndroid Build Coastguard Worker                )
420*523fa7a6SAndroid Build Coastguard Worker
421*523fa7a6SAndroid Build Coastguard Worker        # For constant tensors, allocation_info = None.
422*523fa7a6SAndroid Build Coastguard Worker        return EValue(make_tensor_value(buffer_idx, allocation_info, spec))
423*523fa7a6SAndroid Build Coastguard Worker
424*523fa7a6SAndroid Build Coastguard Worker    def _get_list_tuple_jit_type(
425*523fa7a6SAndroid Build Coastguard Worker        self, val: Union[Tuple[_Argument], List[_Argument]]
426*523fa7a6SAndroid Build Coastguard Worker    ) -> _SchemaType:
427*523fa7a6SAndroid Build Coastguard Worker        """Returns the JIT type for the given python type."""
428*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(
429*523fa7a6SAndroid Build Coastguard Worker            val, (list, tuple)
430*523fa7a6SAndroid Build Coastguard Worker        ), f"Input to _get_list_tuple_jit_type was expected to be an instance of a list or tuple but received {type(val)}"
431*523fa7a6SAndroid Build Coastguard Worker        is_tensor_type = all(
432*523fa7a6SAndroid Build Coastguard Worker            isinstance(v, _AbstractValue) and v.tensor is not None for v in val
433*523fa7a6SAndroid Build Coastguard Worker        )
434*523fa7a6SAndroid Build Coastguard Worker        if is_tensor_type:
435*523fa7a6SAndroid Build Coastguard Worker            return torch.TensorType.get()
436*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val[0], int):
437*523fa7a6SAndroid Build Coastguard Worker            return torch.IntType.get()
438*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val[0], bool):
439*523fa7a6SAndroid Build Coastguard Worker            return torch.BoolType.get()
440*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val[0], float):
441*523fa7a6SAndroid Build Coastguard Worker            return torch.FloatType.get()
442*523fa7a6SAndroid Build Coastguard Worker
443*523fa7a6SAndroid Build Coastguard Worker        raise InternalError(
444*523fa7a6SAndroid Build Coastguard Worker            self._emit_node_specific_error(
445*523fa7a6SAndroid Build Coastguard Worker                self.node,
446*523fa7a6SAndroid Build Coastguard Worker                "Couldn't determine JitType for list/tuple of elements. Only supports int, float, bool, and Tensor.",
447*523fa7a6SAndroid Build Coastguard Worker            )
448*523fa7a6SAndroid Build Coastguard Worker        )
449*523fa7a6SAndroid Build Coastguard Worker
450*523fa7a6SAndroid Build Coastguard Worker    def _constant_to_evalue(  # noqa: C901
451*523fa7a6SAndroid Build Coastguard Worker        self,
452*523fa7a6SAndroid Build Coastguard Worker        val: _Argument,
453*523fa7a6SAndroid Build Coastguard Worker        val_type: Optional[_SchemaType],
454*523fa7a6SAndroid Build Coastguard Worker    ) -> EValue:
455*523fa7a6SAndroid Build Coastguard Worker        """Converts a constant value to an EValue.
456*523fa7a6SAndroid Build Coastguard Worker
457*523fa7a6SAndroid Build Coastguard Worker        Returns an EValue given the Python representation and JIT type. On common paths there should
458*523fa7a6SAndroid Build Coastguard Worker        always be a JIT type provided. Users can pass in a None to infer the JIT type but this
459*523fa7a6SAndroid Build Coastguard Worker        should never be the default case due to the existence of container types.
460*523fa7a6SAndroid Build Coastguard Worker        """
461*523fa7a6SAndroid Build Coastguard Worker        if val is None:
462*523fa7a6SAndroid Build Coastguard Worker            return EValue(Null())
463*523fa7a6SAndroid Build Coastguard Worker
464*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, (list, tuple)):
465*523fa7a6SAndroid Build Coastguard Worker            # Refine Optional[List[T]] -> List[T] This works because if the val was None it would
466*523fa7a6SAndroid Build Coastguard Worker            # have converted to Null before this function call.
467*523fa7a6SAndroid Build Coastguard Worker            if val_type is None:
468*523fa7a6SAndroid Build Coastguard Worker                val_type = torch.ListType(
469*523fa7a6SAndroid Build Coastguard Worker                    self._get_list_tuple_jit_type(val)  # pyre-ignore
470*523fa7a6SAndroid Build Coastguard Worker                )
471*523fa7a6SAndroid Build Coastguard Worker            if isinstance(val_type, torch.OptionalType):
472*523fa7a6SAndroid Build Coastguard Worker                val_type = val_type.getElementType()
473*523fa7a6SAndroid Build Coastguard Worker            assert isinstance(val_type, torch.ListType)
474*523fa7a6SAndroid Build Coastguard Worker            return self._emit_list(
475*523fa7a6SAndroid Build Coastguard Worker                typing.cast(List[_Argument], val),
476*523fa7a6SAndroid Build Coastguard Worker                typing.cast(_SchemaType, val_type.getElementType()),
477*523fa7a6SAndroid Build Coastguard Worker            )
478*523fa7a6SAndroid Build Coastguard Worker
479*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, float):
480*523fa7a6SAndroid Build Coastguard Worker            return EValue(Double(val))
481*523fa7a6SAndroid Build Coastguard Worker
482*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, bool):
483*523fa7a6SAndroid Build Coastguard Worker            return EValue(Bool(val))
484*523fa7a6SAndroid Build Coastguard Worker
485*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, int):
486*523fa7a6SAndroid Build Coastguard Worker            return EValue(Int(val))
487*523fa7a6SAndroid Build Coastguard Worker
488*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, str):
489*523fa7a6SAndroid Build Coastguard Worker            return EValue(String(val))
490*523fa7a6SAndroid Build Coastguard Worker
491*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, torch.dtype):
492*523fa7a6SAndroid Build Coastguard Worker            return EValue(Int(scalar_type_enum(val)))
493*523fa7a6SAndroid Build Coastguard Worker
494*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, torch.layout):
495*523fa7a6SAndroid Build Coastguard Worker            return EValue(Int(layout_enum(val)))
496*523fa7a6SAndroid Build Coastguard Worker
497*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, torch.memory_format):
498*523fa7a6SAndroid Build Coastguard Worker            try:
499*523fa7a6SAndroid Build Coastguard Worker                return EValue(Int(memory_format_enum(val)))
500*523fa7a6SAndroid Build Coastguard Worker            except KeyError:
501*523fa7a6SAndroid Build Coastguard Worker                raise InternalError(
502*523fa7a6SAndroid Build Coastguard Worker                    self._emit_node_specific_error(
503*523fa7a6SAndroid Build Coastguard Worker                        self.node,
504*523fa7a6SAndroid Build Coastguard Worker                        f"Tensor has a memory_format that is unsupported in ExecuTorch: {val}",
505*523fa7a6SAndroid Build Coastguard Worker                    )
506*523fa7a6SAndroid Build Coastguard Worker                )
507*523fa7a6SAndroid Build Coastguard Worker
508*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, torch.Tensor):
509*523fa7a6SAndroid Build Coastguard Worker            raise ExportError(
510*523fa7a6SAndroid Build Coastguard Worker                ExportErrorType.NOT_SUPPORTED,
511*523fa7a6SAndroid Build Coastguard Worker                self._emit_node_specific_error(
512*523fa7a6SAndroid Build Coastguard Worker                    self.node,
513*523fa7a6SAndroid Build Coastguard Worker                    "constant_to_evalue should not be encountering constant tensors, they should be emitted through other codepaths.",
514*523fa7a6SAndroid Build Coastguard Worker                ),
515*523fa7a6SAndroid Build Coastguard Worker            )
516*523fa7a6SAndroid Build Coastguard Worker
517*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
518*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.NOT_SUPPORTED,
519*523fa7a6SAndroid Build Coastguard Worker            self._emit_node_specific_error(
520*523fa7a6SAndroid Build Coastguard Worker                self.node, f"Unsupported constant type: {type(val).__name__}"
521*523fa7a6SAndroid Build Coastguard Worker            ),
522*523fa7a6SAndroid Build Coastguard Worker        )
523*523fa7a6SAndroid Build Coastguard Worker
524*523fa7a6SAndroid Build Coastguard Worker    def _emit_evalue(self, val: EValue) -> _AbstractValue:
525*523fa7a6SAndroid Build Coastguard Worker        """Writes an EValue to the emitter state.
526*523fa7a6SAndroid Build Coastguard Worker
527*523fa7a6SAndroid Build Coastguard Worker        Given an Evalue, adds it to the emitter_state's values table, and returns the AbstractValue
528*523fa7a6SAndroid Build Coastguard Worker        representing it.
529*523fa7a6SAndroid Build Coastguard Worker        """
530*523fa7a6SAndroid Build Coastguard Worker        self.emitter_state.values.append(val)
531*523fa7a6SAndroid Build Coastguard Worker        tensor = val.val if isinstance(val.val, Tensor) else None
532*523fa7a6SAndroid Build Coastguard Worker        return _AbstractValue(len(self.emitter_state.values) - 1, tensor)
533*523fa7a6SAndroid Build Coastguard Worker
534*523fa7a6SAndroid Build Coastguard Worker    def _emit_spec(self, spec: ValueSpec) -> _EmitterValue:
535*523fa7a6SAndroid Build Coastguard Worker        """Given the provided spec constructs the corresponding EValue from it and then emits it."""
536*523fa7a6SAndroid Build Coastguard Worker
537*523fa7a6SAndroid Build Coastguard Worker        def _process(spec: LeafValueSpec) -> _AbstractValue:
538*523fa7a6SAndroid Build Coastguard Worker            if isinstance(spec, (list, tuple)):
539*523fa7a6SAndroid Build Coastguard Worker                raise InternalError(
540*523fa7a6SAndroid Build Coastguard Worker                    self.emit_node_specific_error(
541*523fa7a6SAndroid Build Coastguard Worker                        self.node,
542*523fa7a6SAndroid Build Coastguard Worker                        "Node spec should be either non-nested container or a scalar object",
543*523fa7a6SAndroid Build Coastguard Worker                    )
544*523fa7a6SAndroid Build Coastguard Worker                )
545*523fa7a6SAndroid Build Coastguard Worker
546*523fa7a6SAndroid Build Coastguard Worker            # ScalarSpec can theoretically be handled fine, but it shouldn't be appearing so rather
547*523fa7a6SAndroid Build Coastguard Worker            # than handle it, assert that it isn't supposed to be present. In the future if it has a
548*523fa7a6SAndroid Build Coastguard Worker            # reason to appear we can relax this assert.
549*523fa7a6SAndroid Build Coastguard Worker            self._internal_assert_emitter(
550*523fa7a6SAndroid Build Coastguard Worker                isinstance(spec, TensorSpec),
551*523fa7a6SAndroid Build Coastguard Worker                self.node,
552*523fa7a6SAndroid Build Coastguard Worker                f"Invalid node spec expected TensorSpec received {spec}",
553*523fa7a6SAndroid Build Coastguard Worker            )
554*523fa7a6SAndroid Build Coastguard Worker
555*523fa7a6SAndroid Build Coastguard Worker            ret = self._emit_evalue(self._tensor_spec_to_evalue(spec))  # pyre-ignore
556*523fa7a6SAndroid Build Coastguard Worker            self.emitter_state.spec2id_dict[spec] = ret.id  # pyre-ignore
557*523fa7a6SAndroid Build Coastguard Worker            return ret
558*523fa7a6SAndroid Build Coastguard Worker
559*523fa7a6SAndroid Build Coastguard Worker        return pytree.tree_map(_process, spec)
560*523fa7a6SAndroid Build Coastguard Worker
561*523fa7a6SAndroid Build Coastguard Worker    def _merge_chain(self, chain: Chain) -> None:
562*523fa7a6SAndroid Build Coastguard Worker        """Merges the chain generated from subgraphs (like those originating from control flow) back
563*523fa7a6SAndroid Build Coastguard Worker        into the main program chain."""
564*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.extend(chain.instructions)
565*523fa7a6SAndroid Build Coastguard Worker
566*523fa7a6SAndroid Build Coastguard Worker    def _emit_cond(
567*523fa7a6SAndroid Build Coastguard Worker        self,
568*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[_Argument, ...],
569*523fa7a6SAndroid Build Coastguard Worker        subemitter_binding_output_values: Optional[List[_AbstractValue]],
570*523fa7a6SAndroid Build Coastguard Worker    ) -> List[_AbstractValue]:
571*523fa7a6SAndroid Build Coastguard Worker        """Emits control_flow.cond.
572*523fa7a6SAndroid Build Coastguard Worker
573*523fa7a6SAndroid Build Coastguard Worker        Converts the higher order op into jumps and inlines the submodules of the true and false
574*523fa7a6SAndroid Build Coastguard Worker        branches. Control flow can be nested. The general emitted structure is: <Jump Instruction> -
575*523fa7a6SAndroid Build Coastguard Worker        decides which branch to take <True Branch> <Jump Instruction> - jumps to End Of Cond after
576*523fa7a6SAndroid Build Coastguard Worker        executing true branch <False Branch> <End Of Cond>
577*523fa7a6SAndroid Build Coastguard Worker        """
578*523fa7a6SAndroid Build Coastguard Worker        pred, true_branch, false_branch, inputs = args
579*523fa7a6SAndroid Build Coastguard Worker
580*523fa7a6SAndroid Build Coastguard Worker        # Emit the true branch.
581*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(true_branch, torch.fx.GraphModule)
582*523fa7a6SAndroid Build Coastguard Worker        true_branch_emitter = _Emitter(
583*523fa7a6SAndroid Build Coastguard Worker            true_branch,
584*523fa7a6SAndroid Build Coastguard Worker            self.emitter_state,
585*523fa7a6SAndroid Build Coastguard Worker            self.program_state,
586*523fa7a6SAndroid Build Coastguard Worker            instruction_start_offset=self.instruction_start_offset
587*523fa7a6SAndroid Build Coastguard Worker            + len(self.chain.instructions)
588*523fa7a6SAndroid Build Coastguard Worker            + 1,
589*523fa7a6SAndroid Build Coastguard Worker            binding_input_values=typing.cast(List[_AbstractValue], inputs),
590*523fa7a6SAndroid Build Coastguard Worker            binding_output_values=subemitter_binding_output_values,
591*523fa7a6SAndroid Build Coastguard Worker        )
592*523fa7a6SAndroid Build Coastguard Worker        true_branch_emitter.run()
593*523fa7a6SAndroid Build Coastguard Worker
594*523fa7a6SAndroid Build Coastguard Worker        # Emit the jump.
595*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(pred, _AbstractValue)
596*523fa7a6SAndroid Build Coastguard Worker        jf_instruction_to_skip_true = Instruction(
597*523fa7a6SAndroid Build Coastguard Worker            JumpFalseCall(
598*523fa7a6SAndroid Build Coastguard Worker                cond_value_index=pred.id,
599*523fa7a6SAndroid Build Coastguard Worker                destination_instruction=self.instruction_start_offset
600*523fa7a6SAndroid Build Coastguard Worker                + len(self.chain.instructions)
601*523fa7a6SAndroid Build Coastguard Worker                + len(true_branch_emitter.chain.instructions)
602*523fa7a6SAndroid Build Coastguard Worker                # This jump instruction should point at instruction that is after all instructions
603*523fa7a6SAndroid Build Coastguard Worker                # for the true branch. The reason we add 2 is because we need to account for this
604*523fa7a6SAndroid Build Coastguard Worker                # instruction we are creating right now and the jump instruction that true branch
605*523fa7a6SAndroid Build Coastguard Worker                # will create.
606*523fa7a6SAndroid Build Coastguard Worker                + 2,
607*523fa7a6SAndroid Build Coastguard Worker            )
608*523fa7a6SAndroid Build Coastguard Worker        )
609*523fa7a6SAndroid Build Coastguard Worker
610*523fa7a6SAndroid Build Coastguard Worker        # Insert the branch picking jump instruction to the main chain.
611*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(jf_instruction_to_skip_true)
612*523fa7a6SAndroid Build Coastguard Worker        # Now that we created the true branch instructions, we move them to the main chain.
613*523fa7a6SAndroid Build Coastguard Worker        self._merge_chain(true_branch_emitter.chain)
614*523fa7a6SAndroid Build Coastguard Worker
615*523fa7a6SAndroid Build Coastguard Worker        # emit false branch
616*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(false_branch, torch.fx.GraphModule)
617*523fa7a6SAndroid Build Coastguard Worker        false_branch_emitter = _Emitter(
618*523fa7a6SAndroid Build Coastguard Worker            false_branch,
619*523fa7a6SAndroid Build Coastguard Worker            self.emitter_state,
620*523fa7a6SAndroid Build Coastguard Worker            self.program_state,
621*523fa7a6SAndroid Build Coastguard Worker            instruction_start_offset=self.instruction_start_offset
622*523fa7a6SAndroid Build Coastguard Worker            + len(self.chain.instructions)
623*523fa7a6SAndroid Build Coastguard Worker            + 1,
624*523fa7a6SAndroid Build Coastguard Worker            binding_input_values=typing.cast(List[_AbstractValue], inputs),
625*523fa7a6SAndroid Build Coastguard Worker            binding_output_values=subemitter_binding_output_values,
626*523fa7a6SAndroid Build Coastguard Worker        )
627*523fa7a6SAndroid Build Coastguard Worker        false_branch_emitter.run()
628*523fa7a6SAndroid Build Coastguard Worker
629*523fa7a6SAndroid Build Coastguard Worker        # We bake in constant False because this will trigger the instruction to jump over all false
630*523fa7a6SAndroid Build Coastguard Worker        # branch instructions and point at the start of instruction right after control flow.
631*523fa7a6SAndroid Build Coastguard Worker        value = self._emit_evalue(EValue(Bool(False)))
632*523fa7a6SAndroid Build Coastguard Worker        jf_instruction_to_skip_false = Instruction(
633*523fa7a6SAndroid Build Coastguard Worker            JumpFalseCall(
634*523fa7a6SAndroid Build Coastguard Worker                cond_value_index=value.id,
635*523fa7a6SAndroid Build Coastguard Worker                destination_instruction=self.instruction_start_offset
636*523fa7a6SAndroid Build Coastguard Worker                + len(self.chain.instructions)
637*523fa7a6SAndroid Build Coastguard Worker                + len(false_branch_emitter.chain.instructions)
638*523fa7a6SAndroid Build Coastguard Worker                + 1,
639*523fa7a6SAndroid Build Coastguard Worker            )
640*523fa7a6SAndroid Build Coastguard Worker        )
641*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(jf_instruction_to_skip_false)
642*523fa7a6SAndroid Build Coastguard Worker        self._merge_chain(false_branch_emitter.chain)
643*523fa7a6SAndroid Build Coastguard Worker        return subemitter_binding_output_values
644*523fa7a6SAndroid Build Coastguard Worker
645*523fa7a6SAndroid Build Coastguard Worker    def _emit_map(
646*523fa7a6SAndroid Build Coastguard Worker        self,
647*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[_Argument, ...],
648*523fa7a6SAndroid Build Coastguard Worker        subemitter_binding_output_values: List[_AbstractValue],
649*523fa7a6SAndroid Build Coastguard Worker    ) -> List[_AbstractValue]:
650*523fa7a6SAndroid Build Coastguard Worker        """Emits torch.map.
651*523fa7a6SAndroid Build Coastguard Worker
652*523fa7a6SAndroid Build Coastguard Worker        Converts the higher order op into a loop constructed from jump instructions and primitive
653*523fa7a6SAndroid Build Coastguard Worker        int operations. A concat-like custom op is also injected into the submodule code to handle
654*523fa7a6SAndroid Build Coastguard Worker        the construction of the map output.
655*523fa7a6SAndroid Build Coastguard Worker
656*523fa7a6SAndroid Build Coastguard Worker        Below is what the input graph that is provided to emit_map looks like. class
657*523fa7a6SAndroid Build Coastguard Worker        TestMapCond(torch.nn.Module): def __init__(self):
658*523fa7a6SAndroid Build Coastguard Worker            super().__init__()
659*523fa7a6SAndroid Build Coastguard Worker
660*523fa7a6SAndroid Build Coastguard Worker        def forward(self, x,y):
661*523fa7a6SAndroid Build Coastguard Worker            return control_flow.map(map_fn, x, y)
662*523fa7a6SAndroid Build Coastguard Worker
663*523fa7a6SAndroid Build Coastguard Worker        Corresponding graph: def forward(self, arg0_1, arg1_1):
664*523fa7a6SAndroid Build Coastguard Worker            submodule_0 = self.submodule_0 map_1 = torch.ops.higher_order.map_impl(submodule_0, arg0_1, arg1_1);
665*523fa7a6SAndroid Build Coastguard Worker            submodule_0 = arg0_1 = arg1_1 = None return [map_1]
666*523fa7a6SAndroid Build Coastguard Worker
667*523fa7a6SAndroid Build Coastguard Worker        submodule_0: def forward(self, arg0_1, arg1_1):
668*523fa7a6SAndroid Build Coastguard Worker            add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None return
669*523fa7a6SAndroid Build Coastguard Worker            add_tensor
670*523fa7a6SAndroid Build Coastguard Worker
671*523fa7a6SAndroid Build Coastguard Worker        Post the transformations done by emit_map this is what the submodule program looks like. def
672*523fa7a6SAndroid Build Coastguard Worker        forward(self, arg0_1, arg1_1):
673*523fa7a6SAndroid Build Coastguard Worker            sym_size = torch.ops.aten.sym_size(arg0_1) # Emitter creates a variable here to track
674*523fa7a6SAndroid Build Coastguard Worker            iteration index select_copy_tensor = torch.ops.aten.select(arg0_1, 0, iteration_index)
675*523fa7a6SAndroid Build Coastguard Worker            add_tensor = torch.ops.aten.add.Tensor(select_copy_tensor, arg1_1);  arg0_1 = arg1_1 =
676*523fa7a6SAndroid Build Coastguard Worker            None output_of_map = torch.ops.executorch.prim.et_copy_index(output_of_map, add_tensor,
677*523fa7a6SAndroid Build Coastguard Worker            iteration_index) iteration_index = torch.ops.executorch.prim.add.int(iteration_index, 1,
678*523fa7a6SAndroid Build Coastguard Worker            iteration_index) done_bool = torch.ops.executorch.prim.eq.int(iteration_index, sym_size,
679*523fa7a6SAndroid Build Coastguard Worker            done_bool) # Emitter inserts a instruction here, if done_bool == False jump to
680*523fa7a6SAndroid Build Coastguard Worker            selcect_copy op # if not continue. return add_tensor
681*523fa7a6SAndroid Build Coastguard Worker        """
682*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(
683*523fa7a6SAndroid Build Coastguard Worker            subemitter_binding_output_values, (list, tuple)
684*523fa7a6SAndroid Build Coastguard Worker        ), f"Expect a list for subemitter_binding_output_values for map. Got {subemitter_binding_output_values}."
685*523fa7a6SAndroid Build Coastguard Worker
686*523fa7a6SAndroid Build Coastguard Worker        if len(subemitter_binding_output_values) != 1:
687*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
688*523fa7a6SAndroid Build Coastguard Worker                f"Multiple outputs are not supported. Got {len(subemitter_binding_output_values)}."
689*523fa7a6SAndroid Build Coastguard Worker            )
690*523fa7a6SAndroid Build Coastguard Worker        f, mapped_args, inputs = args
691*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(mapped_args, (list, tuple))
692*523fa7a6SAndroid Build Coastguard Worker        num_mapped_args: int = len(mapped_args)
693*523fa7a6SAndroid Build Coastguard Worker        if num_mapped_args != 1:
694*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
695*523fa7a6SAndroid Build Coastguard Worker                f"Emitting map with more than one mapped args is not supported. Got {num_mapped_args}."
696*523fa7a6SAndroid Build Coastguard Worker            )
697*523fa7a6SAndroid Build Coastguard Worker        x = mapped_args[0]
698*523fa7a6SAndroid Build Coastguard Worker
699*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(f, torch.fx.GraphModule)
700*523fa7a6SAndroid Build Coastguard Worker
701*523fa7a6SAndroid Build Coastguard Worker        # Generate the EValue that we will use as our iterator index to keep track of which
702*523fa7a6SAndroid Build Coastguard Worker        # iteration we are currently on.
703*523fa7a6SAndroid Build Coastguard Worker        iter_idx = self._emit_evalue(EValue(Int(0)))
704*523fa7a6SAndroid Build Coastguard Worker        # Generate the kernel call that will output the number of iterations we need to run for this
705*523fa7a6SAndroid Build Coastguard Worker        # input tensor.
706*523fa7a6SAndroid Build Coastguard Worker        op_index, op = self._get_operator(
707*523fa7a6SAndroid Build Coastguard Worker            name="aten::sym_size",
708*523fa7a6SAndroid Build Coastguard Worker            overload="int",
709*523fa7a6SAndroid Build Coastguard Worker        )
710*523fa7a6SAndroid Build Coastguard Worker        sym_size = self._emit_evalue(EValue(Int(0)))
711*523fa7a6SAndroid Build Coastguard Worker        kernel = Instruction(
712*523fa7a6SAndroid Build Coastguard Worker            KernelCall(
713*523fa7a6SAndroid Build Coastguard Worker                op_index=op_index,
714*523fa7a6SAndroid Build Coastguard Worker                args=[x.id, self._emit_evalue(EValue(Int(0))).id, sym_size.id],
715*523fa7a6SAndroid Build Coastguard Worker            )
716*523fa7a6SAndroid Build Coastguard Worker        )
717*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(kernel)
718*523fa7a6SAndroid Build Coastguard Worker
719*523fa7a6SAndroid Build Coastguard Worker        # This kernel call will slice the input tensor along the index specified in iter_idx to
720*523fa7a6SAndroid Build Coastguard Worker        # generate the input slice on which this iteration will be working on.
721*523fa7a6SAndroid Build Coastguard Worker        op_index, op = self._get_operator(
722*523fa7a6SAndroid Build Coastguard Worker            name="aten::select_copy",
723*523fa7a6SAndroid Build Coastguard Worker            overload="int_out",
724*523fa7a6SAndroid Build Coastguard Worker        )
725*523fa7a6SAndroid Build Coastguard Worker        # This select copy has to output to the tensor which is the input placeholder to the map
726*523fa7a6SAndroid Build Coastguard Worker        # sub-graph. That placeholder isn't allocated an EValue id until the map emitter is run, so
727*523fa7a6SAndroid Build Coastguard Worker        # we temporarily store -1 until the map emitter is run during which the placeholder will be
728*523fa7a6SAndroid Build Coastguard Worker        # allocated an EValue id. After the map emitter is run we will retrieve that id and replace
729*523fa7a6SAndroid Build Coastguard Worker        # the -1's.
730*523fa7a6SAndroid Build Coastguard Worker        kernel = Instruction(
731*523fa7a6SAndroid Build Coastguard Worker            KernelCall(
732*523fa7a6SAndroid Build Coastguard Worker                op_index=op_index,
733*523fa7a6SAndroid Build Coastguard Worker                args=[
734*523fa7a6SAndroid Build Coastguard Worker                    x.id,
735*523fa7a6SAndroid Build Coastguard Worker                    self._emit_evalue(EValue(Int(0))).id,
736*523fa7a6SAndroid Build Coastguard Worker                    iter_idx.id,
737*523fa7a6SAndroid Build Coastguard Worker                    -1,  # input_tensor_value.id,
738*523fa7a6SAndroid Build Coastguard Worker                    -1,  # input_tensor_value.id,
739*523fa7a6SAndroid Build Coastguard Worker                ],
740*523fa7a6SAndroid Build Coastguard Worker            )
741*523fa7a6SAndroid Build Coastguard Worker        )
742*523fa7a6SAndroid Build Coastguard Worker        # Store the index of this instruction as it will be where we will jump back to after the end
743*523fa7a6SAndroid Build Coastguard Worker        # of an iteration.
744*523fa7a6SAndroid Build Coastguard Worker        jump_to_instruction = self.instruction_start_offset + len(
745*523fa7a6SAndroid Build Coastguard Worker            self.chain.instructions
746*523fa7a6SAndroid Build Coastguard Worker        )
747*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(kernel)
748*523fa7a6SAndroid Build Coastguard Worker
749*523fa7a6SAndroid Build Coastguard Worker        # Emit the map operator submodule.
750*523fa7a6SAndroid Build Coastguard Worker        map_emitter = _Emitter(
751*523fa7a6SAndroid Build Coastguard Worker            f,
752*523fa7a6SAndroid Build Coastguard Worker            self.emitter_state,
753*523fa7a6SAndroid Build Coastguard Worker            self.program_state,
754*523fa7a6SAndroid Build Coastguard Worker            instruction_start_offset=self.instruction_start_offset
755*523fa7a6SAndroid Build Coastguard Worker            + len(self.chain.instructions),
756*523fa7a6SAndroid Build Coastguard Worker            # Only the first input is a placeholder, rest of the inputs are args to the map fn.
757*523fa7a6SAndroid Build Coastguard Worker            binding_input_values=[-1, *inputs],
758*523fa7a6SAndroid Build Coastguard Worker            binding_output_values=subemitter_binding_output_values,
759*523fa7a6SAndroid Build Coastguard Worker        )
760*523fa7a6SAndroid Build Coastguard Worker        map_emitter.run()
761*523fa7a6SAndroid Build Coastguard Worker
762*523fa7a6SAndroid Build Coastguard Worker        # Merge all the instructions from the map submodule.
763*523fa7a6SAndroid Build Coastguard Worker        self._merge_chain(map_emitter.chain)
764*523fa7a6SAndroid Build Coastguard Worker        # Get rid of the return instruction emitted by the map subemitter.
765*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.pop()
766*523fa7a6SAndroid Build Coastguard Worker        # At the end of each submodule emit we insert a move call that moves the output of the
767*523fa7a6SAndroid Build Coastguard Worker        # submodule to a deterministic EValue, which is especially useful for if/else branches where
768*523fa7a6SAndroid Build Coastguard Worker        # we want the output of either branch to be in the same EValue, but we don't need a move
769*523fa7a6SAndroid Build Coastguard Worker        # here as our custom op executorch_prim::et_copy_index which is inserted later does that
770*523fa7a6SAndroid Build Coastguard Worker        # for us.
771*523fa7a6SAndroid Build Coastguard Worker
772*523fa7a6SAndroid Build Coastguard Worker        # Now that the map emitter has finished running retrieve the input placeholder EValue id and
773*523fa7a6SAndroid Build Coastguard Worker        # update the select_copy kernel call to output to those id's.
774*523fa7a6SAndroid Build Coastguard Worker        kernel.instr_args.args[-1] = map_emitter.binding_input_values[0].id
775*523fa7a6SAndroid Build Coastguard Worker        kernel.instr_args.args[-2] = kernel.instr_args.args[-1]
776*523fa7a6SAndroid Build Coastguard Worker
777*523fa7a6SAndroid Build Coastguard Worker        self._internal_assert_emitter(
778*523fa7a6SAndroid Build Coastguard Worker            len(map_emitter.concrete_output_ids) == 1,
779*523fa7a6SAndroid Build Coastguard Worker            self.node,
780*523fa7a6SAndroid Build Coastguard Worker            "Map should return only one element",
781*523fa7a6SAndroid Build Coastguard Worker        )
782*523fa7a6SAndroid Build Coastguard Worker
783*523fa7a6SAndroid Build Coastguard Worker        # Here we call the custom op, specially added for the map operator. The output of this
784*523fa7a6SAndroid Build Coastguard Worker        # iteration will be appended to the accumulator tensor that we are maintaining. This
785*523fa7a6SAndroid Build Coastguard Worker        # accumulator tensor is the actual output of the map submodule.
786*523fa7a6SAndroid Build Coastguard Worker        op_index, op = self._get_operator(
787*523fa7a6SAndroid Build Coastguard Worker            name="executorch_prim::et_copy_index",
788*523fa7a6SAndroid Build Coastguard Worker            overload="tensor",
789*523fa7a6SAndroid Build Coastguard Worker        )
790*523fa7a6SAndroid Build Coastguard Worker        kernel = Instruction(
791*523fa7a6SAndroid Build Coastguard Worker            KernelCall(
792*523fa7a6SAndroid Build Coastguard Worker                op_index,
793*523fa7a6SAndroid Build Coastguard Worker                args=[
794*523fa7a6SAndroid Build Coastguard Worker                    subemitter_binding_output_values[0].id,
795*523fa7a6SAndroid Build Coastguard Worker                    map_emitter.concrete_output_ids[0].id,
796*523fa7a6SAndroid Build Coastguard Worker                    iter_idx.id,
797*523fa7a6SAndroid Build Coastguard Worker                ],
798*523fa7a6SAndroid Build Coastguard Worker            )
799*523fa7a6SAndroid Build Coastguard Worker        )
800*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(kernel)
801*523fa7a6SAndroid Build Coastguard Worker
802*523fa7a6SAndroid Build Coastguard Worker        # Increment iter_idx to mark that we have completed an iteration.
803*523fa7a6SAndroid Build Coastguard Worker        op_index, op = self._get_operator(
804*523fa7a6SAndroid Build Coastguard Worker            name="executorch_prim::add",
805*523fa7a6SAndroid Build Coastguard Worker            overload="Scalar",
806*523fa7a6SAndroid Build Coastguard Worker        )
807*523fa7a6SAndroid Build Coastguard Worker        kernel = Instruction(
808*523fa7a6SAndroid Build Coastguard Worker            KernelCall(
809*523fa7a6SAndroid Build Coastguard Worker                op_index,
810*523fa7a6SAndroid Build Coastguard Worker                args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id],
811*523fa7a6SAndroid Build Coastguard Worker            )
812*523fa7a6SAndroid Build Coastguard Worker        )
813*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(kernel)
814*523fa7a6SAndroid Build Coastguard Worker
815*523fa7a6SAndroid Build Coastguard Worker        jump_bool_value = self._emit_evalue(EValue(Bool(False)))
816*523fa7a6SAndroid Build Coastguard Worker
817*523fa7a6SAndroid Build Coastguard Worker        # Generate the kernel call to check whether or not we have completed all the iterations. If
818*523fa7a6SAndroid Build Coastguard Worker        # not jump back to the select_copy instruction that we generated at the beginning of this
819*523fa7a6SAndroid Build Coastguard Worker        # section.
820*523fa7a6SAndroid Build Coastguard Worker        op_index, op = self._get_operator(
821*523fa7a6SAndroid Build Coastguard Worker            name="executorch_prim::eq",
822*523fa7a6SAndroid Build Coastguard Worker            overload="Scalar",
823*523fa7a6SAndroid Build Coastguard Worker        )
824*523fa7a6SAndroid Build Coastguard Worker        kernel = Instruction(
825*523fa7a6SAndroid Build Coastguard Worker            KernelCall(
826*523fa7a6SAndroid Build Coastguard Worker                op_index,
827*523fa7a6SAndroid Build Coastguard Worker                args=[iter_idx.id, sym_size.id, jump_bool_value.id],
828*523fa7a6SAndroid Build Coastguard Worker            )
829*523fa7a6SAndroid Build Coastguard Worker        )
830*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(kernel)
831*523fa7a6SAndroid Build Coastguard Worker
832*523fa7a6SAndroid Build Coastguard Worker        jf_beginning_loop = Instruction(
833*523fa7a6SAndroid Build Coastguard Worker            JumpFalseCall(
834*523fa7a6SAndroid Build Coastguard Worker                cond_value_index=jump_bool_value.id,
835*523fa7a6SAndroid Build Coastguard Worker                destination_instruction=jump_to_instruction,
836*523fa7a6SAndroid Build Coastguard Worker            )
837*523fa7a6SAndroid Build Coastguard Worker        )
838*523fa7a6SAndroid Build Coastguard Worker
839*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(jf_beginning_loop)
840*523fa7a6SAndroid Build Coastguard Worker
841*523fa7a6SAndroid Build Coastguard Worker        # Reset iter_idx in case we plan to run the model again.
842*523fa7a6SAndroid Build Coastguard Worker        op_index, op = self._get_operator(
843*523fa7a6SAndroid Build Coastguard Worker            name="executorch_prim::sub",
844*523fa7a6SAndroid Build Coastguard Worker            overload="Scalar",
845*523fa7a6SAndroid Build Coastguard Worker        )
846*523fa7a6SAndroid Build Coastguard Worker        kernel = Instruction(
847*523fa7a6SAndroid Build Coastguard Worker            KernelCall(
848*523fa7a6SAndroid Build Coastguard Worker                op_index,
849*523fa7a6SAndroid Build Coastguard Worker                args=[iter_idx.id, sym_size.id, iter_idx.id],
850*523fa7a6SAndroid Build Coastguard Worker            )
851*523fa7a6SAndroid Build Coastguard Worker        )
852*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(kernel)
853*523fa7a6SAndroid Build Coastguard Worker
854*523fa7a6SAndroid Build Coastguard Worker        return subemitter_binding_output_values
855*523fa7a6SAndroid Build Coastguard Worker
856*523fa7a6SAndroid Build Coastguard Worker    def _emit_control_flow(
857*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
858*523fa7a6SAndroid Build Coastguard Worker    ) -> _EmitterValue:
859*523fa7a6SAndroid Build Coastguard Worker        """Wraps common logic for emitting all control flow operations.
860*523fa7a6SAndroid Build Coastguard Worker
861*523fa7a6SAndroid Build Coastguard Worker        See the more specific emission functions for more details on how cond or map get emitted.
862*523fa7a6SAndroid Build Coastguard Worker        """
863*523fa7a6SAndroid Build Coastguard Worker        subemitter_binding_output_values = pytree.tree_map(
864*523fa7a6SAndroid Build Coastguard Worker            lambda spec: self._emit_spec(spec),
865*523fa7a6SAndroid Build Coastguard Worker            self.node.meta["spec"],
866*523fa7a6SAndroid Build Coastguard Worker        )
867*523fa7a6SAndroid Build Coastguard Worker
868*523fa7a6SAndroid Build Coastguard Worker        if target is torch.ops.higher_order.cond:
869*523fa7a6SAndroid Build Coastguard Worker            return self._emit_cond(args, subemitter_binding_output_values)
870*523fa7a6SAndroid Build Coastguard Worker        elif target is torch.ops.higher_order.map_impl:
871*523fa7a6SAndroid Build Coastguard Worker            return self._emit_map(args, subemitter_binding_output_values)
872*523fa7a6SAndroid Build Coastguard Worker        else:
873*523fa7a6SAndroid Build Coastguard Worker            raise InternalError(
874*523fa7a6SAndroid Build Coastguard Worker                self._emit_node_specific_error(
875*523fa7a6SAndroid Build Coastguard Worker                    self.node, f"Unsupported control flow operator: {target}"
876*523fa7a6SAndroid Build Coastguard Worker                )
877*523fa7a6SAndroid Build Coastguard Worker            )
878*523fa7a6SAndroid Build Coastguard Worker
879*523fa7a6SAndroid Build Coastguard Worker    def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
880*523fa7a6SAndroid Build Coastguard Worker        assert len(args) == 2
881*523fa7a6SAndroid Build Coastguard Worker
882*523fa7a6SAndroid Build Coastguard Worker        self_arg = self._emit_argument(args[0], torch.TensorType)  # pyre-ignore[6]
883*523fa7a6SAndroid Build Coastguard Worker        size_arg = self._emit_argument(args[1], torch.ListType.ofInts())
884*523fa7a6SAndroid Build Coastguard Worker        out_arg = self._emit_argument(
885*523fa7a6SAndroid Build Coastguard Worker            self._emit_spec(self.node.meta["spec"]), torch.TensorType  # pyre-ignore[6]
886*523fa7a6SAndroid Build Coastguard Worker        )
887*523fa7a6SAndroid Build Coastguard Worker
888*523fa7a6SAndroid Build Coastguard Worker        op_idx, op = self._get_operator(
889*523fa7a6SAndroid Build Coastguard Worker            name="executorch_prim::et_view",
890*523fa7a6SAndroid Build Coastguard Worker            overload="default",
891*523fa7a6SAndroid Build Coastguard Worker        )
892*523fa7a6SAndroid Build Coastguard Worker        kernel = Instruction(
893*523fa7a6SAndroid Build Coastguard Worker            KernelCall(
894*523fa7a6SAndroid Build Coastguard Worker                op_idx,
895*523fa7a6SAndroid Build Coastguard Worker                args=[
896*523fa7a6SAndroid Build Coastguard Worker                    self_arg.id,
897*523fa7a6SAndroid Build Coastguard Worker                    size_arg.id,
898*523fa7a6SAndroid Build Coastguard Worker                    out_arg.id,
899*523fa7a6SAndroid Build Coastguard Worker                ],
900*523fa7a6SAndroid Build Coastguard Worker            )
901*523fa7a6SAndroid Build Coastguard Worker        )
902*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(kernel)
903*523fa7a6SAndroid Build Coastguard Worker        return out_arg
904*523fa7a6SAndroid Build Coastguard Worker
905*523fa7a6SAndroid Build Coastguard Worker    def _add_debug_handle(
906*523fa7a6SAndroid Build Coastguard Worker        self,
907*523fa7a6SAndroid Build Coastguard Worker        emitter_id: int,
908*523fa7a6SAndroid Build Coastguard Worker        target: _Target,
909*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore[11]: Annotation `LoweredBackendModule` is not defined as a type.
910*523fa7a6SAndroid Build Coastguard Worker        lowered_module: "Optional[LoweredBackendModule]" = None,  # noqa: F821
911*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
912*523fa7a6SAndroid Build Coastguard Worker        """Updates the debug handle information for the current node.
913*523fa7a6SAndroid Build Coastguard Worker
914*523fa7a6SAndroid Build Coastguard Worker        If the current node is a delegate we agregate the debug handles of the subgraph and store
915*523fa7a6SAndroid Build Coastguard Worker        them in the map. If the current node is any other type we store the original information in
916*523fa7a6SAndroid Build Coastguard Worker        the debug handle map and replace it with the executorch instruction index corresponding to
917*523fa7a6SAndroid Build Coastguard Worker        this node.
918*523fa7a6SAndroid Build Coastguard Worker        """
919*523fa7a6SAndroid Build Coastguard Worker        # If it's a delegate call, collect the list of debug handles that are consumed by this
920*523fa7a6SAndroid Build Coastguard Worker        # delegate call and store it in the debug handle map.
921*523fa7a6SAndroid Build Coastguard Worker        if target == executorch_call_delegate:
922*523fa7a6SAndroid Build Coastguard Worker            debug_handle_list = []
923*523fa7a6SAndroid Build Coastguard Worker            # Use the lowered_module to fetch the original graph and its debug
924*523fa7a6SAndroid Build Coastguard Worker            # handles.
925*523fa7a6SAndroid Build Coastguard Worker            for node in lowered_module.original_module.graph.nodes:
926*523fa7a6SAndroid Build Coastguard Worker                if (
927*523fa7a6SAndroid Build Coastguard Worker                    node.op == "call_function"
928*523fa7a6SAndroid Build Coastguard Worker                    and node.meta.get("debug_handle") is not None
929*523fa7a6SAndroid Build Coastguard Worker                ):
930*523fa7a6SAndroid Build Coastguard Worker                    debug_handle_list.append(node.meta.get("debug_handle"))
931*523fa7a6SAndroid Build Coastguard Worker            self.debug_handle_map[emitter_id] = debug_handle_list
932*523fa7a6SAndroid Build Coastguard Worker            # Debug handle for this node is the emitter_id which is essentially the index of the
933*523fa7a6SAndroid Build Coastguard Worker            # instruction in the chain.
934*523fa7a6SAndroid Build Coastguard Worker            self.node.meta["debug_handle"] = emitter_id
935*523fa7a6SAndroid Build Coastguard Worker            return
936*523fa7a6SAndroid Build Coastguard Worker
937*523fa7a6SAndroid Build Coastguard Worker        if self.node.meta.get("debug_handle") is not None:
938*523fa7a6SAndroid Build Coastguard Worker            # Store the original debug handle in the debug handle map.
939*523fa7a6SAndroid Build Coastguard Worker            self.debug_handle_map[emitter_id] = self.node.meta.get("debug_handle")
940*523fa7a6SAndroid Build Coastguard Worker            # Replace the debug handle in the metadata of the node with the emitter id which
941*523fa7a6SAndroid Build Coastguard Worker            # represents the instruction index in the chain. We do this because in the runtime the
942*523fa7a6SAndroid Build Coastguard Worker            # instruction index is what is logged during perf/debug data logging and hence we want
943*523fa7a6SAndroid Build Coastguard Worker            # to store this in the node so that we can map the data logged by the runtime back to
944*523fa7a6SAndroid Build Coastguard Worker            # the node.
945*523fa7a6SAndroid Build Coastguard Worker            self.node.meta["debug_handle"] = emitter_id
946*523fa7a6SAndroid Build Coastguard Worker
947*523fa7a6SAndroid Build Coastguard Worker    def _add_delegate_map(
948*523fa7a6SAndroid Build Coastguard Worker        self,
949*523fa7a6SAndroid Build Coastguard Worker        lowered_module: "LoweredBackendModule",  # noqa
950*523fa7a6SAndroid Build Coastguard Worker        delegate_instruction_id: int,
951*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
952*523fa7a6SAndroid Build Coastguard Worker        """
953*523fa7a6SAndroid Build Coastguard Worker        Store the delegate map from this lowered module into the dictionary of delegate maps. It
954*523fa7a6SAndroid Build Coastguard Worker        will later be used for various debugging purposes such as linking back to original source
955*523fa7a6SAndroid Build Coastguard Worker        code, module hierarchy etc.
956*523fa7a6SAndroid Build Coastguard Worker        """
957*523fa7a6SAndroid Build Coastguard Worker        delegate_map = {}
958*523fa7a6SAndroid Build Coastguard Worker        if hasattr(lowered_module, "meta"):
959*523fa7a6SAndroid Build Coastguard Worker            delegate_map = lowered_module.meta.get("debug_handle_map", {})
960*523fa7a6SAndroid Build Coastguard Worker
961*523fa7a6SAndroid Build Coastguard Worker        self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = {
962*523fa7a6SAndroid Build Coastguard Worker            "name": lowered_module.backend_id,
963*523fa7a6SAndroid Build Coastguard Worker            "delegate_map": delegate_map,
964*523fa7a6SAndroid Build Coastguard Worker        }
965*523fa7a6SAndroid Build Coastguard Worker
966*523fa7a6SAndroid Build Coastguard Worker    def _emit_argument(
967*523fa7a6SAndroid Build Coastguard Worker        self, arg: _Argument, arg_type: Optional[_SchemaType]
968*523fa7a6SAndroid Build Coastguard Worker    ) -> _AbstractValue:
969*523fa7a6SAndroid Build Coastguard Worker        """Emit an argument to an operator or delegate if it had not already been emitted otherwise
970*523fa7a6SAndroid Build Coastguard Worker        return the previously emitted location"""
971*523fa7a6SAndroid Build Coastguard Worker        if isinstance(arg, _AbstractValue):
972*523fa7a6SAndroid Build Coastguard Worker            return arg
973*523fa7a6SAndroid Build Coastguard Worker        return self._emit_evalue(self._constant_to_evalue(arg, arg_type))
974*523fa7a6SAndroid Build Coastguard Worker
975*523fa7a6SAndroid Build Coastguard Worker    def _get_sym_ret(
976*523fa7a6SAndroid Build Coastguard Worker        self,
977*523fa7a6SAndroid Build Coastguard Worker        val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]],
978*523fa7a6SAndroid Build Coastguard Worker    ) -> Optional[_AbstractValue]:
979*523fa7a6SAndroid Build Coastguard Worker        """
980*523fa7a6SAndroid Build Coastguard Worker        Returns the emit ret for sym value.
981*523fa7a6SAndroid Build Coastguard Worker        """
982*523fa7a6SAndroid Build Coastguard Worker        ret = None
983*523fa7a6SAndroid Build Coastguard Worker        if isinstance(val, torch.SymInt):
984*523fa7a6SAndroid Build Coastguard Worker            ret = self._emit_evalue(EValue(Int(0)))
985*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, torch.BoolType):
986*523fa7a6SAndroid Build Coastguard Worker            ret = self._emit_evalue(EValue(Bool(False)))
987*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(val, torch.FloatType):
988*523fa7a6SAndroid Build Coastguard Worker            ret = self._emit_evalue(EValue(Double(0)))
989*523fa7a6SAndroid Build Coastguard Worker        return ret
990*523fa7a6SAndroid Build Coastguard Worker
991*523fa7a6SAndroid Build Coastguard Worker    def _get_sym_and_fake_tensor_ret(
992*523fa7a6SAndroid Build Coastguard Worker        self,
993*523fa7a6SAndroid Build Coastguard Worker        val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]],
994*523fa7a6SAndroid Build Coastguard Worker        spec: TensorSpec,
995*523fa7a6SAndroid Build Coastguard Worker    ) -> Union[List[_AbstractValue], _AbstractValue, Tuple[_AbstractValue, ...]]:
996*523fa7a6SAndroid Build Coastguard Worker        # Try to get the ret if it's a sym value.
997*523fa7a6SAndroid Build Coastguard Worker        ret = self._get_sym_ret(val)
998*523fa7a6SAndroid Build Coastguard Worker        # If the ret is None, it means that the val is not a sym value, but a regular tensor
999*523fa7a6SAndroid Build Coastguard Worker        if ret is None:
1000*523fa7a6SAndroid Build Coastguard Worker            ret = self._emit_spec(spec)
1001*523fa7a6SAndroid Build Coastguard Worker        assert ret is not None, "Can't have a None ret"
1002*523fa7a6SAndroid Build Coastguard Worker        return ret
1003*523fa7a6SAndroid Build Coastguard Worker
1004*523fa7a6SAndroid Build Coastguard Worker    def _emit_delegate(
1005*523fa7a6SAndroid Build Coastguard Worker        self,
1006*523fa7a6SAndroid Build Coastguard Worker        lowered_module: "LoweredBackendModule",  # noqa
1007*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[_Argument, ...],
1008*523fa7a6SAndroid Build Coastguard Worker        kwargs: Dict[str, _Argument],
1009*523fa7a6SAndroid Build Coastguard Worker    ) -> _EmitterValue:
1010*523fa7a6SAndroid Build Coastguard Worker        """Emit the delegates inputs and outputs as specified by the schema, then emit the
1011*523fa7a6SAndroid Build Coastguard Worker        delegate's blob."""
1012*523fa7a6SAndroid Build Coastguard Worker        processed_bytes = lowered_module.processed_bytes
1013*523fa7a6SAndroid Build Coastguard Worker
1014*523fa7a6SAndroid Build Coastguard Worker        delegate_index = self.emitter_state.delegate_cache.get(processed_bytes)
1015*523fa7a6SAndroid Build Coastguard Worker        delegate_ret = None
1016*523fa7a6SAndroid Build Coastguard Worker
1017*523fa7a6SAndroid Build Coastguard Worker        if isinstance(self.node.meta["spec"], list):
1018*523fa7a6SAndroid Build Coastguard Worker            delegate_ret = []
1019*523fa7a6SAndroid Build Coastguard Worker            for index, _ in enumerate(self.node.meta["val"]):
1020*523fa7a6SAndroid Build Coastguard Worker                ret = self._get_sym_and_fake_tensor_ret(
1021*523fa7a6SAndroid Build Coastguard Worker                    self.node.meta["val"][index], self.node.meta["spec"][index]
1022*523fa7a6SAndroid Build Coastguard Worker                )
1023*523fa7a6SAndroid Build Coastguard Worker                delegate_ret.append(ret)
1024*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(self.node.meta["spec"], tuple):
1025*523fa7a6SAndroid Build Coastguard Worker            if isinstance(self.node.meta["val"], FakeTensor):
1026*523fa7a6SAndroid Build Coastguard Worker                # There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor
1027*523fa7a6SAndroid Build Coastguard Worker                ret = self._get_sym_and_fake_tensor_ret(
1028*523fa7a6SAndroid Build Coastguard Worker                    self.node.meta["val"], self.node.meta["spec"][0]
1029*523fa7a6SAndroid Build Coastguard Worker                )
1030*523fa7a6SAndroid Build Coastguard Worker                delegate_ret = (ret,)
1031*523fa7a6SAndroid Build Coastguard Worker            else:
1032*523fa7a6SAndroid Build Coastguard Worker                delegate_ret = []
1033*523fa7a6SAndroid Build Coastguard Worker                for index, _ in enumerate(self.node.meta["val"]):
1034*523fa7a6SAndroid Build Coastguard Worker                    ret = self._get_sym_and_fake_tensor_ret(
1035*523fa7a6SAndroid Build Coastguard Worker                        self.node.meta["val"][index], self.node.meta["spec"][index]
1036*523fa7a6SAndroid Build Coastguard Worker                    )
1037*523fa7a6SAndroid Build Coastguard Worker                    delegate_ret.append(ret)
1038*523fa7a6SAndroid Build Coastguard Worker                delegate_ret = tuple(delegate_ret)
1039*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(self.node.meta["spec"], TensorSpec):
1040*523fa7a6SAndroid Build Coastguard Worker            ret = self._get_sym_and_fake_tensor_ret(
1041*523fa7a6SAndroid Build Coastguard Worker                self.node.meta["val"], self.node.meta["spec"]
1042*523fa7a6SAndroid Build Coastguard Worker            )
1043*523fa7a6SAndroid Build Coastguard Worker            delegate_ret = ret
1044*523fa7a6SAndroid Build Coastguard Worker        else:
1045*523fa7a6SAndroid Build Coastguard Worker            raise NotImplementedError(
1046*523fa7a6SAndroid Build Coastguard Worker                f"self.node.meta['spec'] {type(self.node.meta['spec'])} is not supported"
1047*523fa7a6SAndroid Build Coastguard Worker            )
1048*523fa7a6SAndroid Build Coastguard Worker        assert delegate_ret is not None, "Can't have a None delegate_ret"
1049*523fa7a6SAndroid Build Coastguard Worker        if delegate_index is None:
1050*523fa7a6SAndroid Build Coastguard Worker            # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
1051*523fa7a6SAndroid Build Coastguard Worker            # present.
1052*523fa7a6SAndroid Build Coastguard Worker            data_index: int = len(self.program_state.backend_delegate_data)
1053*523fa7a6SAndroid Build Coastguard Worker            self.program_state.backend_delegate_data.append(
1054*523fa7a6SAndroid Build Coastguard Worker                BackendDelegateInlineData(data=processed_bytes)
1055*523fa7a6SAndroid Build Coastguard Worker            )
1056*523fa7a6SAndroid Build Coastguard Worker
1057*523fa7a6SAndroid Build Coastguard Worker            backend_delegate = BackendDelegate(
1058*523fa7a6SAndroid Build Coastguard Worker                id=lowered_module.backend_id,
1059*523fa7a6SAndroid Build Coastguard Worker                processed=BackendDelegateDataReference(
1060*523fa7a6SAndroid Build Coastguard Worker                    location=DataLocation.INLINE, index=data_index
1061*523fa7a6SAndroid Build Coastguard Worker                ),
1062*523fa7a6SAndroid Build Coastguard Worker                compile_specs=lowered_module.compile_specs,
1063*523fa7a6SAndroid Build Coastguard Worker            )
1064*523fa7a6SAndroid Build Coastguard Worker            delegate_index = len(self.emitter_state.delegate_cache)
1065*523fa7a6SAndroid Build Coastguard Worker            self.emitter_state.delegates.append(backend_delegate)
1066*523fa7a6SAndroid Build Coastguard Worker            self.emitter_state.delegate_cache[processed_bytes] = delegate_index
1067*523fa7a6SAndroid Build Coastguard Worker
1068*523fa7a6SAndroid Build Coastguard Worker        # TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
1069*523fa7a6SAndroid Build Coastguard Worker        # function's spec and with default arguments. This requires us to store the function's spec
1070*523fa7a6SAndroid Build Coastguard Worker        # in to_backend()
1071*523fa7a6SAndroid Build Coastguard Worker        delegate_args = [
1072*523fa7a6SAndroid Build Coastguard Worker            self._emit_argument(arg, None).id
1073*523fa7a6SAndroid Build Coastguard Worker            for arg in typing.cast(List[_Argument], args)
1074*523fa7a6SAndroid Build Coastguard Worker        ]
1075*523fa7a6SAndroid Build Coastguard Worker
1076*523fa7a6SAndroid Build Coastguard Worker        for elem in pytree.tree_flatten(delegate_ret)[0]:
1077*523fa7a6SAndroid Build Coastguard Worker            delegate_args.append(elem.id)
1078*523fa7a6SAndroid Build Coastguard Worker
1079*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(
1080*523fa7a6SAndroid Build Coastguard Worker            Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args))
1081*523fa7a6SAndroid Build Coastguard Worker        )
1082*523fa7a6SAndroid Build Coastguard Worker
1083*523fa7a6SAndroid Build Coastguard Worker        return delegate_ret
1084*523fa7a6SAndroid Build Coastguard Worker
1085*523fa7a6SAndroid Build Coastguard Worker    def _get_operator(self, name: str, overload: str) -> Tuple[int, Operator]:
1086*523fa7a6SAndroid Build Coastguard Worker        """Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it
1087*523fa7a6SAndroid Build Coastguard Worker        if it is not already present"""
1088*523fa7a6SAndroid Build Coastguard Worker        key = (name, overload)
1089*523fa7a6SAndroid Build Coastguard Worker        op_index = self.emitter_state.operator_cache.get(key)
1090*523fa7a6SAndroid Build Coastguard Worker        if op_index is not None:
1091*523fa7a6SAndroid Build Coastguard Worker            return op_index, self.emitter_state.operators[op_index]
1092*523fa7a6SAndroid Build Coastguard Worker
1093*523fa7a6SAndroid Build Coastguard Worker        op_index, operator = len(self.emitter_state.operators), Operator(
1094*523fa7a6SAndroid Build Coastguard Worker            name=name, overload=overload
1095*523fa7a6SAndroid Build Coastguard Worker        )
1096*523fa7a6SAndroid Build Coastguard Worker        self.emitter_state.operators.append(operator)
1097*523fa7a6SAndroid Build Coastguard Worker        self.emitter_state.operator_cache[key] = op_index
1098*523fa7a6SAndroid Build Coastguard Worker        return op_index, operator
1099*523fa7a6SAndroid Build Coastguard Worker
1100*523fa7a6SAndroid Build Coastguard Worker    def _emit_operator(  # noqa: C901
1101*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1102*523fa7a6SAndroid Build Coastguard Worker    ) -> _EmitterValue:
1103*523fa7a6SAndroid Build Coastguard Worker        """Emits an operator (aten or custom), directly translates to a call_kernel instruction."""
1104*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(
1105*523fa7a6SAndroid Build Coastguard Worker            target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload)
1106*523fa7a6SAndroid Build Coastguard Worker        ), f"target is {target}"
1107*523fa7a6SAndroid Build Coastguard Worker
1108*523fa7a6SAndroid Build Coastguard Worker        # grab the name
1109*523fa7a6SAndroid Build Coastguard Worker        op_name = target._overloadpacket._qualified_op_name
1110*523fa7a6SAndroid Build Coastguard Worker        op_overload = ""
1111*523fa7a6SAndroid Build Coastguard Worker        if target._overloadname != "default":
1112*523fa7a6SAndroid Build Coastguard Worker            op_overload = target._overloadname
1113*523fa7a6SAndroid Build Coastguard Worker
1114*523fa7a6SAndroid Build Coastguard Worker        def _get_empty_tensor_evalue() -> EValue:
1115*523fa7a6SAndroid Build Coastguard Worker            """Constructs an EValue for an empty tensor."""
1116*523fa7a6SAndroid Build Coastguard Worker            return EValue(
1117*523fa7a6SAndroid Build Coastguard Worker                Tensor(
1118*523fa7a6SAndroid Build Coastguard Worker                    scalar_type=ScalarType.BYTE,
1119*523fa7a6SAndroid Build Coastguard Worker                    # The runtime currently only supports tensors with offset 0.
1120*523fa7a6SAndroid Build Coastguard Worker                    storage_offset=0,
1121*523fa7a6SAndroid Build Coastguard Worker                    sizes=[0],
1122*523fa7a6SAndroid Build Coastguard Worker                    dim_order=[],
1123*523fa7a6SAndroid Build Coastguard Worker                    requires_grad=False,
1124*523fa7a6SAndroid Build Coastguard Worker                    layout=0,
1125*523fa7a6SAndroid Build Coastguard Worker                    data_buffer_idx=0,
1126*523fa7a6SAndroid Build Coastguard Worker                    allocation_info=None,
1127*523fa7a6SAndroid Build Coastguard Worker                    shape_dynamism=TensorShapeDynamism.STATIC,
1128*523fa7a6SAndroid Build Coastguard Worker                )
1129*523fa7a6SAndroid Build Coastguard Worker            )
1130*523fa7a6SAndroid Build Coastguard Worker
1131*523fa7a6SAndroid Build Coastguard Worker        op_index, operator = self._get_operator(name=op_name, overload=op_overload)
1132*523fa7a6SAndroid Build Coastguard Worker
1133*523fa7a6SAndroid Build Coastguard Worker        # Emit the args and kwargs in the order according to the function schema.
1134*523fa7a6SAndroid Build Coastguard Worker        kernel_args = []
1135*523fa7a6SAndroid Build Coastguard Worker        out_args = []
1136*523fa7a6SAndroid Build Coastguard Worker        for i, schema_arg in enumerate(target._schema.arguments):
1137*523fa7a6SAndroid Build Coastguard Worker            if schema_arg.name in kwargs:
1138*523fa7a6SAndroid Build Coastguard Worker                kernel_arg = kwargs[schema_arg.name]
1139*523fa7a6SAndroid Build Coastguard Worker            elif not schema_arg.kwarg_only and i < len(args):
1140*523fa7a6SAndroid Build Coastguard Worker                kernel_arg = args[i]
1141*523fa7a6SAndroid Build Coastguard Worker            else:
1142*523fa7a6SAndroid Build Coastguard Worker                # Emit default values
1143*523fa7a6SAndroid Build Coastguard Worker                kernel_arg = schema_arg.default_value
1144*523fa7a6SAndroid Build Coastguard Worker
1145*523fa7a6SAndroid Build Coastguard Worker            if kernel_arg is None and isinstance(schema_arg.type, torch.TensorType):
1146*523fa7a6SAndroid Build Coastguard Worker                kernel_arg = self._emit_evalue(_get_empty_tensor_evalue())
1147*523fa7a6SAndroid Build Coastguard Worker
1148*523fa7a6SAndroid Build Coastguard Worker            kernel_args.append(self._emit_argument(kernel_arg, schema_arg.type).id)
1149*523fa7a6SAndroid Build Coastguard Worker
1150*523fa7a6SAndroid Build Coastguard Worker            if schema_arg.is_out:
1151*523fa7a6SAndroid Build Coastguard Worker                out_args.append((schema_arg.name, kernel_arg))
1152*523fa7a6SAndroid Build Coastguard Worker
1153*523fa7a6SAndroid Build Coastguard Worker        if is_out_variant(op_name, op_overload):
1154*523fa7a6SAndroid Build Coastguard Worker            ret = [val for _, val in out_args]
1155*523fa7a6SAndroid Build Coastguard Worker            ret = ret[0] if len(ret) == 1 else ret
1156*523fa7a6SAndroid Build Coastguard Worker        elif is_sym_op(target):
1157*523fa7a6SAndroid Build Coastguard Worker            assert (
1158*523fa7a6SAndroid Build Coastguard Worker                len(target._schema.returns) == 1
1159*523fa7a6SAndroid Build Coastguard Worker            ), "Only returning a single Sym from symbolic ops is supported currently."
1160*523fa7a6SAndroid Build Coastguard Worker            assert type(target._schema.returns[0].type) in (
1161*523fa7a6SAndroid Build Coastguard Worker                torch.IntType,
1162*523fa7a6SAndroid Build Coastguard Worker                torch.FloatType,
1163*523fa7a6SAndroid Build Coastguard Worker                torch.BoolType,
1164*523fa7a6SAndroid Build Coastguard Worker                torch.NumberType,
1165*523fa7a6SAndroid Build Coastguard Worker            ), f"Only symbolic ops that return a Int Bool Float are supported currently got {type(target._schema.returns[0].type)}."
1166*523fa7a6SAndroid Build Coastguard Worker            ret = self._get_sym_ret(target._schema.returns[0])
1167*523fa7a6SAndroid Build Coastguard Worker            if ret is None:  # type(target._schema.returns[0].type) == torch.NumberType:
1168*523fa7a6SAndroid Build Coastguard Worker                # Cant definitively say what type this is, the runtime operator just overrides the EValue completely
1169*523fa7a6SAndroid Build Coastguard Worker                # though so we can just serialize whatever as a placeholder.
1170*523fa7a6SAndroid Build Coastguard Worker                ret = self._emit_evalue(EValue(Int(0)))
1171*523fa7a6SAndroid Build Coastguard Worker        else:
1172*523fa7a6SAndroid Build Coastguard Worker            ret = self._emit_spec(self.node.meta["spec"])
1173*523fa7a6SAndroid Build Coastguard Worker
1174*523fa7a6SAndroid Build Coastguard Worker        out_args = (
1175*523fa7a6SAndroid Build Coastguard Worker            self._emit_evalue(
1176*523fa7a6SAndroid Build Coastguard Worker                EValue(TensorList([cast(_AbstractValue, val).id for val in ret]))
1177*523fa7a6SAndroid Build Coastguard Worker            )
1178*523fa7a6SAndroid Build Coastguard Worker            if isinstance(ret, list)
1179*523fa7a6SAndroid Build Coastguard Worker            else ret
1180*523fa7a6SAndroid Build Coastguard Worker        )
1181*523fa7a6SAndroid Build Coastguard Worker
1182*523fa7a6SAndroid Build Coastguard Worker        for elem in pytree.tree_flatten(out_args)[0]:
1183*523fa7a6SAndroid Build Coastguard Worker            kernel_args.append(cast(_AbstractValue, elem).id)
1184*523fa7a6SAndroid Build Coastguard Worker
1185*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(
1186*523fa7a6SAndroid Build Coastguard Worker            Instruction(KernelCall(op_index=op_index, args=kernel_args))
1187*523fa7a6SAndroid Build Coastguard Worker        )
1188*523fa7a6SAndroid Build Coastguard Worker        self._add_debug_handle(len(self.chain.instructions) - 1, target)
1189*523fa7a6SAndroid Build Coastguard Worker
1190*523fa7a6SAndroid Build Coastguard Worker        # Get the stacktrace if it exists for each instruction.
1191*523fa7a6SAndroid Build Coastguard Worker        if self.emitter_state.emit_stacktrace:
1192*523fa7a6SAndroid Build Coastguard Worker            stack_trace = self.node.meta["stack_trace"]
1193*523fa7a6SAndroid Build Coastguard Worker            chain_stacktrace = self.chain.stacktrace or []
1194*523fa7a6SAndroid Build Coastguard Worker
1195*523fa7a6SAndroid Build Coastguard Worker            chain_stacktrace.append(_stacktrace_to_framelist(stack_trace))
1196*523fa7a6SAndroid Build Coastguard Worker            self._internal_assert_emitter(
1197*523fa7a6SAndroid Build Coastguard Worker                len(chain_stacktrace) == len(self.chain.instructions),
1198*523fa7a6SAndroid Build Coastguard Worker                self.node,
1199*523fa7a6SAndroid Build Coastguard Worker                f"Each instruction should have corresponding stacktrace received {len(self.chain.instructions)} \
1200*523fa7a6SAndroid Build Coastguard Worker                instructions and {len(chain_stacktrace)} stacktraces",
1201*523fa7a6SAndroid Build Coastguard Worker            )
1202*523fa7a6SAndroid Build Coastguard Worker            self.chain.stacktrace = chain_stacktrace
1203*523fa7a6SAndroid Build Coastguard Worker
1204*523fa7a6SAndroid Build Coastguard Worker        return cast(_EmitterValue, ret)
1205*523fa7a6SAndroid Build Coastguard Worker
1206*523fa7a6SAndroid Build Coastguard Worker    def _emit_free(self, spec: TensorSpec) -> _AbstractValue:
1207*523fa7a6SAndroid Build Coastguard Worker        """Emits a FreeCall instruction to release a given Unbounded Tensor's memory."""
1208*523fa7a6SAndroid Build Coastguard Worker        self.chain.instructions.append(
1209*523fa7a6SAndroid Build Coastguard Worker            Instruction(FreeCall(value_index=self.emitter_state.spec2id(spec)))
1210*523fa7a6SAndroid Build Coastguard Worker        )
1211*523fa7a6SAndroid Build Coastguard Worker        # The value is not used but the caller expects an AbstractValue returned.
1212*523fa7a6SAndroid Build Coastguard Worker        return _AbstractValue(None, None)  # pyre-ignore
1213*523fa7a6SAndroid Build Coastguard Worker
1214*523fa7a6SAndroid Build Coastguard Worker    def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
1215*523fa7a6SAndroid Build Coastguard Worker        """
1216*523fa7a6SAndroid Build Coastguard Worker        Given a mapping of function names to return values, emit simple execution
1217*523fa7a6SAndroid Build Coastguard Worker        plans that just return these constant values.
1218*523fa7a6SAndroid Build Coastguard Worker
1219*523fa7a6SAndroid Build Coastguard Worker        Precondition: All the values are primitives (bool, float, int, str, enum)
1220*523fa7a6SAndroid Build Coastguard Worker        or structures (list, dict) of them.
1221*523fa7a6SAndroid Build Coastguard Worker        """
1222*523fa7a6SAndroid Build Coastguard Worker        plans = []
1223*523fa7a6SAndroid Build Coastguard Worker        # flatten any structures
1224*523fa7a6SAndroid Build Coastguard Worker        for method, vals in prim_getters.items():
1225*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
1226*523fa7a6SAndroid Build Coastguard Worker            flattened_output, spec = ex_pytree.tree_flatten(vals)
1227*523fa7a6SAndroid Build Coastguard Worker            spec = spec.to_str()
1228*523fa7a6SAndroid Build Coastguard Worker            chain = Chain(
1229*523fa7a6SAndroid Build Coastguard Worker                inputs=[],
1230*523fa7a6SAndroid Build Coastguard Worker                outputs=[],
1231*523fa7a6SAndroid Build Coastguard Worker                instructions=[],
1232*523fa7a6SAndroid Build Coastguard Worker                stacktrace=None,
1233*523fa7a6SAndroid Build Coastguard Worker            )
1234*523fa7a6SAndroid Build Coastguard Worker
1235*523fa7a6SAndroid Build Coastguard Worker            # switch on type of prim
1236*523fa7a6SAndroid Build Coastguard Worker            values = []
1237*523fa7a6SAndroid Build Coastguard Worker            for val in flattened_output:
1238*523fa7a6SAndroid Build Coastguard Worker                if isinstance(val, float):
1239*523fa7a6SAndroid Build Coastguard Worker                    values.append(EValue(Double(val)))
1240*523fa7a6SAndroid Build Coastguard Worker
1241*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(val, bool):
1242*523fa7a6SAndroid Build Coastguard Worker                    values.append(EValue(Bool(val)))
1243*523fa7a6SAndroid Build Coastguard Worker
1244*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(val, int):
1245*523fa7a6SAndroid Build Coastguard Worker                    values.append(EValue(Int(val)))
1246*523fa7a6SAndroid Build Coastguard Worker
1247*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(val, str):
1248*523fa7a6SAndroid Build Coastguard Worker                    values.append(EValue(String(val)))
1249*523fa7a6SAndroid Build Coastguard Worker
1250*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(val, torch.dtype):
1251*523fa7a6SAndroid Build Coastguard Worker                    values.append(EValue(Int(scalar_type_enum(val))))
1252*523fa7a6SAndroid Build Coastguard Worker
1253*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(val, torch.layout):
1254*523fa7a6SAndroid Build Coastguard Worker                    values.append(EValue(Int(layout_enum(val))))
1255*523fa7a6SAndroid Build Coastguard Worker
1256*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(val, torch.Tensor):
1257*523fa7a6SAndroid Build Coastguard Worker                    values.append(
1258*523fa7a6SAndroid Build Coastguard Worker                        self._tensor_spec_to_evalue(
1259*523fa7a6SAndroid Build Coastguard Worker                            TensorSpec.from_tensor(val, const=True)
1260*523fa7a6SAndroid Build Coastguard Worker                        )
1261*523fa7a6SAndroid Build Coastguard Worker                    )
1262*523fa7a6SAndroid Build Coastguard Worker
1263*523fa7a6SAndroid Build Coastguard Worker                else:
1264*523fa7a6SAndroid Build Coastguard Worker                    raise ExportError(
1265*523fa7a6SAndroid Build Coastguard Worker                        ExportErrorType.NOT_SUPPORTED,
1266*523fa7a6SAndroid Build Coastguard Worker                        f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive",
1267*523fa7a6SAndroid Build Coastguard Worker                    )
1268*523fa7a6SAndroid Build Coastguard Worker
1269*523fa7a6SAndroid Build Coastguard Worker            # add to plans
1270*523fa7a6SAndroid Build Coastguard Worker            plans.append(
1271*523fa7a6SAndroid Build Coastguard Worker                ExecutionPlan(
1272*523fa7a6SAndroid Build Coastguard Worker                    name=method,
1273*523fa7a6SAndroid Build Coastguard Worker                    values=values,
1274*523fa7a6SAndroid Build Coastguard Worker                    inputs=[],
1275*523fa7a6SAndroid Build Coastguard Worker                    outputs=list(range(0, len(values))),
1276*523fa7a6SAndroid Build Coastguard Worker                    chains=[chain],
1277*523fa7a6SAndroid Build Coastguard Worker                    operators=[],
1278*523fa7a6SAndroid Build Coastguard Worker                    delegates=[],
1279*523fa7a6SAndroid Build Coastguard Worker                    non_const_buffer_sizes=[0],
1280*523fa7a6SAndroid Build Coastguard Worker                    container_meta_type=ContainerMetadata("", spec),
1281*523fa7a6SAndroid Build Coastguard Worker                )
1282*523fa7a6SAndroid Build Coastguard Worker            )
1283*523fa7a6SAndroid Build Coastguard Worker        return plans
1284*523fa7a6SAndroid Build Coastguard Worker
1285*523fa7a6SAndroid Build Coastguard Worker    def fetch_attr(self, target: _Target) -> _AbstractValue:
1286*523fa7a6SAndroid Build Coastguard Worker        """Fetch weights and other module parameters. If the attribute is a tensor, emit it."""
1287*523fa7a6SAndroid Build Coastguard Worker        attr = super().fetch_attr(target)  # pyre-fixme[6]
1288*523fa7a6SAndroid Build Coastguard Worker
1289*523fa7a6SAndroid Build Coastguard Worker        if isinstance(attr, torch.Tensor):
1290*523fa7a6SAndroid Build Coastguard Worker            return self._emit_evalue(
1291*523fa7a6SAndroid Build Coastguard Worker                self._tensor_spec_to_evalue(TensorSpec.from_tensor(attr, const=True))
1292*523fa7a6SAndroid Build Coastguard Worker            )
1293*523fa7a6SAndroid Build Coastguard Worker
1294*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(attr, torch._C.ScriptObject):
1295*523fa7a6SAndroid Build Coastguard Worker            raise ExportError(
1296*523fa7a6SAndroid Build Coastguard Worker                ExportErrorType.NOT_SUPPORTED,
1297*523fa7a6SAndroid Build Coastguard Worker                f"Custom class {attr} is not supported in EXIR",
1298*523fa7a6SAndroid Build Coastguard Worker            )
1299*523fa7a6SAndroid Build Coastguard Worker
1300*523fa7a6SAndroid Build Coastguard Worker        else:
1301*523fa7a6SAndroid Build Coastguard Worker            return attr
1302*523fa7a6SAndroid Build Coastguard Worker
1303*523fa7a6SAndroid Build Coastguard Worker    def call_module(  # pyre-fixme[14]
1304*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1305*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
1306*523fa7a6SAndroid Build Coastguard Worker        """Unsupported in execution IR, so unhandled by the emitter."""
1307*523fa7a6SAndroid Build Coastguard Worker        raise InternalError(
1308*523fa7a6SAndroid Build Coastguard Worker            self._emit_node_specific_error(self.node, "call_module is not supported")
1309*523fa7a6SAndroid Build Coastguard Worker        )
1310*523fa7a6SAndroid Build Coastguard Worker
1311*523fa7a6SAndroid Build Coastguard Worker    def call_method(  # pyre-fixme[14]
1312*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1313*523fa7a6SAndroid Build Coastguard Worker    ) -> _EmitterValue:
1314*523fa7a6SAndroid Build Coastguard Worker        """Unsupported in execution IR, so unhandled by the emitter."""
1315*523fa7a6SAndroid Build Coastguard Worker        raise InternalError(
1316*523fa7a6SAndroid Build Coastguard Worker            self._emit_node_specific_error(self.node, "call_method is not supported")
1317*523fa7a6SAndroid Build Coastguard Worker        )
1318*523fa7a6SAndroid Build Coastguard Worker
1319*523fa7a6SAndroid Build Coastguard Worker    def placeholder(  # pyre-fixme[14]
1320*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1321*523fa7a6SAndroid Build Coastguard Worker    ) -> _AbstractValue:
1322*523fa7a6SAndroid Build Coastguard Worker        """Performs actions for the placeholder node of a graph module.
1323*523fa7a6SAndroid Build Coastguard Worker
1324*523fa7a6SAndroid Build Coastguard Worker        The placeholder node of the top level entry point is handled by TopLevelEmitter. This
1325*523fa7a6SAndroid Build Coastguard Worker        function only executes on control flow subgraphs. Takes the inputs of the subgraph that had
1326*523fa7a6SAndroid Build Coastguard Worker        not previously been emitted and emits them.
1327*523fa7a6SAndroid Build Coastguard Worker        """
1328*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
1329*523fa7a6SAndroid Build Coastguard Worker        value = self.binding_input_values[self.placeholder_count]
1330*523fa7a6SAndroid Build Coastguard Worker        # This indicates that the placeholder wasn't allocated an EValue id before this sub-emitter
1331*523fa7a6SAndroid Build Coastguard Worker        # was run, so we generate one now.
1332*523fa7a6SAndroid Build Coastguard Worker        if value == -1:
1333*523fa7a6SAndroid Build Coastguard Worker            value = self._emit_evalue(
1334*523fa7a6SAndroid Build Coastguard Worker                self._tensor_spec_to_evalue(self.node.meta["spec"])
1335*523fa7a6SAndroid Build Coastguard Worker            )
1336*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
1337*523fa7a6SAndroid Build Coastguard Worker            self.binding_input_values[self.placeholder_count] = value
1338*523fa7a6SAndroid Build Coastguard Worker        self.placeholder_count += 1
1339*523fa7a6SAndroid Build Coastguard Worker        return value
1340*523fa7a6SAndroid Build Coastguard Worker
1341*523fa7a6SAndroid Build Coastguard Worker    def output(  # pyre-fixme[14]
1342*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1343*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
1344*523fa7a6SAndroid Build Coastguard Worker        """Performs actions for the output node of a graph module.
1345*523fa7a6SAndroid Build Coastguard Worker
1346*523fa7a6SAndroid Build Coastguard Worker        The output node of the top level entry point is handled by TopLevelEmitter. This function
1347*523fa7a6SAndroid Build Coastguard Worker        only executes on control flow subgraphs. Takes the outputs of the subgraph (if any) and
1348*523fa7a6SAndroid Build Coastguard Worker        inserts instructions to move them to the common output location between control flow
1349*523fa7a6SAndroid Build Coastguard Worker        branches.
1350*523fa7a6SAndroid Build Coastguard Worker        """
1351*523fa7a6SAndroid Build Coastguard Worker        self.concrete_output_ids = list(pytree.tree_flatten(args[0])[0])
1352*523fa7a6SAndroid Build Coastguard Worker        binding_output_values = self.binding_output_values
1353*523fa7a6SAndroid Build Coastguard Worker        if binding_output_values is not None:
1354*523fa7a6SAndroid Build Coastguard Worker            binding_output_list, _ = pytree.tree_flatten(binding_output_values)
1355*523fa7a6SAndroid Build Coastguard Worker
1356*523fa7a6SAndroid Build Coastguard Worker            self._internal_assert_emitter(
1357*523fa7a6SAndroid Build Coastguard Worker                len(binding_output_list) == len(self.concrete_output_ids),
1358*523fa7a6SAndroid Build Coastguard Worker                self.node,
1359*523fa7a6SAndroid Build Coastguard Worker                "The number of binding output values should match the args to output",
1360*523fa7a6SAndroid Build Coastguard Worker            )
1361*523fa7a6SAndroid Build Coastguard Worker
1362*523fa7a6SAndroid Build Coastguard Worker            for move_from, move_to in zip(
1363*523fa7a6SAndroid Build Coastguard Worker                self.concrete_output_ids, binding_output_list
1364*523fa7a6SAndroid Build Coastguard Worker            ):
1365*523fa7a6SAndroid Build Coastguard Worker                if move_from != move_to:
1366*523fa7a6SAndroid Build Coastguard Worker                    instruction = Instruction(
1367*523fa7a6SAndroid Build Coastguard Worker                        MoveCall(move_from=move_from.id, move_to=move_to.id)
1368*523fa7a6SAndroid Build Coastguard Worker                    )
1369*523fa7a6SAndroid Build Coastguard Worker                    self.chain.instructions.append(instruction)
1370*523fa7a6SAndroid Build Coastguard Worker
1371*523fa7a6SAndroid Build Coastguard Worker    def call_function(  # pyre-fixme[14]
1372*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1373*523fa7a6SAndroid Build Coastguard Worker    ) -> _EmitterValue:
1374*523fa7a6SAndroid Build Coastguard Worker        """Performs actions for the call_function node of a graph module.
1375*523fa7a6SAndroid Build Coastguard Worker
1376*523fa7a6SAndroid Build Coastguard Worker        Dispatches based on 'target' and emits the corresponding function. 'call_function' is a
1377*523fa7a6SAndroid Build Coastguard Worker        powerful node that contains many operations ranging from control_flow, to memory management,
1378*523fa7a6SAndroid Build Coastguard Worker        to delegate and operator calls.
1379*523fa7a6SAndroid Build Coastguard Worker        """
1380*523fa7a6SAndroid Build Coastguard Worker
1381*523fa7a6SAndroid Build Coastguard Worker        # Delegate and operator calls are the only functions that should have a debug handle
1382*523fa7a6SAndroid Build Coastguard Worker        # associated with them. All the others such as memory.alloc, getitem should be ignored.
1383*523fa7a6SAndroid Build Coastguard Worker        # Default to none and let delegates and ops override.
1384*523fa7a6SAndroid Build Coastguard Worker        if target == operator.getitem:
1385*523fa7a6SAndroid Build Coastguard Worker            assert len(args) == 2
1386*523fa7a6SAndroid Build Coastguard Worker            head = typing.cast(Mapping[int, _EmitterValue], args[0])
1387*523fa7a6SAndroid Build Coastguard Worker            index = typing.cast(int, args[1])
1388*523fa7a6SAndroid Build Coastguard Worker            return head[index]
1389*523fa7a6SAndroid Build Coastguard Worker
1390*523fa7a6SAndroid Build Coastguard Worker        elif target == memory.alloc:
1391*523fa7a6SAndroid Build Coastguard Worker            assert len(args) == 1
1392*523fa7a6SAndroid Build Coastguard Worker            return self._emit_spec(self.node.meta["spec"])
1393*523fa7a6SAndroid Build Coastguard Worker
1394*523fa7a6SAndroid Build Coastguard Worker        elif target == memory.view:
1395*523fa7a6SAndroid Build Coastguard Worker            return self._emit_view(args)
1396*523fa7a6SAndroid Build Coastguard Worker
1397*523fa7a6SAndroid Build Coastguard Worker        elif target == memory.free:
1398*523fa7a6SAndroid Build Coastguard Worker            assert len(args) == 1
1399*523fa7a6SAndroid Build Coastguard Worker            # pyre-ignore
1400*523fa7a6SAndroid Build Coastguard Worker            return self._emit_free(args[0])
1401*523fa7a6SAndroid Build Coastguard Worker
1402*523fa7a6SAndroid Build Coastguard Worker        elif target is torch.ops.higher_order.cond:
1403*523fa7a6SAndroid Build Coastguard Worker            return self._emit_control_flow(target, args, kwargs)
1404*523fa7a6SAndroid Build Coastguard Worker
1405*523fa7a6SAndroid Build Coastguard Worker        elif target is torch.ops.higher_order.map_impl:
1406*523fa7a6SAndroid Build Coastguard Worker            return self._emit_control_flow(target, args, kwargs)
1407*523fa7a6SAndroid Build Coastguard Worker
1408*523fa7a6SAndroid Build Coastguard Worker        elif target == executorch_call_delegate:
1409*523fa7a6SAndroid Build Coastguard Worker            lowered_module = args[0]
1410*523fa7a6SAndroid Build Coastguard Worker            assert is_lowered_module(lowered_module)
1411*523fa7a6SAndroid Build Coastguard Worker            v = self._emit_delegate(lowered_module, args[1:], kwargs)
1412*523fa7a6SAndroid Build Coastguard Worker            delegate_instruction_id = len(self.chain.instructions) - 1
1413*523fa7a6SAndroid Build Coastguard Worker            self._add_debug_handle(delegate_instruction_id, target, lowered_module)
1414*523fa7a6SAndroid Build Coastguard Worker            self._add_delegate_map(lowered_module, delegate_instruction_id)
1415*523fa7a6SAndroid Build Coastguard Worker            return v
1416*523fa7a6SAndroid Build Coastguard Worker
1417*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(
1418*523fa7a6SAndroid Build Coastguard Worker            target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload)
1419*523fa7a6SAndroid Build Coastguard Worker        ):
1420*523fa7a6SAndroid Build Coastguard Worker            return self._emit_operator(target, args, kwargs)
1421*523fa7a6SAndroid Build Coastguard Worker
1422*523fa7a6SAndroid Build Coastguard Worker        else:
1423*523fa7a6SAndroid Build Coastguard Worker            raise InternalError(
1424*523fa7a6SAndroid Build Coastguard Worker                self._emit_node_specific_error(
1425*523fa7a6SAndroid Build Coastguard Worker                    self.node, f"invalid target for call_function {target}"
1426*523fa7a6SAndroid Build Coastguard Worker                )
1427*523fa7a6SAndroid Build Coastguard Worker            )
1428*523fa7a6SAndroid Build Coastguard Worker
1429*523fa7a6SAndroid Build Coastguard Worker    def run(  # pyre-fixme[14]
1430*523fa7a6SAndroid Build Coastguard Worker        self,
1431*523fa7a6SAndroid Build Coastguard Worker        *args: _Argument,
1432*523fa7a6SAndroid Build Coastguard Worker        initial_env: Optional[Dict[torch.fx.Node, _Argument]] = None,
1433*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
1434*523fa7a6SAndroid Build Coastguard Worker        """Traverses all nodes in the graph module and emits each one appropriately."""
1435*523fa7a6SAndroid Build Coastguard Worker        super().run(*args, initial_env, enable_io_processing=False)
1436*523fa7a6SAndroid Build Coastguard Worker
1437*523fa7a6SAndroid Build Coastguard Worker    def run_node(self, n: torch.fx.Node) -> None:
1438*523fa7a6SAndroid Build Coastguard Worker        """Executes and emits the specified node.
1439*523fa7a6SAndroid Build Coastguard Worker
1440*523fa7a6SAndroid Build Coastguard Worker        For more context on what a node is and what execution means see
1441*523fa7a6SAndroid Build Coastguard Worker        https://pytorch.org/docs/stable/fx.html#torch.fx.Node
1442*523fa7a6SAndroid Build Coastguard Worker        """
1443*523fa7a6SAndroid Build Coastguard Worker        self.node = n
1444*523fa7a6SAndroid Build Coastguard Worker        try:
1445*523fa7a6SAndroid Build Coastguard Worker            ret = super().run_node(n)
1446*523fa7a6SAndroid Build Coastguard Worker        except Exception as e:
1447*523fa7a6SAndroid Build Coastguard Worker            if isinstance(e, (InternalError, ExportError)):
1448*523fa7a6SAndroid Build Coastguard Worker                raise e
1449*523fa7a6SAndroid Build Coastguard Worker            else:
1450*523fa7a6SAndroid Build Coastguard Worker                raise InternalError(
1451*523fa7a6SAndroid Build Coastguard Worker                    self._emit_node_specific_error(self.node, str(e))
1452*523fa7a6SAndroid Build Coastguard Worker                ) from e
1453*523fa7a6SAndroid Build Coastguard Worker        return ret
1454*523fa7a6SAndroid Build Coastguard Worker
1455*523fa7a6SAndroid Build Coastguard Worker
1456*523fa7a6SAndroid Build Coastguard Workerclass _TopLevelEmitter(_Emitter):
1457*523fa7a6SAndroid Build Coastguard Worker    """An emitter that manages the root level operations within a graph module.
1458*523fa7a6SAndroid Build Coastguard Worker
1459*523fa7a6SAndroid Build Coastguard Worker    Exists as a separate class so that 'Emitter' can handle the special behavior of 'placeholder'
1460*523fa7a6SAndroid Build Coastguard Worker    and 'output' nodes in control flow submodules.
1461*523fa7a6SAndroid Build Coastguard Worker    """
1462*523fa7a6SAndroid Build Coastguard Worker
1463*523fa7a6SAndroid Build Coastguard Worker    def __init__(
1464*523fa7a6SAndroid Build Coastguard Worker        self,
1465*523fa7a6SAndroid Build Coastguard Worker        name: str,
1466*523fa7a6SAndroid Build Coastguard Worker        exported_program: ExportedProgram,
1467*523fa7a6SAndroid Build Coastguard Worker        graph_module: torch.fx.GraphModule,
1468*523fa7a6SAndroid Build Coastguard Worker        program_state: _ProgramState,
1469*523fa7a6SAndroid Build Coastguard Worker        emitter_state: _EmitterState,
1470*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
1471*523fa7a6SAndroid Build Coastguard Worker        super().__init__(graph_module, emitter_state, program_state)
1472*523fa7a6SAndroid Build Coastguard Worker        self.name = name
1473*523fa7a6SAndroid Build Coastguard Worker        self.exported_program = exported_program
1474*523fa7a6SAndroid Build Coastguard Worker
1475*523fa7a6SAndroid Build Coastguard Worker        self.inputs: List[int] = []
1476*523fa7a6SAndroid Build Coastguard Worker        self.outputs: List[int] = []
1477*523fa7a6SAndroid Build Coastguard Worker        self.given_mutable_buffer_warning = False
1478*523fa7a6SAndroid Build Coastguard Worker
1479*523fa7a6SAndroid Build Coastguard Worker        def create_container_str(spec: Optional[pytree.TreeSpec]) -> str:
1480*523fa7a6SAndroid Build Coastguard Worker            if spec is None:
1481*523fa7a6SAndroid Build Coastguard Worker                return ""
1482*523fa7a6SAndroid Build Coastguard Worker            assert isinstance(spec, pytree.TreeSpec), type(spec)
1483*523fa7a6SAndroid Build Coastguard Worker            dummy_leaves = [0] * spec.num_leaves
1484*523fa7a6SAndroid Build Coastguard Worker            tree = torch.utils._pytree.tree_unflatten(dummy_leaves, spec)
1485*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
1486*523fa7a6SAndroid Build Coastguard Worker            _, tree = ex_pytree.tree_flatten(tree)
1487*523fa7a6SAndroid Build Coastguard Worker            return tree.to_str()
1488*523fa7a6SAndroid Build Coastguard Worker
1489*523fa7a6SAndroid Build Coastguard Worker        inp_container_str = create_container_str(exported_program.call_spec.in_spec)
1490*523fa7a6SAndroid Build Coastguard Worker        out_container_str = create_container_str(exported_program.call_spec.out_spec)
1491*523fa7a6SAndroid Build Coastguard Worker
1492*523fa7a6SAndroid Build Coastguard Worker        self.container_meta_type = ContainerMetadata(
1493*523fa7a6SAndroid Build Coastguard Worker            inp_container_str, out_container_str
1494*523fa7a6SAndroid Build Coastguard Worker        )
1495*523fa7a6SAndroid Build Coastguard Worker
1496*523fa7a6SAndroid Build Coastguard Worker    def _find_fqn_for_placeholder(
1497*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, spec: Any  # pyre-ignore[2]
1498*523fa7a6SAndroid Build Coastguard Worker    ) -> Tuple[Optional[str], bool]:
1499*523fa7a6SAndroid Build Coastguard Worker        # Find the fully qualified name
1500*523fa7a6SAndroid Build Coastguard Worker        fqn = None
1501*523fa7a6SAndroid Build Coastguard Worker        is_mutable_buffer = False
1502*523fa7a6SAndroid Build Coastguard Worker        if target in self.exported_program.graph_signature.inputs_to_parameters:
1503*523fa7a6SAndroid Build Coastguard Worker            fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
1504*523fa7a6SAndroid Build Coastguard Worker
1505*523fa7a6SAndroid Build Coastguard Worker        elif target in self.exported_program.graph_signature.inputs_to_buffers:
1506*523fa7a6SAndroid Build Coastguard Worker            fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1507*523fa7a6SAndroid Build Coastguard Worker
1508*523fa7a6SAndroid Build Coastguard Worker            # if the buffer is mutated then record that
1509*523fa7a6SAndroid Build Coastguard Worker            if fqn in self.exported_program.graph_signature.buffers_to_mutate.values():
1510*523fa7a6SAndroid Build Coastguard Worker                is_mutable_buffer = True
1511*523fa7a6SAndroid Build Coastguard Worker                if not self.given_mutable_buffer_warning:
1512*523fa7a6SAndroid Build Coastguard Worker                    warnings.warn(
1513*523fa7a6SAndroid Build Coastguard Worker                        "Mutation on a buffer in the model is detected. ExecuTorch assumes "
1514*523fa7a6SAndroid Build Coastguard Worker                        "buffers that are mutated in the graph have a meaningless initial state, "
1515*523fa7a6SAndroid Build Coastguard Worker                        "only the shape and dtype will be serialized.",
1516*523fa7a6SAndroid Build Coastguard Worker                        UserWarning,
1517*523fa7a6SAndroid Build Coastguard Worker                        stacklevel=1,
1518*523fa7a6SAndroid Build Coastguard Worker                    )
1519*523fa7a6SAndroid Build Coastguard Worker                    self.given_mutable_buffer_warning = True
1520*523fa7a6SAndroid Build Coastguard Worker
1521*523fa7a6SAndroid Build Coastguard Worker        elif (
1522*523fa7a6SAndroid Build Coastguard Worker            target
1523*523fa7a6SAndroid Build Coastguard Worker            in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1524*523fa7a6SAndroid Build Coastguard Worker        ):
1525*523fa7a6SAndroid Build Coastguard Worker            fqn = (
1526*523fa7a6SAndroid Build Coastguard Worker                self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
1527*523fa7a6SAndroid Build Coastguard Worker                    target
1528*523fa7a6SAndroid Build Coastguard Worker                ]
1529*523fa7a6SAndroid Build Coastguard Worker            )
1530*523fa7a6SAndroid Build Coastguard Worker        return fqn, is_mutable_buffer
1531*523fa7a6SAndroid Build Coastguard Worker
1532*523fa7a6SAndroid Build Coastguard Worker    def placeholder(
1533*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1534*523fa7a6SAndroid Build Coastguard Worker    ) -> _AbstractValue:
1535*523fa7a6SAndroid Build Coastguard Worker        """Emits the value within the placeholder node.
1536*523fa7a6SAndroid Build Coastguard Worker
1537*523fa7a6SAndroid Build Coastguard Worker        For more information on placeholder nodes see
1538*523fa7a6SAndroid Build Coastguard Worker        https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
1539*523fa7a6SAndroid Build Coastguard Worker        """
1540*523fa7a6SAndroid Build Coastguard Worker        spec = self.node.meta["spec"]
1541*523fa7a6SAndroid Build Coastguard Worker        is_user_input = True
1542*523fa7a6SAndroid Build Coastguard Worker
1543*523fa7a6SAndroid Build Coastguard Worker        if isinstance(target, str) and isinstance(spec, TensorSpec):
1544*523fa7a6SAndroid Build Coastguard Worker            fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
1545*523fa7a6SAndroid Build Coastguard Worker
1546*523fa7a6SAndroid Build Coastguard Worker            # From the fqn find the corresponding tensor
1547*523fa7a6SAndroid Build Coastguard Worker            real_tensor = None
1548*523fa7a6SAndroid Build Coastguard Worker            if fqn in self.exported_program.state_dict:
1549*523fa7a6SAndroid Build Coastguard Worker                real_tensor = self.exported_program.state_dict[fqn]
1550*523fa7a6SAndroid Build Coastguard Worker                is_user_input = False
1551*523fa7a6SAndroid Build Coastguard Worker
1552*523fa7a6SAndroid Build Coastguard Worker            elif fqn in self.exported_program.constants:
1553*523fa7a6SAndroid Build Coastguard Worker                real_tensor = self.exported_program.constants[fqn]
1554*523fa7a6SAndroid Build Coastguard Worker                is_user_input = False
1555*523fa7a6SAndroid Build Coastguard Worker            elif fqn is not None:
1556*523fa7a6SAndroid Build Coastguard Worker                buffers = self.exported_program.named_buffers()
1557*523fa7a6SAndroid Build Coastguard Worker                buf = next((x[1] for x in buffers if x[0] == fqn), None)
1558*523fa7a6SAndroid Build Coastguard Worker                if buf is not None:
1559*523fa7a6SAndroid Build Coastguard Worker                    real_tensor = buf
1560*523fa7a6SAndroid Build Coastguard Worker                    is_user_input = False
1561*523fa7a6SAndroid Build Coastguard Worker                else:
1562*523fa7a6SAndroid Build Coastguard Worker                    raise InternalError(
1563*523fa7a6SAndroid Build Coastguard Worker                        self._emit_node_specific_error(
1564*523fa7a6SAndroid Build Coastguard Worker                            self.node,
1565*523fa7a6SAndroid Build Coastguard Worker                            f"Could not find buffer with fqn {fqn} in state_dict or named_buffers",
1566*523fa7a6SAndroid Build Coastguard Worker                        )
1567*523fa7a6SAndroid Build Coastguard Worker                    )
1568*523fa7a6SAndroid Build Coastguard Worker
1569*523fa7a6SAndroid Build Coastguard Worker            # assign the storage of the placeholder spec to the storage of the real tensor if there is one
1570*523fa7a6SAndroid Build Coastguard Worker            if real_tensor is not None:
1571*523fa7a6SAndroid Build Coastguard Worker                # for non-contigous tensors, convert to a contiguous one
1572*523fa7a6SAndroid Build Coastguard Worker                real_tensor = real_tensor.contiguous()
1573*523fa7a6SAndroid Build Coastguard Worker                # Weights cannot be views during emission or serialization
1574*523fa7a6SAndroid Build Coastguard Worker                if real_tensor.nbytes != real_tensor.untyped_storage().nbytes():
1575*523fa7a6SAndroid Build Coastguard Worker                    real_tensor = real_tensor.clone()
1576*523fa7a6SAndroid Build Coastguard Worker
1577*523fa7a6SAndroid Build Coastguard Worker                spec.storage = real_tensor.untyped_storage()
1578*523fa7a6SAndroid Build Coastguard Worker
1579*523fa7a6SAndroid Build Coastguard Worker            # User inputs and mutable buffers are not constants, other buffers or parameters are.
1580*523fa7a6SAndroid Build Coastguard Worker            spec.const = not (is_user_input or is_mutable_buffer)
1581*523fa7a6SAndroid Build Coastguard Worker
1582*523fa7a6SAndroid Build Coastguard Worker        evalue = (
1583*523fa7a6SAndroid Build Coastguard Worker            self._tensor_spec_to_evalue(spec)
1584*523fa7a6SAndroid Build Coastguard Worker            if isinstance(spec, TensorSpec)
1585*523fa7a6SAndroid Build Coastguard Worker            else self._constant_to_evalue(spec, None)
1586*523fa7a6SAndroid Build Coastguard Worker        )
1587*523fa7a6SAndroid Build Coastguard Worker        value = self._emit_evalue(evalue)
1588*523fa7a6SAndroid Build Coastguard Worker
1589*523fa7a6SAndroid Build Coastguard Worker        # Only user inputs should remain as inputs.
1590*523fa7a6SAndroid Build Coastguard Worker        if is_user_input:
1591*523fa7a6SAndroid Build Coastguard Worker            self.inputs.append(value.id)
1592*523fa7a6SAndroid Build Coastguard Worker
1593*523fa7a6SAndroid Build Coastguard Worker        return value
1594*523fa7a6SAndroid Build Coastguard Worker
1595*523fa7a6SAndroid Build Coastguard Worker    def output(
1596*523fa7a6SAndroid Build Coastguard Worker        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
1597*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
1598*523fa7a6SAndroid Build Coastguard Worker        """Records the ExecutionPlan's outputs based on the output node in the graph."""
1599*523fa7a6SAndroid Build Coastguard Worker        if isinstance(args[0], dict):
1600*523fa7a6SAndroid Build Coastguard Worker            args_tuple, _ = pytree.tree_flatten(args[0])
1601*523fa7a6SAndroid Build Coastguard Worker        else:
1602*523fa7a6SAndroid Build Coastguard Worker            args_tuple = typing.cast(Tuple[_AbstractValue, ...], args[0])
1603*523fa7a6SAndroid Build Coastguard Worker        if isinstance(args_tuple, _AbstractValue):
1604*523fa7a6SAndroid Build Coastguard Worker            self.outputs.append(args_tuple.id)
1605*523fa7a6SAndroid Build Coastguard Worker        else:
1606*523fa7a6SAndroid Build Coastguard Worker            for arg in args_tuple:
1607*523fa7a6SAndroid Build Coastguard Worker                if isinstance(arg, (int, float, bool, type(None))):
1608*523fa7a6SAndroid Build Coastguard Worker                    arg = self._emit_evalue(self._constant_to_evalue(arg, None))
1609*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(arg, str):
1610*523fa7a6SAndroid Build Coastguard Worker                    # TODO(jackkhuu): T181599879 Add support for string outputs IFF compiler supports
1611*523fa7a6SAndroid Build Coastguard Worker                    raise InternalError(
1612*523fa7a6SAndroid Build Coastguard Worker                        self._emit_node_specific_error(
1613*523fa7a6SAndroid Build Coastguard Worker                            self.node,
1614*523fa7a6SAndroid Build Coastguard Worker                            f"Returning {arg} is not yet supported in the emitter.",
1615*523fa7a6SAndroid Build Coastguard Worker                        )
1616*523fa7a6SAndroid Build Coastguard Worker                    )
1617*523fa7a6SAndroid Build Coastguard Worker                else:
1618*523fa7a6SAndroid Build Coastguard Worker                    # Every other output should already have its value emitted.
1619*523fa7a6SAndroid Build Coastguard Worker                    # They should only be abstract IDs at this point.
1620*523fa7a6SAndroid Build Coastguard Worker                    assert isinstance(arg, _AbstractValue)
1621*523fa7a6SAndroid Build Coastguard Worker                self.outputs.append(arg.id)
1622*523fa7a6SAndroid Build Coastguard Worker
1623*523fa7a6SAndroid Build Coastguard Worker    def plan(self) -> ExecutionPlan:
1624*523fa7a6SAndroid Build Coastguard Worker        """Returns the execution plan emitted from this entry point."""
1625*523fa7a6SAndroid Build Coastguard Worker        return ExecutionPlan(
1626*523fa7a6SAndroid Build Coastguard Worker            name=self.name,
1627*523fa7a6SAndroid Build Coastguard Worker            values=self.emitter_state.values,
1628*523fa7a6SAndroid Build Coastguard Worker            inputs=self.inputs,
1629*523fa7a6SAndroid Build Coastguard Worker            outputs=self.outputs,
1630*523fa7a6SAndroid Build Coastguard Worker            chains=[self.chain],
1631*523fa7a6SAndroid Build Coastguard Worker            operators=self.emitter_state.operators,
1632*523fa7a6SAndroid Build Coastguard Worker            delegates=self.emitter_state.delegates,
1633*523fa7a6SAndroid Build Coastguard Worker            # non_const_buffer_sizes field is set by the memory_planning_pass. In case the field is
1634*523fa7a6SAndroid Build Coastguard Worker            # missing in scenarios like unit test that does not enable memory planning, assume an
1635*523fa7a6SAndroid Build Coastguard Worker            # empty list.
1636*523fa7a6SAndroid Build Coastguard Worker            non_const_buffer_sizes=typing.cast(
1637*523fa7a6SAndroid Build Coastguard Worker                # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorB...
1638*523fa7a6SAndroid Build Coastguard Worker                List[int], self.module.meta["non_const_buffer_sizes"]
1639*523fa7a6SAndroid Build Coastguard Worker            ),
1640*523fa7a6SAndroid Build Coastguard Worker            container_meta_type=self.container_meta_type,
1641*523fa7a6SAndroid Build Coastguard Worker        )
1642