1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker"""Takes an ExportedArtifact, or a collection of ExportedArtifacts, in execution dialect, and turns 8*523fa7a6SAndroid Build Coastguard Workerthem into a single ExecuTorch Program. 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard WorkerThe provided ExportedArtifact's graph modules are in execution dialect and the emitter parses and 11*523fa7a6SAndroid Build Coastguard Workerconverts them into executorch instructions. The emitter walks the provided graphs and as it 12*523fa7a6SAndroid Build Coastguard Workerencounters concrete values such as tensors or ints, it converts them to the serialized format and 13*523fa7a6SAndroid Build Coastguard Workerstores them in a list for later use. The emitter walks the graph by traversing fx.nodes, these can 14*523fa7a6SAndroid Build Coastguard Workercome in a variety of forms and are the primitives of execution at the graph module level. The most 15*523fa7a6SAndroid Build Coastguard Workercommon 3 we care about are 'call_function', 'place_holder', and 'output'. 'placeholder' and 'output' 16*523fa7a6SAndroid Build Coastguard Workerhandle io for the module and 'call_function' handles everything else. Within 'call_function' we may 17*523fa7a6SAndroid Build Coastguard Workerencounter an operator or delegate call, in such case we parse the schema and emit all the inputs and 18*523fa7a6SAndroid Build Coastguard Workeroutputs (unless they have already previously been emitted), and then we convert the actual function 19*523fa7a6SAndroid Build Coastguard Workercall into an executorch instruction such as KernelCall or DelegateCall. 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard WorkerWhen control flow is present in the graphmodule it will take the form of a few different types of 22*523fa7a6SAndroid Build Coastguard Worker'call_function'. Today (June 14th 2023) only cond and map are supported. The actual operations of 23*523fa7a6SAndroid Build Coastguard Workerthese, such as the true/false branches of cond, or the mapping function of map, are stored as sub 24*523fa7a6SAndroid Build Coastguard Workergraphmodules. When these are encountered during emission, the emitter will recursively emit them and 25*523fa7a6SAndroid Build Coastguard Workertheir instructions. 26*523fa7a6SAndroid Build Coastguard Worker""" 27*523fa7a6SAndroid Build Coastguard Worker# TODO(jakeszwe): add information here about how weights and other parameters are handled in the 28*523fa7a6SAndroid Build Coastguard Worker# presence of aot autograd param lifting. 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 31*523fa7a6SAndroid Build Coastguard Workerimport ctypes 32*523fa7a6SAndroid Build Coastguard Workerimport hashlib 33*523fa7a6SAndroid Build Coastguard Workerimport operator 34*523fa7a6SAndroid Build Coastguard Workerimport typing 35*523fa7a6SAndroid Build Coastguard Workerimport warnings 36*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field 37*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, Union 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.memory as memory 40*523fa7a6SAndroid Build Coastguard Workerimport executorch.extension.pytree as ex_pytree 41*523fa7a6SAndroid Build Coastguard Workerimport torch 42*523fa7a6SAndroid Build Coastguard Workerimport torch.fx 43*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate, is_lowered_module 44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.backend._ops import BackendOpOverload 45*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload 46*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, InternalError 47*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.operator.convert import is_out_variant 48*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.executorch_prim_ops_registry import is_sym_op 49*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.print_program import _stacktrace_to_framelist, inspect_node 50*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import ( 51*523fa7a6SAndroid Build Coastguard Worker BackendDelegate, 52*523fa7a6SAndroid Build Coastguard Worker BackendDelegateDataReference, 53*523fa7a6SAndroid Build Coastguard Worker BackendDelegateInlineData, 54*523fa7a6SAndroid Build Coastguard Worker Bool, 55*523fa7a6SAndroid Build Coastguard Worker BoolList, 56*523fa7a6SAndroid Build Coastguard Worker Buffer, 57*523fa7a6SAndroid Build Coastguard Worker Chain, 58*523fa7a6SAndroid Build Coastguard Worker ContainerMetadata, 59*523fa7a6SAndroid Build Coastguard Worker DataLocation, 60*523fa7a6SAndroid Build Coastguard Worker DelegateCall, 61*523fa7a6SAndroid Build Coastguard Worker Double, 62*523fa7a6SAndroid Build Coastguard Worker DoubleList, 63*523fa7a6SAndroid Build Coastguard Worker EValue, 64*523fa7a6SAndroid Build Coastguard Worker ExecutionPlan, 65*523fa7a6SAndroid Build Coastguard Worker FreeCall, 66*523fa7a6SAndroid Build Coastguard Worker Instruction, 67*523fa7a6SAndroid Build Coastguard Worker Int, 68*523fa7a6SAndroid Build Coastguard Worker IntList, 69*523fa7a6SAndroid Build Coastguard Worker JumpFalseCall, 70*523fa7a6SAndroid Build Coastguard Worker KernelCall, 71*523fa7a6SAndroid Build Coastguard Worker MoveCall, 72*523fa7a6SAndroid Build Coastguard Worker Null, 73*523fa7a6SAndroid Build Coastguard Worker Operator, 74*523fa7a6SAndroid Build Coastguard Worker OptionalTensorList, 75*523fa7a6SAndroid Build Coastguard Worker ScalarType, 76*523fa7a6SAndroid Build Coastguard Worker String, 77*523fa7a6SAndroid Build Coastguard Worker Tensor, 78*523fa7a6SAndroid Build Coastguard Worker TensorList, 79*523fa7a6SAndroid Build Coastguard Worker TensorShapeDynamism, 80*523fa7a6SAndroid Build Coastguard Worker) 81*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import ( 82*523fa7a6SAndroid Build Coastguard Worker AddressSpaceOverflowException, 83*523fa7a6SAndroid Build Coastguard Worker layout_enum, 84*523fa7a6SAndroid Build Coastguard Worker make_allocation_info, 85*523fa7a6SAndroid Build Coastguard Worker make_tensor_value, 86*523fa7a6SAndroid Build Coastguard Worker memory_format_enum, 87*523fa7a6SAndroid Build Coastguard Worker scalar_type_enum, 88*523fa7a6SAndroid Build Coastguard Worker TensorSpec, 89*523fa7a6SAndroid Build Coastguard Worker) 90*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.types import LeafValueSpec, ValueSpec 91*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram 94*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker@dataclass 100*523fa7a6SAndroid Build Coastguard Workerclass _ProgramState: 101*523fa7a6SAndroid Build Coastguard Worker """State shared between all methods of a program and the graph module it represents. 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker Initialized once within emit_program and then shared across each entry point as they are 104*523fa7a6SAndroid Build Coastguard Worker emitted. 105*523fa7a6SAndroid Build Coastguard Worker """ 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker # Parallel list of specs and the buffers that backed them, have to add + 1 to any index in here 108*523fa7a6SAndroid Build Coastguard Worker # as index 0 in the constant_buffer is reserved. 109*523fa7a6SAndroid Build Coastguard Worker allocated_specs: List[TensorSpec] = field(default_factory=list) 110*523fa7a6SAndroid Build Coastguard Worker # Weights in any arbitrary graph_module only need to compare against weights from previously 111*523fa7a6SAndroid Build Coastguard Worker # emitted graph modules, not any weights emitted from itself. This should speed up the lookup, 112*523fa7a6SAndroid Build Coastguard Worker # from O(N) to O(1) 113*523fa7a6SAndroid Build Coastguard Worker cached_spec_hash_values: Dict[str, int] = field(default_factory=dict) 114*523fa7a6SAndroid Build Coastguard Worker cached_spec_mutable_hash_values: Dict[str, int] = field(default_factory=dict) 115*523fa7a6SAndroid Build Coastguard Worker # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder. 116*523fa7a6SAndroid Build Coastguard Worker constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")]) 117*523fa7a6SAndroid Build Coastguard Worker # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder. 118*523fa7a6SAndroid Build Coastguard Worker mutable_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")]) 119*523fa7a6SAndroid Build Coastguard Worker # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference, 120*523fa7a6SAndroid Build Coastguard Worker # and should be copied to Program.backend_delegate_data. 121*523fa7a6SAndroid Build Coastguard Worker backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list) 122*523fa7a6SAndroid Build Coastguard Worker 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker@dataclass 125*523fa7a6SAndroid Build Coastguard Workerclass _EmitterState: 126*523fa7a6SAndroid Build Coastguard Worker """State of a single emitter. 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker Local to at least the entry point, and may be local to a subgraph of an entry point originating 129*523fa7a6SAndroid Build Coastguard Worker from control flow. 130*523fa7a6SAndroid Build Coastguard Worker """ 131*523fa7a6SAndroid Build Coastguard Worker 132*523fa7a6SAndroid Build Coastguard Worker values: List[EValue] 133*523fa7a6SAndroid Build Coastguard Worker operators: List[Operator] 134*523fa7a6SAndroid Build Coastguard Worker delegates: List[BackendDelegate] 135*523fa7a6SAndroid Build Coastguard Worker operator_cache: Dict[Tuple[str, str], int] 136*523fa7a6SAndroid Build Coastguard Worker delegate_cache: Dict[bytes, int] 137*523fa7a6SAndroid Build Coastguard Worker emit_stacktrace: bool 138*523fa7a6SAndroid Build Coastguard Worker 139*523fa7a6SAndroid Build Coastguard Worker spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict) 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Worker def spec2id(self, spec: TensorSpec) -> int: 142*523fa7a6SAndroid Build Coastguard Worker """Map a TensorSpec to value index in the values array.""" 143*523fa7a6SAndroid Build Coastguard Worker assert spec in self.spec2id_dict, f"Spec is not found: {spec.debug()}" 144*523fa7a6SAndroid Build Coastguard Worker return self.spec2id_dict[spec] 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Worker 147*523fa7a6SAndroid Build Coastguard Worker@dataclass 148*523fa7a6SAndroid Build Coastguard Workerclass _AbstractValue: 149*523fa7a6SAndroid Build Coastguard Worker """Represents an already emitted EValue""" 150*523fa7a6SAndroid Build Coastguard Worker 151*523fa7a6SAndroid Build Coastguard Worker # Index in the values table of this EValue. 152*523fa7a6SAndroid Build Coastguard Worker id: int 153*523fa7a6SAndroid Build Coastguard Worker 154*523fa7a6SAndroid Build Coastguard Worker # Used for sanity checks for functions that expect to only receive AbstractValues. 155*523fa7a6SAndroid Build Coastguard Worker tensor: Optional[Tensor] 156*523fa7a6SAndroid Build Coastguard Worker 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker_EmitterValue: TypeAlias = Union[ 159*523fa7a6SAndroid Build Coastguard Worker _AbstractValue, List[_AbstractValue], Tuple[_AbstractValue, ...] 160*523fa7a6SAndroid Build Coastguard Worker] 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Worker_PythonValue: TypeAlias = Union[bool, int, float, torch.Tensor, List["_PythonValue"]] 163*523fa7a6SAndroid Build Coastguard Worker_SchemaType: TypeAlias = Union[ 164*523fa7a6SAndroid Build Coastguard Worker torch.OptionalType, 165*523fa7a6SAndroid Build Coastguard Worker torch.ListType, 166*523fa7a6SAndroid Build Coastguard Worker torch.FloatType, 167*523fa7a6SAndroid Build Coastguard Worker torch.BoolType, 168*523fa7a6SAndroid Build Coastguard Worker torch.IntType, 169*523fa7a6SAndroid Build Coastguard Worker torch.StringType, 170*523fa7a6SAndroid Build Coastguard Worker torch.TensorType, 171*523fa7a6SAndroid Build Coastguard Worker] 172*523fa7a6SAndroid Build Coastguard Worker 173*523fa7a6SAndroid Build Coastguard Worker_Target: TypeAlias = Union[Callable[..., _PythonValue], str] 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker_Argument: TypeAlias = Union[ 176*523fa7a6SAndroid Build Coastguard Worker _EmitterValue, 177*523fa7a6SAndroid Build Coastguard Worker Tuple["_Argument", ...], 178*523fa7a6SAndroid Build Coastguard Worker List["_Argument"], 179*523fa7a6SAndroid Build Coastguard Worker Dict[str, "_Argument"], 180*523fa7a6SAndroid Build Coastguard Worker str, 181*523fa7a6SAndroid Build Coastguard Worker int, 182*523fa7a6SAndroid Build Coastguard Worker float, 183*523fa7a6SAndroid Build Coastguard Worker bool, 184*523fa7a6SAndroid Build Coastguard Worker complex, 185*523fa7a6SAndroid Build Coastguard Worker torch.dtype, 186*523fa7a6SAndroid Build Coastguard Worker torch.Tensor, 187*523fa7a6SAndroid Build Coastguard Worker torch.memory_format, 188*523fa7a6SAndroid Build Coastguard Worker torch.layout, 189*523fa7a6SAndroid Build Coastguard Worker None, 190*523fa7a6SAndroid Build Coastguard Worker] 191*523fa7a6SAndroid Build Coastguard Worker 192*523fa7a6SAndroid Build Coastguard Worker_DelegateDebugIdentifierMap: TypeAlias = Union[ 193*523fa7a6SAndroid Build Coastguard Worker Dict[int, Tuple[int]], Dict[str, Tuple[int]] 194*523fa7a6SAndroid Build Coastguard Worker] 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore[13]: Attribute `node` is never initialized. 198*523fa7a6SAndroid Build Coastguard Workerclass _Emitter(torch.fx.Interpreter): 199*523fa7a6SAndroid Build Coastguard Worker """An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the 200*523fa7a6SAndroid Build Coastguard Worker given traced torch.fx.GraphModule to the flatbuffer schema.""" 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker def __init__( 205*523fa7a6SAndroid Build Coastguard Worker self, 206*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 207*523fa7a6SAndroid Build Coastguard Worker emitter_state: _EmitterState, 208*523fa7a6SAndroid Build Coastguard Worker program_state: _ProgramState, 209*523fa7a6SAndroid Build Coastguard Worker instruction_start_offset: int = 0, 210*523fa7a6SAndroid Build Coastguard Worker binding_input_values: Optional[List[_AbstractValue]] = None, 211*523fa7a6SAndroid Build Coastguard Worker binding_output_values: Optional[List[_AbstractValue]] = None, 212*523fa7a6SAndroid Build Coastguard Worker ) -> None: 213*523fa7a6SAndroid Build Coastguard Worker super().__init__(graph_module) 214*523fa7a6SAndroid Build Coastguard Worker self.emitter_state = emitter_state 215*523fa7a6SAndroid Build Coastguard Worker self.program_state = program_state 216*523fa7a6SAndroid Build Coastguard Worker self.outputs: List[int] = [] 217*523fa7a6SAndroid Build Coastguard Worker 218*523fa7a6SAndroid Build Coastguard Worker self.chain = Chain( 219*523fa7a6SAndroid Build Coastguard Worker inputs=[], 220*523fa7a6SAndroid Build Coastguard Worker outputs=[], 221*523fa7a6SAndroid Build Coastguard Worker instructions=[], 222*523fa7a6SAndroid Build Coastguard Worker stacktrace=None, 223*523fa7a6SAndroid Build Coastguard Worker ) 224*523fa7a6SAndroid Build Coastguard Worker 225*523fa7a6SAndroid Build Coastguard Worker if "non_const_buffer_sizes" not in graph_module.meta.keys(): 226*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 227*523fa7a6SAndroid Build Coastguard Worker "Must set 'non_const_buffer_sizes' in graph meta in memory planning pass" 228*523fa7a6SAndroid Build Coastguard Worker ) 229*523fa7a6SAndroid Build Coastguard Worker self.instruction_start_offset = instruction_start_offset 230*523fa7a6SAndroid Build Coastguard Worker self.binding_input_values = binding_input_values 231*523fa7a6SAndroid Build Coastguard Worker self.binding_output_values = binding_output_values 232*523fa7a6SAndroid Build Coastguard Worker self.graph_module: torch.fx.GraphModule = graph_module 233*523fa7a6SAndroid Build Coastguard Worker self.nodes: List[torch.fx.Node] = list(self.graph_module.graph.nodes) 234*523fa7a6SAndroid Build Coastguard Worker 235*523fa7a6SAndroid Build Coastguard Worker # Marks the placeholder node with its order so that we can match them with the corresponding 236*523fa7a6SAndroid Build Coastguard Worker # Abstract Value coming from top level. 237*523fa7a6SAndroid Build Coastguard Worker self.placeholder_count = 0 238*523fa7a6SAndroid Build Coastguard Worker 239*523fa7a6SAndroid Build Coastguard Worker self.concrete_output_ids: List[_AbstractValue] = [] 240*523fa7a6SAndroid Build Coastguard Worker self.debug_handle_map: Dict[int, Union[int, List[int]]] = {} 241*523fa7a6SAndroid Build Coastguard Worker self.instr_id_to_delegate_debug_id_map: Dict[ 242*523fa7a6SAndroid Build Coastguard Worker int, Dict[str, Union[str, _DelegateDebugIdentifierMap]] 243*523fa7a6SAndroid Build Coastguard Worker ] = {} 244*523fa7a6SAndroid Build Coastguard Worker 245*523fa7a6SAndroid Build Coastguard Worker def _emit_node_specific_error(self, node: torch.fx.Node, err_msg: str) -> str: 246*523fa7a6SAndroid Build Coastguard Worker """Returns 'err_msg' with node specific information attached.""" 247*523fa7a6SAndroid Build Coastguard Worker err_msg = f"Failed with error: {str(err_msg)}\n" + inspect_node( 248*523fa7a6SAndroid Build Coastguard Worker self.graph_module.graph, node 249*523fa7a6SAndroid Build Coastguard Worker ) 250*523fa7a6SAndroid Build Coastguard Worker return err_msg 251*523fa7a6SAndroid Build Coastguard Worker 252*523fa7a6SAndroid Build Coastguard Worker def _internal_assert_emitter( 253*523fa7a6SAndroid Build Coastguard Worker self, pred: bool, node: torch.fx.Node, assert_msg: str 254*523fa7a6SAndroid Build Coastguard Worker ) -> None: 255*523fa7a6SAndroid Build Coastguard Worker """If pred is False, construct and raise a node specific error message.""" 256*523fa7a6SAndroid Build Coastguard Worker if not pred: 257*523fa7a6SAndroid Build Coastguard Worker raise InternalError(self._emit_node_specific_error(node, assert_msg)) 258*523fa7a6SAndroid Build Coastguard Worker 259*523fa7a6SAndroid Build Coastguard Worker def _emit_int_list(self, val: List[_Argument]) -> EValue: 260*523fa7a6SAndroid Build Coastguard Worker """Emits a list of integers as a collection of EValues. 261*523fa7a6SAndroid Build Coastguard Worker 262*523fa7a6SAndroid Build Coastguard Worker For every argument in 'val': 263*523fa7a6SAndroid Build Coastguard Worker - If it is a concrete value, then emit it and then place its location in the boxed list 264*523fa7a6SAndroid Build Coastguard Worker - If it is already an abstract value, then just place its location in the boxed list 265*523fa7a6SAndroid Build Coastguard Worker 266*523fa7a6SAndroid Build Coastguard Worker Int lists are boxed to handle symints whose values are determined at runtime, but could 267*523fa7a6SAndroid Build Coastguard Worker still end up inside lists for ops like view_copy(Tensor self, SymInt[] size) 268*523fa7a6SAndroid Build Coastguard Worker """ 269*523fa7a6SAndroid Build Coastguard Worker boxed_list = [] 270*523fa7a6SAndroid Build Coastguard Worker for item in val: 271*523fa7a6SAndroid Build Coastguard Worker if isinstance(item, _AbstractValue): 272*523fa7a6SAndroid Build Coastguard Worker boxed_list.append(item.id) 273*523fa7a6SAndroid Build Coastguard Worker elif isinstance(item, int): 274*523fa7a6SAndroid Build Coastguard Worker boxed_list.append( 275*523fa7a6SAndroid Build Coastguard Worker self._emit_evalue(self._constant_to_evalue(item, None)).id 276*523fa7a6SAndroid Build Coastguard Worker ) 277*523fa7a6SAndroid Build Coastguard Worker else: 278*523fa7a6SAndroid Build Coastguard Worker self._internal_assert_emitter( 279*523fa7a6SAndroid Build Coastguard Worker False, self.node, "Unsupported type encountered in int list." 280*523fa7a6SAndroid Build Coastguard Worker ) 281*523fa7a6SAndroid Build Coastguard Worker 282*523fa7a6SAndroid Build Coastguard Worker return EValue(IntList(boxed_list)) 283*523fa7a6SAndroid Build Coastguard Worker 284*523fa7a6SAndroid Build Coastguard Worker def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue: 285*523fa7a6SAndroid Build Coastguard Worker """Emits a list type. 286*523fa7a6SAndroid Build Coastguard Worker 287*523fa7a6SAndroid Build Coastguard Worker Emits the list stored in val. If the list is of Tensors, Optionals, or Ints the emitted list 288*523fa7a6SAndroid Build Coastguard Worker is boxed, otherwise the values are constant at runtime and stored inline. 289*523fa7a6SAndroid Build Coastguard Worker 290*523fa7a6SAndroid Build Coastguard Worker NOTE: When symbool and symfloat are supported bool and float lists will be stored boxed. 291*523fa7a6SAndroid Build Coastguard Worker """ 292*523fa7a6SAndroid Build Coastguard Worker 293*523fa7a6SAndroid Build Coastguard Worker if isinstance(val_type, torch.BoolType): 294*523fa7a6SAndroid Build Coastguard Worker return EValue(BoolList(typing.cast(List[bool], val))) 295*523fa7a6SAndroid Build Coastguard Worker 296*523fa7a6SAndroid Build Coastguard Worker if isinstance(val_type, torch.IntType): 297*523fa7a6SAndroid Build Coastguard Worker return self._emit_int_list(val) 298*523fa7a6SAndroid Build Coastguard Worker 299*523fa7a6SAndroid Build Coastguard Worker if isinstance(val_type, torch.FloatType): 300*523fa7a6SAndroid Build Coastguard Worker return EValue(DoubleList(typing.cast(List[float], val))) 301*523fa7a6SAndroid Build Coastguard Worker 302*523fa7a6SAndroid Build Coastguard Worker if isinstance(val_type, torch.TensorType): 303*523fa7a6SAndroid Build Coastguard Worker values = [] 304*523fa7a6SAndroid Build Coastguard Worker for v in val: 305*523fa7a6SAndroid Build Coastguard Worker assert isinstance(v, _AbstractValue) 306*523fa7a6SAndroid Build Coastguard Worker self._internal_assert_emitter( 307*523fa7a6SAndroid Build Coastguard Worker v.tensor is not None, 308*523fa7a6SAndroid Build Coastguard Worker self.node, 309*523fa7a6SAndroid Build Coastguard Worker "AbstractValue corresponding to tensor type doesn't contain tensor value", 310*523fa7a6SAndroid Build Coastguard Worker ) 311*523fa7a6SAndroid Build Coastguard Worker values.append(v.id) 312*523fa7a6SAndroid Build Coastguard Worker return EValue(TensorList(values)) 313*523fa7a6SAndroid Build Coastguard Worker 314*523fa7a6SAndroid Build Coastguard Worker if isinstance(val_type, torch.OptionalType): 315*523fa7a6SAndroid Build Coastguard Worker # refine further 316*523fa7a6SAndroid Build Coastguard Worker actual_type = val_type.getElementType() 317*523fa7a6SAndroid Build Coastguard Worker if isinstance(actual_type, torch.TensorType): 318*523fa7a6SAndroid Build Coastguard Worker vals = [] 319*523fa7a6SAndroid Build Coastguard Worker for v in val: 320*523fa7a6SAndroid Build Coastguard Worker if v is None: 321*523fa7a6SAndroid Build Coastguard Worker vals.append(-1) 322*523fa7a6SAndroid Build Coastguard Worker else: 323*523fa7a6SAndroid Build Coastguard Worker assert isinstance(v, _AbstractValue) 324*523fa7a6SAndroid Build Coastguard Worker vals.append(v.id) 325*523fa7a6SAndroid Build Coastguard Worker return EValue(OptionalTensorList(vals)) 326*523fa7a6SAndroid Build Coastguard Worker 327*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 328*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}" 329*523fa7a6SAndroid Build Coastguard Worker ) 330*523fa7a6SAndroid Build Coastguard Worker 331*523fa7a6SAndroid Build Coastguard Worker def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: 332*523fa7a6SAndroid Build Coastguard Worker """Constructs an EValue from the given TensorSpec.""" 333*523fa7a6SAndroid Build Coastguard Worker 334*523fa7a6SAndroid Build Coastguard Worker allocation_info = None 335*523fa7a6SAndroid Build Coastguard Worker buffer_idx = 0 336*523fa7a6SAndroid Build Coastguard Worker 337*523fa7a6SAndroid Build Coastguard Worker # Need to memory plan 338*523fa7a6SAndroid Build Coastguard Worker # Some users set mem_id on all tensors and then rely on the 339*523fa7a6SAndroid Build Coastguard Worker # default algos to set offsets, so need to check both. 340*523fa7a6SAndroid Build Coastguard Worker if spec.mem_id is not None and spec.mem_offset is not None: 341*523fa7a6SAndroid Build Coastguard Worker # Tensor is an activation. 342*523fa7a6SAndroid Build Coastguard Worker self._internal_assert_emitter( 343*523fa7a6SAndroid Build Coastguard Worker isinstance(spec.mem_id, int) and spec.mem_id >= 0, 344*523fa7a6SAndroid Build Coastguard Worker self.node, 345*523fa7a6SAndroid Build Coastguard Worker f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}", 346*523fa7a6SAndroid Build Coastguard Worker ) 347*523fa7a6SAndroid Build Coastguard Worker 348*523fa7a6SAndroid Build Coastguard Worker self._internal_assert_emitter( 349*523fa7a6SAndroid Build Coastguard Worker isinstance(spec.mem_offset, int) and spec.mem_offset >= 0, 350*523fa7a6SAndroid Build Coastguard Worker self.node, 351*523fa7a6SAndroid Build Coastguard Worker f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}", 352*523fa7a6SAndroid Build Coastguard Worker ) 353*523fa7a6SAndroid Build Coastguard Worker try: 354*523fa7a6SAndroid Build Coastguard Worker allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset) 355*523fa7a6SAndroid Build Coastguard Worker except AddressSpaceOverflowException as e: 356*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 357*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 358*523fa7a6SAndroid Build Coastguard Worker self.node, 359*523fa7a6SAndroid Build Coastguard Worker ( 360*523fa7a6SAndroid Build Coastguard Worker f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, " 361*523fa7a6SAndroid Build Coastguard Worker f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an " 362*523fa7a6SAndroid Build Coastguard Worker f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) " 363*523fa7a6SAndroid Build Coastguard Worker "during torch.export()." 364*523fa7a6SAndroid Build Coastguard Worker ), 365*523fa7a6SAndroid Build Coastguard Worker ) 366*523fa7a6SAndroid Build Coastguard Worker ) 367*523fa7a6SAndroid Build Coastguard Worker 368*523fa7a6SAndroid Build Coastguard Worker if spec.const: 369*523fa7a6SAndroid Build Coastguard Worker # Tensor with a blob we need to serialize. May not actually be constant at runtime 370*523fa7a6SAndroid Build Coastguard Worker # if it's a weight with an associated gradient 371*523fa7a6SAndroid Build Coastguard Worker spec_array_type = ( 372*523fa7a6SAndroid Build Coastguard Worker ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes() 373*523fa7a6SAndroid Build Coastguard Worker ) 374*523fa7a6SAndroid Build Coastguard Worker 375*523fa7a6SAndroid Build Coastguard Worker buffer_data = ( 376*523fa7a6SAndroid Build Coastguard Worker bytes( 377*523fa7a6SAndroid Build Coastguard Worker ctypes.cast( 378*523fa7a6SAndroid Build Coastguard Worker typing.cast(torch.UntypedStorage, spec.storage).data_ptr(), 379*523fa7a6SAndroid Build Coastguard Worker ctypes.POINTER(spec_array_type), 380*523fa7a6SAndroid Build Coastguard Worker ).contents 381*523fa7a6SAndroid Build Coastguard Worker ) 382*523fa7a6SAndroid Build Coastguard Worker if spec.allocated_memory != 0 383*523fa7a6SAndroid Build Coastguard Worker else b"" 384*523fa7a6SAndroid Build Coastguard Worker ) 385*523fa7a6SAndroid Build Coastguard Worker 386*523fa7a6SAndroid Build Coastguard Worker hashed = hashlib.sha256(buffer_data).hexdigest() 387*523fa7a6SAndroid Build Coastguard Worker 388*523fa7a6SAndroid Build Coastguard Worker if allocation_info: 389*523fa7a6SAndroid Build Coastguard Worker buffer_idx = self.program_state.cached_spec_mutable_hash_values.get( 390*523fa7a6SAndroid Build Coastguard Worker hashed, -1 391*523fa7a6SAndroid Build Coastguard Worker ) 392*523fa7a6SAndroid Build Coastguard Worker else: 393*523fa7a6SAndroid Build Coastguard Worker buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) 394*523fa7a6SAndroid Build Coastguard Worker 395*523fa7a6SAndroid Build Coastguard Worker # Haven't seen this constant before 396*523fa7a6SAndroid Build Coastguard Worker if buffer_idx == -1: 397*523fa7a6SAndroid Build Coastguard Worker # Update buffer_idx to point to the end of the list where we are adding the new buffer. 398*523fa7a6SAndroid Build Coastguard Worker buffer = Buffer(storage=buffer_data) 399*523fa7a6SAndroid Build Coastguard Worker self.program_state.allocated_specs.append(spec) 400*523fa7a6SAndroid Build Coastguard Worker # +1 because the first buffer location is reserved 401*523fa7a6SAndroid Build Coastguard Worker 402*523fa7a6SAndroid Build Coastguard Worker if allocation_info: 403*523fa7a6SAndroid Build Coastguard Worker buffer_idx = len(self.program_state.mutable_buffer) 404*523fa7a6SAndroid Build Coastguard Worker self.program_state.cached_spec_mutable_hash_values[hashed] = ( 405*523fa7a6SAndroid Build Coastguard Worker buffer_idx 406*523fa7a6SAndroid Build Coastguard Worker ) 407*523fa7a6SAndroid Build Coastguard Worker self.program_state.mutable_buffer.append(buffer) 408*523fa7a6SAndroid Build Coastguard Worker else: 409*523fa7a6SAndroid Build Coastguard Worker buffer_idx = len(self.program_state.constant_buffer) 410*523fa7a6SAndroid Build Coastguard Worker self.program_state.cached_spec_hash_values[hashed] = buffer_idx 411*523fa7a6SAndroid Build Coastguard Worker self.program_state.constant_buffer.append(buffer) 412*523fa7a6SAndroid Build Coastguard Worker 413*523fa7a6SAndroid Build Coastguard Worker if spec.const and spec.nbytes() != len(buffer_data): 414*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 415*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 416*523fa7a6SAndroid Build Coastguard Worker self.node, 417*523fa7a6SAndroid Build Coastguard Worker f"Tensor spec has buffer of size {len(buffer_data)}, but expected nbytes of {spec.nbytes()}", 418*523fa7a6SAndroid Build Coastguard Worker ) 419*523fa7a6SAndroid Build Coastguard Worker ) 420*523fa7a6SAndroid Build Coastguard Worker 421*523fa7a6SAndroid Build Coastguard Worker # For constant tensors, allocation_info = None. 422*523fa7a6SAndroid Build Coastguard Worker return EValue(make_tensor_value(buffer_idx, allocation_info, spec)) 423*523fa7a6SAndroid Build Coastguard Worker 424*523fa7a6SAndroid Build Coastguard Worker def _get_list_tuple_jit_type( 425*523fa7a6SAndroid Build Coastguard Worker self, val: Union[Tuple[_Argument], List[_Argument]] 426*523fa7a6SAndroid Build Coastguard Worker ) -> _SchemaType: 427*523fa7a6SAndroid Build Coastguard Worker """Returns the JIT type for the given python type.""" 428*523fa7a6SAndroid Build Coastguard Worker assert isinstance( 429*523fa7a6SAndroid Build Coastguard Worker val, (list, tuple) 430*523fa7a6SAndroid Build Coastguard Worker ), f"Input to _get_list_tuple_jit_type was expected to be an instance of a list or tuple but received {type(val)}" 431*523fa7a6SAndroid Build Coastguard Worker is_tensor_type = all( 432*523fa7a6SAndroid Build Coastguard Worker isinstance(v, _AbstractValue) and v.tensor is not None for v in val 433*523fa7a6SAndroid Build Coastguard Worker ) 434*523fa7a6SAndroid Build Coastguard Worker if is_tensor_type: 435*523fa7a6SAndroid Build Coastguard Worker return torch.TensorType.get() 436*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val[0], int): 437*523fa7a6SAndroid Build Coastguard Worker return torch.IntType.get() 438*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val[0], bool): 439*523fa7a6SAndroid Build Coastguard Worker return torch.BoolType.get() 440*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val[0], float): 441*523fa7a6SAndroid Build Coastguard Worker return torch.FloatType.get() 442*523fa7a6SAndroid Build Coastguard Worker 443*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 444*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 445*523fa7a6SAndroid Build Coastguard Worker self.node, 446*523fa7a6SAndroid Build Coastguard Worker "Couldn't determine JitType for list/tuple of elements. Only supports int, float, bool, and Tensor.", 447*523fa7a6SAndroid Build Coastguard Worker ) 448*523fa7a6SAndroid Build Coastguard Worker ) 449*523fa7a6SAndroid Build Coastguard Worker 450*523fa7a6SAndroid Build Coastguard Worker def _constant_to_evalue( # noqa: C901 451*523fa7a6SAndroid Build Coastguard Worker self, 452*523fa7a6SAndroid Build Coastguard Worker val: _Argument, 453*523fa7a6SAndroid Build Coastguard Worker val_type: Optional[_SchemaType], 454*523fa7a6SAndroid Build Coastguard Worker ) -> EValue: 455*523fa7a6SAndroid Build Coastguard Worker """Converts a constant value to an EValue. 456*523fa7a6SAndroid Build Coastguard Worker 457*523fa7a6SAndroid Build Coastguard Worker Returns an EValue given the Python representation and JIT type. On common paths there should 458*523fa7a6SAndroid Build Coastguard Worker always be a JIT type provided. Users can pass in a None to infer the JIT type but this 459*523fa7a6SAndroid Build Coastguard Worker should never be the default case due to the existence of container types. 460*523fa7a6SAndroid Build Coastguard Worker """ 461*523fa7a6SAndroid Build Coastguard Worker if val is None: 462*523fa7a6SAndroid Build Coastguard Worker return EValue(Null()) 463*523fa7a6SAndroid Build Coastguard Worker 464*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, (list, tuple)): 465*523fa7a6SAndroid Build Coastguard Worker # Refine Optional[List[T]] -> List[T] This works because if the val was None it would 466*523fa7a6SAndroid Build Coastguard Worker # have converted to Null before this function call. 467*523fa7a6SAndroid Build Coastguard Worker if val_type is None: 468*523fa7a6SAndroid Build Coastguard Worker val_type = torch.ListType( 469*523fa7a6SAndroid Build Coastguard Worker self._get_list_tuple_jit_type(val) # pyre-ignore 470*523fa7a6SAndroid Build Coastguard Worker ) 471*523fa7a6SAndroid Build Coastguard Worker if isinstance(val_type, torch.OptionalType): 472*523fa7a6SAndroid Build Coastguard Worker val_type = val_type.getElementType() 473*523fa7a6SAndroid Build Coastguard Worker assert isinstance(val_type, torch.ListType) 474*523fa7a6SAndroid Build Coastguard Worker return self._emit_list( 475*523fa7a6SAndroid Build Coastguard Worker typing.cast(List[_Argument], val), 476*523fa7a6SAndroid Build Coastguard Worker typing.cast(_SchemaType, val_type.getElementType()), 477*523fa7a6SAndroid Build Coastguard Worker ) 478*523fa7a6SAndroid Build Coastguard Worker 479*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, float): 480*523fa7a6SAndroid Build Coastguard Worker return EValue(Double(val)) 481*523fa7a6SAndroid Build Coastguard Worker 482*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, bool): 483*523fa7a6SAndroid Build Coastguard Worker return EValue(Bool(val)) 484*523fa7a6SAndroid Build Coastguard Worker 485*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, int): 486*523fa7a6SAndroid Build Coastguard Worker return EValue(Int(val)) 487*523fa7a6SAndroid Build Coastguard Worker 488*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, str): 489*523fa7a6SAndroid Build Coastguard Worker return EValue(String(val)) 490*523fa7a6SAndroid Build Coastguard Worker 491*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, torch.dtype): 492*523fa7a6SAndroid Build Coastguard Worker return EValue(Int(scalar_type_enum(val))) 493*523fa7a6SAndroid Build Coastguard Worker 494*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, torch.layout): 495*523fa7a6SAndroid Build Coastguard Worker return EValue(Int(layout_enum(val))) 496*523fa7a6SAndroid Build Coastguard Worker 497*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, torch.memory_format): 498*523fa7a6SAndroid Build Coastguard Worker try: 499*523fa7a6SAndroid Build Coastguard Worker return EValue(Int(memory_format_enum(val))) 500*523fa7a6SAndroid Build Coastguard Worker except KeyError: 501*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 502*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 503*523fa7a6SAndroid Build Coastguard Worker self.node, 504*523fa7a6SAndroid Build Coastguard Worker f"Tensor has a memory_format that is unsupported in ExecuTorch: {val}", 505*523fa7a6SAndroid Build Coastguard Worker ) 506*523fa7a6SAndroid Build Coastguard Worker ) 507*523fa7a6SAndroid Build Coastguard Worker 508*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, torch.Tensor): 509*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 510*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 511*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 512*523fa7a6SAndroid Build Coastguard Worker self.node, 513*523fa7a6SAndroid Build Coastguard Worker "constant_to_evalue should not be encountering constant tensors, they should be emitted through other codepaths.", 514*523fa7a6SAndroid Build Coastguard Worker ), 515*523fa7a6SAndroid Build Coastguard Worker ) 516*523fa7a6SAndroid Build Coastguard Worker 517*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 518*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 519*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 520*523fa7a6SAndroid Build Coastguard Worker self.node, f"Unsupported constant type: {type(val).__name__}" 521*523fa7a6SAndroid Build Coastguard Worker ), 522*523fa7a6SAndroid Build Coastguard Worker ) 523*523fa7a6SAndroid Build Coastguard Worker 524*523fa7a6SAndroid Build Coastguard Worker def _emit_evalue(self, val: EValue) -> _AbstractValue: 525*523fa7a6SAndroid Build Coastguard Worker """Writes an EValue to the emitter state. 526*523fa7a6SAndroid Build Coastguard Worker 527*523fa7a6SAndroid Build Coastguard Worker Given an Evalue, adds it to the emitter_state's values table, and returns the AbstractValue 528*523fa7a6SAndroid Build Coastguard Worker representing it. 529*523fa7a6SAndroid Build Coastguard Worker """ 530*523fa7a6SAndroid Build Coastguard Worker self.emitter_state.values.append(val) 531*523fa7a6SAndroid Build Coastguard Worker tensor = val.val if isinstance(val.val, Tensor) else None 532*523fa7a6SAndroid Build Coastguard Worker return _AbstractValue(len(self.emitter_state.values) - 1, tensor) 533*523fa7a6SAndroid Build Coastguard Worker 534*523fa7a6SAndroid Build Coastguard Worker def _emit_spec(self, spec: ValueSpec) -> _EmitterValue: 535*523fa7a6SAndroid Build Coastguard Worker """Given the provided spec constructs the corresponding EValue from it and then emits it.""" 536*523fa7a6SAndroid Build Coastguard Worker 537*523fa7a6SAndroid Build Coastguard Worker def _process(spec: LeafValueSpec) -> _AbstractValue: 538*523fa7a6SAndroid Build Coastguard Worker if isinstance(spec, (list, tuple)): 539*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 540*523fa7a6SAndroid Build Coastguard Worker self.emit_node_specific_error( 541*523fa7a6SAndroid Build Coastguard Worker self.node, 542*523fa7a6SAndroid Build Coastguard Worker "Node spec should be either non-nested container or a scalar object", 543*523fa7a6SAndroid Build Coastguard Worker ) 544*523fa7a6SAndroid Build Coastguard Worker ) 545*523fa7a6SAndroid Build Coastguard Worker 546*523fa7a6SAndroid Build Coastguard Worker # ScalarSpec can theoretically be handled fine, but it shouldn't be appearing so rather 547*523fa7a6SAndroid Build Coastguard Worker # than handle it, assert that it isn't supposed to be present. In the future if it has a 548*523fa7a6SAndroid Build Coastguard Worker # reason to appear we can relax this assert. 549*523fa7a6SAndroid Build Coastguard Worker self._internal_assert_emitter( 550*523fa7a6SAndroid Build Coastguard Worker isinstance(spec, TensorSpec), 551*523fa7a6SAndroid Build Coastguard Worker self.node, 552*523fa7a6SAndroid Build Coastguard Worker f"Invalid node spec expected TensorSpec received {spec}", 553*523fa7a6SAndroid Build Coastguard Worker ) 554*523fa7a6SAndroid Build Coastguard Worker 555*523fa7a6SAndroid Build Coastguard Worker ret = self._emit_evalue(self._tensor_spec_to_evalue(spec)) # pyre-ignore 556*523fa7a6SAndroid Build Coastguard Worker self.emitter_state.spec2id_dict[spec] = ret.id # pyre-ignore 557*523fa7a6SAndroid Build Coastguard Worker return ret 558*523fa7a6SAndroid Build Coastguard Worker 559*523fa7a6SAndroid Build Coastguard Worker return pytree.tree_map(_process, spec) 560*523fa7a6SAndroid Build Coastguard Worker 561*523fa7a6SAndroid Build Coastguard Worker def _merge_chain(self, chain: Chain) -> None: 562*523fa7a6SAndroid Build Coastguard Worker """Merges the chain generated from subgraphs (like those originating from control flow) back 563*523fa7a6SAndroid Build Coastguard Worker into the main program chain.""" 564*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.extend(chain.instructions) 565*523fa7a6SAndroid Build Coastguard Worker 566*523fa7a6SAndroid Build Coastguard Worker def _emit_cond( 567*523fa7a6SAndroid Build Coastguard Worker self, 568*523fa7a6SAndroid Build Coastguard Worker args: Tuple[_Argument, ...], 569*523fa7a6SAndroid Build Coastguard Worker subemitter_binding_output_values: Optional[List[_AbstractValue]], 570*523fa7a6SAndroid Build Coastguard Worker ) -> List[_AbstractValue]: 571*523fa7a6SAndroid Build Coastguard Worker """Emits control_flow.cond. 572*523fa7a6SAndroid Build Coastguard Worker 573*523fa7a6SAndroid Build Coastguard Worker Converts the higher order op into jumps and inlines the submodules of the true and false 574*523fa7a6SAndroid Build Coastguard Worker branches. Control flow can be nested. The general emitted structure is: <Jump Instruction> - 575*523fa7a6SAndroid Build Coastguard Worker decides which branch to take <True Branch> <Jump Instruction> - jumps to End Of Cond after 576*523fa7a6SAndroid Build Coastguard Worker executing true branch <False Branch> <End Of Cond> 577*523fa7a6SAndroid Build Coastguard Worker """ 578*523fa7a6SAndroid Build Coastguard Worker pred, true_branch, false_branch, inputs = args 579*523fa7a6SAndroid Build Coastguard Worker 580*523fa7a6SAndroid Build Coastguard Worker # Emit the true branch. 581*523fa7a6SAndroid Build Coastguard Worker assert isinstance(true_branch, torch.fx.GraphModule) 582*523fa7a6SAndroid Build Coastguard Worker true_branch_emitter = _Emitter( 583*523fa7a6SAndroid Build Coastguard Worker true_branch, 584*523fa7a6SAndroid Build Coastguard Worker self.emitter_state, 585*523fa7a6SAndroid Build Coastguard Worker self.program_state, 586*523fa7a6SAndroid Build Coastguard Worker instruction_start_offset=self.instruction_start_offset 587*523fa7a6SAndroid Build Coastguard Worker + len(self.chain.instructions) 588*523fa7a6SAndroid Build Coastguard Worker + 1, 589*523fa7a6SAndroid Build Coastguard Worker binding_input_values=typing.cast(List[_AbstractValue], inputs), 590*523fa7a6SAndroid Build Coastguard Worker binding_output_values=subemitter_binding_output_values, 591*523fa7a6SAndroid Build Coastguard Worker ) 592*523fa7a6SAndroid Build Coastguard Worker true_branch_emitter.run() 593*523fa7a6SAndroid Build Coastguard Worker 594*523fa7a6SAndroid Build Coastguard Worker # Emit the jump. 595*523fa7a6SAndroid Build Coastguard Worker assert isinstance(pred, _AbstractValue) 596*523fa7a6SAndroid Build Coastguard Worker jf_instruction_to_skip_true = Instruction( 597*523fa7a6SAndroid Build Coastguard Worker JumpFalseCall( 598*523fa7a6SAndroid Build Coastguard Worker cond_value_index=pred.id, 599*523fa7a6SAndroid Build Coastguard Worker destination_instruction=self.instruction_start_offset 600*523fa7a6SAndroid Build Coastguard Worker + len(self.chain.instructions) 601*523fa7a6SAndroid Build Coastguard Worker + len(true_branch_emitter.chain.instructions) 602*523fa7a6SAndroid Build Coastguard Worker # This jump instruction should point at instruction that is after all instructions 603*523fa7a6SAndroid Build Coastguard Worker # for the true branch. The reason we add 2 is because we need to account for this 604*523fa7a6SAndroid Build Coastguard Worker # instruction we are creating right now and the jump instruction that true branch 605*523fa7a6SAndroid Build Coastguard Worker # will create. 606*523fa7a6SAndroid Build Coastguard Worker + 2, 607*523fa7a6SAndroid Build Coastguard Worker ) 608*523fa7a6SAndroid Build Coastguard Worker ) 609*523fa7a6SAndroid Build Coastguard Worker 610*523fa7a6SAndroid Build Coastguard Worker # Insert the branch picking jump instruction to the main chain. 611*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(jf_instruction_to_skip_true) 612*523fa7a6SAndroid Build Coastguard Worker # Now that we created the true branch instructions, we move them to the main chain. 613*523fa7a6SAndroid Build Coastguard Worker self._merge_chain(true_branch_emitter.chain) 614*523fa7a6SAndroid Build Coastguard Worker 615*523fa7a6SAndroid Build Coastguard Worker # emit false branch 616*523fa7a6SAndroid Build Coastguard Worker assert isinstance(false_branch, torch.fx.GraphModule) 617*523fa7a6SAndroid Build Coastguard Worker false_branch_emitter = _Emitter( 618*523fa7a6SAndroid Build Coastguard Worker false_branch, 619*523fa7a6SAndroid Build Coastguard Worker self.emitter_state, 620*523fa7a6SAndroid Build Coastguard Worker self.program_state, 621*523fa7a6SAndroid Build Coastguard Worker instruction_start_offset=self.instruction_start_offset 622*523fa7a6SAndroid Build Coastguard Worker + len(self.chain.instructions) 623*523fa7a6SAndroid Build Coastguard Worker + 1, 624*523fa7a6SAndroid Build Coastguard Worker binding_input_values=typing.cast(List[_AbstractValue], inputs), 625*523fa7a6SAndroid Build Coastguard Worker binding_output_values=subemitter_binding_output_values, 626*523fa7a6SAndroid Build Coastguard Worker ) 627*523fa7a6SAndroid Build Coastguard Worker false_branch_emitter.run() 628*523fa7a6SAndroid Build Coastguard Worker 629*523fa7a6SAndroid Build Coastguard Worker # We bake in constant False because this will trigger the instruction to jump over all false 630*523fa7a6SAndroid Build Coastguard Worker # branch instructions and point at the start of instruction right after control flow. 631*523fa7a6SAndroid Build Coastguard Worker value = self._emit_evalue(EValue(Bool(False))) 632*523fa7a6SAndroid Build Coastguard Worker jf_instruction_to_skip_false = Instruction( 633*523fa7a6SAndroid Build Coastguard Worker JumpFalseCall( 634*523fa7a6SAndroid Build Coastguard Worker cond_value_index=value.id, 635*523fa7a6SAndroid Build Coastguard Worker destination_instruction=self.instruction_start_offset 636*523fa7a6SAndroid Build Coastguard Worker + len(self.chain.instructions) 637*523fa7a6SAndroid Build Coastguard Worker + len(false_branch_emitter.chain.instructions) 638*523fa7a6SAndroid Build Coastguard Worker + 1, 639*523fa7a6SAndroid Build Coastguard Worker ) 640*523fa7a6SAndroid Build Coastguard Worker ) 641*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(jf_instruction_to_skip_false) 642*523fa7a6SAndroid Build Coastguard Worker self._merge_chain(false_branch_emitter.chain) 643*523fa7a6SAndroid Build Coastguard Worker return subemitter_binding_output_values 644*523fa7a6SAndroid Build Coastguard Worker 645*523fa7a6SAndroid Build Coastguard Worker def _emit_map( 646*523fa7a6SAndroid Build Coastguard Worker self, 647*523fa7a6SAndroid Build Coastguard Worker args: Tuple[_Argument, ...], 648*523fa7a6SAndroid Build Coastguard Worker subemitter_binding_output_values: List[_AbstractValue], 649*523fa7a6SAndroid Build Coastguard Worker ) -> List[_AbstractValue]: 650*523fa7a6SAndroid Build Coastguard Worker """Emits torch.map. 651*523fa7a6SAndroid Build Coastguard Worker 652*523fa7a6SAndroid Build Coastguard Worker Converts the higher order op into a loop constructed from jump instructions and primitive 653*523fa7a6SAndroid Build Coastguard Worker int operations. A concat-like custom op is also injected into the submodule code to handle 654*523fa7a6SAndroid Build Coastguard Worker the construction of the map output. 655*523fa7a6SAndroid Build Coastguard Worker 656*523fa7a6SAndroid Build Coastguard Worker Below is what the input graph that is provided to emit_map looks like. class 657*523fa7a6SAndroid Build Coastguard Worker TestMapCond(torch.nn.Module): def __init__(self): 658*523fa7a6SAndroid Build Coastguard Worker super().__init__() 659*523fa7a6SAndroid Build Coastguard Worker 660*523fa7a6SAndroid Build Coastguard Worker def forward(self, x,y): 661*523fa7a6SAndroid Build Coastguard Worker return control_flow.map(map_fn, x, y) 662*523fa7a6SAndroid Build Coastguard Worker 663*523fa7a6SAndroid Build Coastguard Worker Corresponding graph: def forward(self, arg0_1, arg1_1): 664*523fa7a6SAndroid Build Coastguard Worker submodule_0 = self.submodule_0 map_1 = torch.ops.higher_order.map_impl(submodule_0, arg0_1, arg1_1); 665*523fa7a6SAndroid Build Coastguard Worker submodule_0 = arg0_1 = arg1_1 = None return [map_1] 666*523fa7a6SAndroid Build Coastguard Worker 667*523fa7a6SAndroid Build Coastguard Worker submodule_0: def forward(self, arg0_1, arg1_1): 668*523fa7a6SAndroid Build Coastguard Worker add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return 669*523fa7a6SAndroid Build Coastguard Worker add_tensor 670*523fa7a6SAndroid Build Coastguard Worker 671*523fa7a6SAndroid Build Coastguard Worker Post the transformations done by emit_map this is what the submodule program looks like. def 672*523fa7a6SAndroid Build Coastguard Worker forward(self, arg0_1, arg1_1): 673*523fa7a6SAndroid Build Coastguard Worker sym_size = torch.ops.aten.sym_size(arg0_1) # Emitter creates a variable here to track 674*523fa7a6SAndroid Build Coastguard Worker iteration index select_copy_tensor = torch.ops.aten.select(arg0_1, 0, iteration_index) 675*523fa7a6SAndroid Build Coastguard Worker add_tensor = torch.ops.aten.add.Tensor(select_copy_tensor, arg1_1); arg0_1 = arg1_1 = 676*523fa7a6SAndroid Build Coastguard Worker None output_of_map = torch.ops.executorch.prim.et_copy_index(output_of_map, add_tensor, 677*523fa7a6SAndroid Build Coastguard Worker iteration_index) iteration_index = torch.ops.executorch.prim.add.int(iteration_index, 1, 678*523fa7a6SAndroid Build Coastguard Worker iteration_index) done_bool = torch.ops.executorch.prim.eq.int(iteration_index, sym_size, 679*523fa7a6SAndroid Build Coastguard Worker done_bool) # Emitter inserts a instruction here, if done_bool == False jump to 680*523fa7a6SAndroid Build Coastguard Worker selcect_copy op # if not continue. return add_tensor 681*523fa7a6SAndroid Build Coastguard Worker """ 682*523fa7a6SAndroid Build Coastguard Worker assert isinstance( 683*523fa7a6SAndroid Build Coastguard Worker subemitter_binding_output_values, (list, tuple) 684*523fa7a6SAndroid Build Coastguard Worker ), f"Expect a list for subemitter_binding_output_values for map. Got {subemitter_binding_output_values}." 685*523fa7a6SAndroid Build Coastguard Worker 686*523fa7a6SAndroid Build Coastguard Worker if len(subemitter_binding_output_values) != 1: 687*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 688*523fa7a6SAndroid Build Coastguard Worker f"Multiple outputs are not supported. Got {len(subemitter_binding_output_values)}." 689*523fa7a6SAndroid Build Coastguard Worker ) 690*523fa7a6SAndroid Build Coastguard Worker f, mapped_args, inputs = args 691*523fa7a6SAndroid Build Coastguard Worker assert isinstance(mapped_args, (list, tuple)) 692*523fa7a6SAndroid Build Coastguard Worker num_mapped_args: int = len(mapped_args) 693*523fa7a6SAndroid Build Coastguard Worker if num_mapped_args != 1: 694*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 695*523fa7a6SAndroid Build Coastguard Worker f"Emitting map with more than one mapped args is not supported. Got {num_mapped_args}." 696*523fa7a6SAndroid Build Coastguard Worker ) 697*523fa7a6SAndroid Build Coastguard Worker x = mapped_args[0] 698*523fa7a6SAndroid Build Coastguard Worker 699*523fa7a6SAndroid Build Coastguard Worker assert isinstance(f, torch.fx.GraphModule) 700*523fa7a6SAndroid Build Coastguard Worker 701*523fa7a6SAndroid Build Coastguard Worker # Generate the EValue that we will use as our iterator index to keep track of which 702*523fa7a6SAndroid Build Coastguard Worker # iteration we are currently on. 703*523fa7a6SAndroid Build Coastguard Worker iter_idx = self._emit_evalue(EValue(Int(0))) 704*523fa7a6SAndroid Build Coastguard Worker # Generate the kernel call that will output the number of iterations we need to run for this 705*523fa7a6SAndroid Build Coastguard Worker # input tensor. 706*523fa7a6SAndroid Build Coastguard Worker op_index, op = self._get_operator( 707*523fa7a6SAndroid Build Coastguard Worker name="aten::sym_size", 708*523fa7a6SAndroid Build Coastguard Worker overload="int", 709*523fa7a6SAndroid Build Coastguard Worker ) 710*523fa7a6SAndroid Build Coastguard Worker sym_size = self._emit_evalue(EValue(Int(0))) 711*523fa7a6SAndroid Build Coastguard Worker kernel = Instruction( 712*523fa7a6SAndroid Build Coastguard Worker KernelCall( 713*523fa7a6SAndroid Build Coastguard Worker op_index=op_index, 714*523fa7a6SAndroid Build Coastguard Worker args=[x.id, self._emit_evalue(EValue(Int(0))).id, sym_size.id], 715*523fa7a6SAndroid Build Coastguard Worker ) 716*523fa7a6SAndroid Build Coastguard Worker ) 717*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(kernel) 718*523fa7a6SAndroid Build Coastguard Worker 719*523fa7a6SAndroid Build Coastguard Worker # This kernel call will slice the input tensor along the index specified in iter_idx to 720*523fa7a6SAndroid Build Coastguard Worker # generate the input slice on which this iteration will be working on. 721*523fa7a6SAndroid Build Coastguard Worker op_index, op = self._get_operator( 722*523fa7a6SAndroid Build Coastguard Worker name="aten::select_copy", 723*523fa7a6SAndroid Build Coastguard Worker overload="int_out", 724*523fa7a6SAndroid Build Coastguard Worker ) 725*523fa7a6SAndroid Build Coastguard Worker # This select copy has to output to the tensor which is the input placeholder to the map 726*523fa7a6SAndroid Build Coastguard Worker # sub-graph. That placeholder isn't allocated an EValue id until the map emitter is run, so 727*523fa7a6SAndroid Build Coastguard Worker # we temporarily store -1 until the map emitter is run during which the placeholder will be 728*523fa7a6SAndroid Build Coastguard Worker # allocated an EValue id. After the map emitter is run we will retrieve that id and replace 729*523fa7a6SAndroid Build Coastguard Worker # the -1's. 730*523fa7a6SAndroid Build Coastguard Worker kernel = Instruction( 731*523fa7a6SAndroid Build Coastguard Worker KernelCall( 732*523fa7a6SAndroid Build Coastguard Worker op_index=op_index, 733*523fa7a6SAndroid Build Coastguard Worker args=[ 734*523fa7a6SAndroid Build Coastguard Worker x.id, 735*523fa7a6SAndroid Build Coastguard Worker self._emit_evalue(EValue(Int(0))).id, 736*523fa7a6SAndroid Build Coastguard Worker iter_idx.id, 737*523fa7a6SAndroid Build Coastguard Worker -1, # input_tensor_value.id, 738*523fa7a6SAndroid Build Coastguard Worker -1, # input_tensor_value.id, 739*523fa7a6SAndroid Build Coastguard Worker ], 740*523fa7a6SAndroid Build Coastguard Worker ) 741*523fa7a6SAndroid Build Coastguard Worker ) 742*523fa7a6SAndroid Build Coastguard Worker # Store the index of this instruction as it will be where we will jump back to after the end 743*523fa7a6SAndroid Build Coastguard Worker # of an iteration. 744*523fa7a6SAndroid Build Coastguard Worker jump_to_instruction = self.instruction_start_offset + len( 745*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions 746*523fa7a6SAndroid Build Coastguard Worker ) 747*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(kernel) 748*523fa7a6SAndroid Build Coastguard Worker 749*523fa7a6SAndroid Build Coastguard Worker # Emit the map operator submodule. 750*523fa7a6SAndroid Build Coastguard Worker map_emitter = _Emitter( 751*523fa7a6SAndroid Build Coastguard Worker f, 752*523fa7a6SAndroid Build Coastguard Worker self.emitter_state, 753*523fa7a6SAndroid Build Coastguard Worker self.program_state, 754*523fa7a6SAndroid Build Coastguard Worker instruction_start_offset=self.instruction_start_offset 755*523fa7a6SAndroid Build Coastguard Worker + len(self.chain.instructions), 756*523fa7a6SAndroid Build Coastguard Worker # Only the first input is a placeholder, rest of the inputs are args to the map fn. 757*523fa7a6SAndroid Build Coastguard Worker binding_input_values=[-1, *inputs], 758*523fa7a6SAndroid Build Coastguard Worker binding_output_values=subemitter_binding_output_values, 759*523fa7a6SAndroid Build Coastguard Worker ) 760*523fa7a6SAndroid Build Coastguard Worker map_emitter.run() 761*523fa7a6SAndroid Build Coastguard Worker 762*523fa7a6SAndroid Build Coastguard Worker # Merge all the instructions from the map submodule. 763*523fa7a6SAndroid Build Coastguard Worker self._merge_chain(map_emitter.chain) 764*523fa7a6SAndroid Build Coastguard Worker # Get rid of the return instruction emitted by the map subemitter. 765*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.pop() 766*523fa7a6SAndroid Build Coastguard Worker # At the end of each submodule emit we insert a move call that moves the output of the 767*523fa7a6SAndroid Build Coastguard Worker # submodule to a deterministic EValue, which is especially useful for if/else branches where 768*523fa7a6SAndroid Build Coastguard Worker # we want the output of either branch to be in the same EValue, but we don't need a move 769*523fa7a6SAndroid Build Coastguard Worker # here as our custom op executorch_prim::et_copy_index which is inserted later does that 770*523fa7a6SAndroid Build Coastguard Worker # for us. 771*523fa7a6SAndroid Build Coastguard Worker 772*523fa7a6SAndroid Build Coastguard Worker # Now that the map emitter has finished running retrieve the input placeholder EValue id and 773*523fa7a6SAndroid Build Coastguard Worker # update the select_copy kernel call to output to those id's. 774*523fa7a6SAndroid Build Coastguard Worker kernel.instr_args.args[-1] = map_emitter.binding_input_values[0].id 775*523fa7a6SAndroid Build Coastguard Worker kernel.instr_args.args[-2] = kernel.instr_args.args[-1] 776*523fa7a6SAndroid Build Coastguard Worker 777*523fa7a6SAndroid Build Coastguard Worker self._internal_assert_emitter( 778*523fa7a6SAndroid Build Coastguard Worker len(map_emitter.concrete_output_ids) == 1, 779*523fa7a6SAndroid Build Coastguard Worker self.node, 780*523fa7a6SAndroid Build Coastguard Worker "Map should return only one element", 781*523fa7a6SAndroid Build Coastguard Worker ) 782*523fa7a6SAndroid Build Coastguard Worker 783*523fa7a6SAndroid Build Coastguard Worker # Here we call the custom op, specially added for the map operator. The output of this 784*523fa7a6SAndroid Build Coastguard Worker # iteration will be appended to the accumulator tensor that we are maintaining. This 785*523fa7a6SAndroid Build Coastguard Worker # accumulator tensor is the actual output of the map submodule. 786*523fa7a6SAndroid Build Coastguard Worker op_index, op = self._get_operator( 787*523fa7a6SAndroid Build Coastguard Worker name="executorch_prim::et_copy_index", 788*523fa7a6SAndroid Build Coastguard Worker overload="tensor", 789*523fa7a6SAndroid Build Coastguard Worker ) 790*523fa7a6SAndroid Build Coastguard Worker kernel = Instruction( 791*523fa7a6SAndroid Build Coastguard Worker KernelCall( 792*523fa7a6SAndroid Build Coastguard Worker op_index, 793*523fa7a6SAndroid Build Coastguard Worker args=[ 794*523fa7a6SAndroid Build Coastguard Worker subemitter_binding_output_values[0].id, 795*523fa7a6SAndroid Build Coastguard Worker map_emitter.concrete_output_ids[0].id, 796*523fa7a6SAndroid Build Coastguard Worker iter_idx.id, 797*523fa7a6SAndroid Build Coastguard Worker ], 798*523fa7a6SAndroid Build Coastguard Worker ) 799*523fa7a6SAndroid Build Coastguard Worker ) 800*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(kernel) 801*523fa7a6SAndroid Build Coastguard Worker 802*523fa7a6SAndroid Build Coastguard Worker # Increment iter_idx to mark that we have completed an iteration. 803*523fa7a6SAndroid Build Coastguard Worker op_index, op = self._get_operator( 804*523fa7a6SAndroid Build Coastguard Worker name="executorch_prim::add", 805*523fa7a6SAndroid Build Coastguard Worker overload="Scalar", 806*523fa7a6SAndroid Build Coastguard Worker ) 807*523fa7a6SAndroid Build Coastguard Worker kernel = Instruction( 808*523fa7a6SAndroid Build Coastguard Worker KernelCall( 809*523fa7a6SAndroid Build Coastguard Worker op_index, 810*523fa7a6SAndroid Build Coastguard Worker args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id], 811*523fa7a6SAndroid Build Coastguard Worker ) 812*523fa7a6SAndroid Build Coastguard Worker ) 813*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(kernel) 814*523fa7a6SAndroid Build Coastguard Worker 815*523fa7a6SAndroid Build Coastguard Worker jump_bool_value = self._emit_evalue(EValue(Bool(False))) 816*523fa7a6SAndroid Build Coastguard Worker 817*523fa7a6SAndroid Build Coastguard Worker # Generate the kernel call to check whether or not we have completed all the iterations. If 818*523fa7a6SAndroid Build Coastguard Worker # not jump back to the select_copy instruction that we generated at the beginning of this 819*523fa7a6SAndroid Build Coastguard Worker # section. 820*523fa7a6SAndroid Build Coastguard Worker op_index, op = self._get_operator( 821*523fa7a6SAndroid Build Coastguard Worker name="executorch_prim::eq", 822*523fa7a6SAndroid Build Coastguard Worker overload="Scalar", 823*523fa7a6SAndroid Build Coastguard Worker ) 824*523fa7a6SAndroid Build Coastguard Worker kernel = Instruction( 825*523fa7a6SAndroid Build Coastguard Worker KernelCall( 826*523fa7a6SAndroid Build Coastguard Worker op_index, 827*523fa7a6SAndroid Build Coastguard Worker args=[iter_idx.id, sym_size.id, jump_bool_value.id], 828*523fa7a6SAndroid Build Coastguard Worker ) 829*523fa7a6SAndroid Build Coastguard Worker ) 830*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(kernel) 831*523fa7a6SAndroid Build Coastguard Worker 832*523fa7a6SAndroid Build Coastguard Worker jf_beginning_loop = Instruction( 833*523fa7a6SAndroid Build Coastguard Worker JumpFalseCall( 834*523fa7a6SAndroid Build Coastguard Worker cond_value_index=jump_bool_value.id, 835*523fa7a6SAndroid Build Coastguard Worker destination_instruction=jump_to_instruction, 836*523fa7a6SAndroid Build Coastguard Worker ) 837*523fa7a6SAndroid Build Coastguard Worker ) 838*523fa7a6SAndroid Build Coastguard Worker 839*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(jf_beginning_loop) 840*523fa7a6SAndroid Build Coastguard Worker 841*523fa7a6SAndroid Build Coastguard Worker # Reset iter_idx in case we plan to run the model again. 842*523fa7a6SAndroid Build Coastguard Worker op_index, op = self._get_operator( 843*523fa7a6SAndroid Build Coastguard Worker name="executorch_prim::sub", 844*523fa7a6SAndroid Build Coastguard Worker overload="Scalar", 845*523fa7a6SAndroid Build Coastguard Worker ) 846*523fa7a6SAndroid Build Coastguard Worker kernel = Instruction( 847*523fa7a6SAndroid Build Coastguard Worker KernelCall( 848*523fa7a6SAndroid Build Coastguard Worker op_index, 849*523fa7a6SAndroid Build Coastguard Worker args=[iter_idx.id, sym_size.id, iter_idx.id], 850*523fa7a6SAndroid Build Coastguard Worker ) 851*523fa7a6SAndroid Build Coastguard Worker ) 852*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(kernel) 853*523fa7a6SAndroid Build Coastguard Worker 854*523fa7a6SAndroid Build Coastguard Worker return subemitter_binding_output_values 855*523fa7a6SAndroid Build Coastguard Worker 856*523fa7a6SAndroid Build Coastguard Worker def _emit_control_flow( 857*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 858*523fa7a6SAndroid Build Coastguard Worker ) -> _EmitterValue: 859*523fa7a6SAndroid Build Coastguard Worker """Wraps common logic for emitting all control flow operations. 860*523fa7a6SAndroid Build Coastguard Worker 861*523fa7a6SAndroid Build Coastguard Worker See the more specific emission functions for more details on how cond or map get emitted. 862*523fa7a6SAndroid Build Coastguard Worker """ 863*523fa7a6SAndroid Build Coastguard Worker subemitter_binding_output_values = pytree.tree_map( 864*523fa7a6SAndroid Build Coastguard Worker lambda spec: self._emit_spec(spec), 865*523fa7a6SAndroid Build Coastguard Worker self.node.meta["spec"], 866*523fa7a6SAndroid Build Coastguard Worker ) 867*523fa7a6SAndroid Build Coastguard Worker 868*523fa7a6SAndroid Build Coastguard Worker if target is torch.ops.higher_order.cond: 869*523fa7a6SAndroid Build Coastguard Worker return self._emit_cond(args, subemitter_binding_output_values) 870*523fa7a6SAndroid Build Coastguard Worker elif target is torch.ops.higher_order.map_impl: 871*523fa7a6SAndroid Build Coastguard Worker return self._emit_map(args, subemitter_binding_output_values) 872*523fa7a6SAndroid Build Coastguard Worker else: 873*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 874*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 875*523fa7a6SAndroid Build Coastguard Worker self.node, f"Unsupported control flow operator: {target}" 876*523fa7a6SAndroid Build Coastguard Worker ) 877*523fa7a6SAndroid Build Coastguard Worker ) 878*523fa7a6SAndroid Build Coastguard Worker 879*523fa7a6SAndroid Build Coastguard Worker def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue: 880*523fa7a6SAndroid Build Coastguard Worker assert len(args) == 2 881*523fa7a6SAndroid Build Coastguard Worker 882*523fa7a6SAndroid Build Coastguard Worker self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6] 883*523fa7a6SAndroid Build Coastguard Worker size_arg = self._emit_argument(args[1], torch.ListType.ofInts()) 884*523fa7a6SAndroid Build Coastguard Worker out_arg = self._emit_argument( 885*523fa7a6SAndroid Build Coastguard Worker self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6] 886*523fa7a6SAndroid Build Coastguard Worker ) 887*523fa7a6SAndroid Build Coastguard Worker 888*523fa7a6SAndroid Build Coastguard Worker op_idx, op = self._get_operator( 889*523fa7a6SAndroid Build Coastguard Worker name="executorch_prim::et_view", 890*523fa7a6SAndroid Build Coastguard Worker overload="default", 891*523fa7a6SAndroid Build Coastguard Worker ) 892*523fa7a6SAndroid Build Coastguard Worker kernel = Instruction( 893*523fa7a6SAndroid Build Coastguard Worker KernelCall( 894*523fa7a6SAndroid Build Coastguard Worker op_idx, 895*523fa7a6SAndroid Build Coastguard Worker args=[ 896*523fa7a6SAndroid Build Coastguard Worker self_arg.id, 897*523fa7a6SAndroid Build Coastguard Worker size_arg.id, 898*523fa7a6SAndroid Build Coastguard Worker out_arg.id, 899*523fa7a6SAndroid Build Coastguard Worker ], 900*523fa7a6SAndroid Build Coastguard Worker ) 901*523fa7a6SAndroid Build Coastguard Worker ) 902*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(kernel) 903*523fa7a6SAndroid Build Coastguard Worker return out_arg 904*523fa7a6SAndroid Build Coastguard Worker 905*523fa7a6SAndroid Build Coastguard Worker def _add_debug_handle( 906*523fa7a6SAndroid Build Coastguard Worker self, 907*523fa7a6SAndroid Build Coastguard Worker emitter_id: int, 908*523fa7a6SAndroid Build Coastguard Worker target: _Target, 909*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore[11]: Annotation `LoweredBackendModule` is not defined as a type. 910*523fa7a6SAndroid Build Coastguard Worker lowered_module: "Optional[LoweredBackendModule]" = None, # noqa: F821 911*523fa7a6SAndroid Build Coastguard Worker ) -> None: 912*523fa7a6SAndroid Build Coastguard Worker """Updates the debug handle information for the current node. 913*523fa7a6SAndroid Build Coastguard Worker 914*523fa7a6SAndroid Build Coastguard Worker If the current node is a delegate we agregate the debug handles of the subgraph and store 915*523fa7a6SAndroid Build Coastguard Worker them in the map. If the current node is any other type we store the original information in 916*523fa7a6SAndroid Build Coastguard Worker the debug handle map and replace it with the executorch instruction index corresponding to 917*523fa7a6SAndroid Build Coastguard Worker this node. 918*523fa7a6SAndroid Build Coastguard Worker """ 919*523fa7a6SAndroid Build Coastguard Worker # If it's a delegate call, collect the list of debug handles that are consumed by this 920*523fa7a6SAndroid Build Coastguard Worker # delegate call and store it in the debug handle map. 921*523fa7a6SAndroid Build Coastguard Worker if target == executorch_call_delegate: 922*523fa7a6SAndroid Build Coastguard Worker debug_handle_list = [] 923*523fa7a6SAndroid Build Coastguard Worker # Use the lowered_module to fetch the original graph and its debug 924*523fa7a6SAndroid Build Coastguard Worker # handles. 925*523fa7a6SAndroid Build Coastguard Worker for node in lowered_module.original_module.graph.nodes: 926*523fa7a6SAndroid Build Coastguard Worker if ( 927*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 928*523fa7a6SAndroid Build Coastguard Worker and node.meta.get("debug_handle") is not None 929*523fa7a6SAndroid Build Coastguard Worker ): 930*523fa7a6SAndroid Build Coastguard Worker debug_handle_list.append(node.meta.get("debug_handle")) 931*523fa7a6SAndroid Build Coastguard Worker self.debug_handle_map[emitter_id] = debug_handle_list 932*523fa7a6SAndroid Build Coastguard Worker # Debug handle for this node is the emitter_id which is essentially the index of the 933*523fa7a6SAndroid Build Coastguard Worker # instruction in the chain. 934*523fa7a6SAndroid Build Coastguard Worker self.node.meta["debug_handle"] = emitter_id 935*523fa7a6SAndroid Build Coastguard Worker return 936*523fa7a6SAndroid Build Coastguard Worker 937*523fa7a6SAndroid Build Coastguard Worker if self.node.meta.get("debug_handle") is not None: 938*523fa7a6SAndroid Build Coastguard Worker # Store the original debug handle in the debug handle map. 939*523fa7a6SAndroid Build Coastguard Worker self.debug_handle_map[emitter_id] = self.node.meta.get("debug_handle") 940*523fa7a6SAndroid Build Coastguard Worker # Replace the debug handle in the metadata of the node with the emitter id which 941*523fa7a6SAndroid Build Coastguard Worker # represents the instruction index in the chain. We do this because in the runtime the 942*523fa7a6SAndroid Build Coastguard Worker # instruction index is what is logged during perf/debug data logging and hence we want 943*523fa7a6SAndroid Build Coastguard Worker # to store this in the node so that we can map the data logged by the runtime back to 944*523fa7a6SAndroid Build Coastguard Worker # the node. 945*523fa7a6SAndroid Build Coastguard Worker self.node.meta["debug_handle"] = emitter_id 946*523fa7a6SAndroid Build Coastguard Worker 947*523fa7a6SAndroid Build Coastguard Worker def _add_delegate_map( 948*523fa7a6SAndroid Build Coastguard Worker self, 949*523fa7a6SAndroid Build Coastguard Worker lowered_module: "LoweredBackendModule", # noqa 950*523fa7a6SAndroid Build Coastguard Worker delegate_instruction_id: int, 951*523fa7a6SAndroid Build Coastguard Worker ) -> None: 952*523fa7a6SAndroid Build Coastguard Worker """ 953*523fa7a6SAndroid Build Coastguard Worker Store the delegate map from this lowered module into the dictionary of delegate maps. It 954*523fa7a6SAndroid Build Coastguard Worker will later be used for various debugging purposes such as linking back to original source 955*523fa7a6SAndroid Build Coastguard Worker code, module hierarchy etc. 956*523fa7a6SAndroid Build Coastguard Worker """ 957*523fa7a6SAndroid Build Coastguard Worker delegate_map = {} 958*523fa7a6SAndroid Build Coastguard Worker if hasattr(lowered_module, "meta"): 959*523fa7a6SAndroid Build Coastguard Worker delegate_map = lowered_module.meta.get("debug_handle_map", {}) 960*523fa7a6SAndroid Build Coastguard Worker 961*523fa7a6SAndroid Build Coastguard Worker self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = { 962*523fa7a6SAndroid Build Coastguard Worker "name": lowered_module.backend_id, 963*523fa7a6SAndroid Build Coastguard Worker "delegate_map": delegate_map, 964*523fa7a6SAndroid Build Coastguard Worker } 965*523fa7a6SAndroid Build Coastguard Worker 966*523fa7a6SAndroid Build Coastguard Worker def _emit_argument( 967*523fa7a6SAndroid Build Coastguard Worker self, arg: _Argument, arg_type: Optional[_SchemaType] 968*523fa7a6SAndroid Build Coastguard Worker ) -> _AbstractValue: 969*523fa7a6SAndroid Build Coastguard Worker """Emit an argument to an operator or delegate if it had not already been emitted otherwise 970*523fa7a6SAndroid Build Coastguard Worker return the previously emitted location""" 971*523fa7a6SAndroid Build Coastguard Worker if isinstance(arg, _AbstractValue): 972*523fa7a6SAndroid Build Coastguard Worker return arg 973*523fa7a6SAndroid Build Coastguard Worker return self._emit_evalue(self._constant_to_evalue(arg, arg_type)) 974*523fa7a6SAndroid Build Coastguard Worker 975*523fa7a6SAndroid Build Coastguard Worker def _get_sym_ret( 976*523fa7a6SAndroid Build Coastguard Worker self, 977*523fa7a6SAndroid Build Coastguard Worker val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]], 978*523fa7a6SAndroid Build Coastguard Worker ) -> Optional[_AbstractValue]: 979*523fa7a6SAndroid Build Coastguard Worker """ 980*523fa7a6SAndroid Build Coastguard Worker Returns the emit ret for sym value. 981*523fa7a6SAndroid Build Coastguard Worker """ 982*523fa7a6SAndroid Build Coastguard Worker ret = None 983*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, torch.SymInt): 984*523fa7a6SAndroid Build Coastguard Worker ret = self._emit_evalue(EValue(Int(0))) 985*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, torch.BoolType): 986*523fa7a6SAndroid Build Coastguard Worker ret = self._emit_evalue(EValue(Bool(False))) 987*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, torch.FloatType): 988*523fa7a6SAndroid Build Coastguard Worker ret = self._emit_evalue(EValue(Double(0))) 989*523fa7a6SAndroid Build Coastguard Worker return ret 990*523fa7a6SAndroid Build Coastguard Worker 991*523fa7a6SAndroid Build Coastguard Worker def _get_sym_and_fake_tensor_ret( 992*523fa7a6SAndroid Build Coastguard Worker self, 993*523fa7a6SAndroid Build Coastguard Worker val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]], 994*523fa7a6SAndroid Build Coastguard Worker spec: TensorSpec, 995*523fa7a6SAndroid Build Coastguard Worker ) -> Union[List[_AbstractValue], _AbstractValue, Tuple[_AbstractValue, ...]]: 996*523fa7a6SAndroid Build Coastguard Worker # Try to get the ret if it's a sym value. 997*523fa7a6SAndroid Build Coastguard Worker ret = self._get_sym_ret(val) 998*523fa7a6SAndroid Build Coastguard Worker # If the ret is None, it means that the val is not a sym value, but a regular tensor 999*523fa7a6SAndroid Build Coastguard Worker if ret is None: 1000*523fa7a6SAndroid Build Coastguard Worker ret = self._emit_spec(spec) 1001*523fa7a6SAndroid Build Coastguard Worker assert ret is not None, "Can't have a None ret" 1002*523fa7a6SAndroid Build Coastguard Worker return ret 1003*523fa7a6SAndroid Build Coastguard Worker 1004*523fa7a6SAndroid Build Coastguard Worker def _emit_delegate( 1005*523fa7a6SAndroid Build Coastguard Worker self, 1006*523fa7a6SAndroid Build Coastguard Worker lowered_module: "LoweredBackendModule", # noqa 1007*523fa7a6SAndroid Build Coastguard Worker args: Tuple[_Argument, ...], 1008*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, _Argument], 1009*523fa7a6SAndroid Build Coastguard Worker ) -> _EmitterValue: 1010*523fa7a6SAndroid Build Coastguard Worker """Emit the delegates inputs and outputs as specified by the schema, then emit the 1011*523fa7a6SAndroid Build Coastguard Worker delegate's blob.""" 1012*523fa7a6SAndroid Build Coastguard Worker processed_bytes = lowered_module.processed_bytes 1013*523fa7a6SAndroid Build Coastguard Worker 1014*523fa7a6SAndroid Build Coastguard Worker delegate_index = self.emitter_state.delegate_cache.get(processed_bytes) 1015*523fa7a6SAndroid Build Coastguard Worker delegate_ret = None 1016*523fa7a6SAndroid Build Coastguard Worker 1017*523fa7a6SAndroid Build Coastguard Worker if isinstance(self.node.meta["spec"], list): 1018*523fa7a6SAndroid Build Coastguard Worker delegate_ret = [] 1019*523fa7a6SAndroid Build Coastguard Worker for index, _ in enumerate(self.node.meta["val"]): 1020*523fa7a6SAndroid Build Coastguard Worker ret = self._get_sym_and_fake_tensor_ret( 1021*523fa7a6SAndroid Build Coastguard Worker self.node.meta["val"][index], self.node.meta["spec"][index] 1022*523fa7a6SAndroid Build Coastguard Worker ) 1023*523fa7a6SAndroid Build Coastguard Worker delegate_ret.append(ret) 1024*523fa7a6SAndroid Build Coastguard Worker elif isinstance(self.node.meta["spec"], tuple): 1025*523fa7a6SAndroid Build Coastguard Worker if isinstance(self.node.meta["val"], FakeTensor): 1026*523fa7a6SAndroid Build Coastguard Worker # There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor 1027*523fa7a6SAndroid Build Coastguard Worker ret = self._get_sym_and_fake_tensor_ret( 1028*523fa7a6SAndroid Build Coastguard Worker self.node.meta["val"], self.node.meta["spec"][0] 1029*523fa7a6SAndroid Build Coastguard Worker ) 1030*523fa7a6SAndroid Build Coastguard Worker delegate_ret = (ret,) 1031*523fa7a6SAndroid Build Coastguard Worker else: 1032*523fa7a6SAndroid Build Coastguard Worker delegate_ret = [] 1033*523fa7a6SAndroid Build Coastguard Worker for index, _ in enumerate(self.node.meta["val"]): 1034*523fa7a6SAndroid Build Coastguard Worker ret = self._get_sym_and_fake_tensor_ret( 1035*523fa7a6SAndroid Build Coastguard Worker self.node.meta["val"][index], self.node.meta["spec"][index] 1036*523fa7a6SAndroid Build Coastguard Worker ) 1037*523fa7a6SAndroid Build Coastguard Worker delegate_ret.append(ret) 1038*523fa7a6SAndroid Build Coastguard Worker delegate_ret = tuple(delegate_ret) 1039*523fa7a6SAndroid Build Coastguard Worker elif isinstance(self.node.meta["spec"], TensorSpec): 1040*523fa7a6SAndroid Build Coastguard Worker ret = self._get_sym_and_fake_tensor_ret( 1041*523fa7a6SAndroid Build Coastguard Worker self.node.meta["val"], self.node.meta["spec"] 1042*523fa7a6SAndroid Build Coastguard Worker ) 1043*523fa7a6SAndroid Build Coastguard Worker delegate_ret = ret 1044*523fa7a6SAndroid Build Coastguard Worker else: 1045*523fa7a6SAndroid Build Coastguard Worker raise NotImplementedError( 1046*523fa7a6SAndroid Build Coastguard Worker f"self.node.meta['spec'] {type(self.node.meta['spec'])} is not supported" 1047*523fa7a6SAndroid Build Coastguard Worker ) 1048*523fa7a6SAndroid Build Coastguard Worker assert delegate_ret is not None, "Can't have a None delegate_ret" 1049*523fa7a6SAndroid Build Coastguard Worker if delegate_index is None: 1050*523fa7a6SAndroid Build Coastguard Worker # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if 1051*523fa7a6SAndroid Build Coastguard Worker # present. 1052*523fa7a6SAndroid Build Coastguard Worker data_index: int = len(self.program_state.backend_delegate_data) 1053*523fa7a6SAndroid Build Coastguard Worker self.program_state.backend_delegate_data.append( 1054*523fa7a6SAndroid Build Coastguard Worker BackendDelegateInlineData(data=processed_bytes) 1055*523fa7a6SAndroid Build Coastguard Worker ) 1056*523fa7a6SAndroid Build Coastguard Worker 1057*523fa7a6SAndroid Build Coastguard Worker backend_delegate = BackendDelegate( 1058*523fa7a6SAndroid Build Coastguard Worker id=lowered_module.backend_id, 1059*523fa7a6SAndroid Build Coastguard Worker processed=BackendDelegateDataReference( 1060*523fa7a6SAndroid Build Coastguard Worker location=DataLocation.INLINE, index=data_index 1061*523fa7a6SAndroid Build Coastguard Worker ), 1062*523fa7a6SAndroid Build Coastguard Worker compile_specs=lowered_module.compile_specs, 1063*523fa7a6SAndroid Build Coastguard Worker ) 1064*523fa7a6SAndroid Build Coastguard Worker delegate_index = len(self.emitter_state.delegate_cache) 1065*523fa7a6SAndroid Build Coastguard Worker self.emitter_state.delegates.append(backend_delegate) 1066*523fa7a6SAndroid Build Coastguard Worker self.emitter_state.delegate_cache[processed_bytes] = delegate_index 1067*523fa7a6SAndroid Build Coastguard Worker 1068*523fa7a6SAndroid Build Coastguard Worker # TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the 1069*523fa7a6SAndroid Build Coastguard Worker # function's spec and with default arguments. This requires us to store the function's spec 1070*523fa7a6SAndroid Build Coastguard Worker # in to_backend() 1071*523fa7a6SAndroid Build Coastguard Worker delegate_args = [ 1072*523fa7a6SAndroid Build Coastguard Worker self._emit_argument(arg, None).id 1073*523fa7a6SAndroid Build Coastguard Worker for arg in typing.cast(List[_Argument], args) 1074*523fa7a6SAndroid Build Coastguard Worker ] 1075*523fa7a6SAndroid Build Coastguard Worker 1076*523fa7a6SAndroid Build Coastguard Worker for elem in pytree.tree_flatten(delegate_ret)[0]: 1077*523fa7a6SAndroid Build Coastguard Worker delegate_args.append(elem.id) 1078*523fa7a6SAndroid Build Coastguard Worker 1079*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append( 1080*523fa7a6SAndroid Build Coastguard Worker Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args)) 1081*523fa7a6SAndroid Build Coastguard Worker ) 1082*523fa7a6SAndroid Build Coastguard Worker 1083*523fa7a6SAndroid Build Coastguard Worker return delegate_ret 1084*523fa7a6SAndroid Build Coastguard Worker 1085*523fa7a6SAndroid Build Coastguard Worker def _get_operator(self, name: str, overload: str) -> Tuple[int, Operator]: 1086*523fa7a6SAndroid Build Coastguard Worker """Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it 1087*523fa7a6SAndroid Build Coastguard Worker if it is not already present""" 1088*523fa7a6SAndroid Build Coastguard Worker key = (name, overload) 1089*523fa7a6SAndroid Build Coastguard Worker op_index = self.emitter_state.operator_cache.get(key) 1090*523fa7a6SAndroid Build Coastguard Worker if op_index is not None: 1091*523fa7a6SAndroid Build Coastguard Worker return op_index, self.emitter_state.operators[op_index] 1092*523fa7a6SAndroid Build Coastguard Worker 1093*523fa7a6SAndroid Build Coastguard Worker op_index, operator = len(self.emitter_state.operators), Operator( 1094*523fa7a6SAndroid Build Coastguard Worker name=name, overload=overload 1095*523fa7a6SAndroid Build Coastguard Worker ) 1096*523fa7a6SAndroid Build Coastguard Worker self.emitter_state.operators.append(operator) 1097*523fa7a6SAndroid Build Coastguard Worker self.emitter_state.operator_cache[key] = op_index 1098*523fa7a6SAndroid Build Coastguard Worker return op_index, operator 1099*523fa7a6SAndroid Build Coastguard Worker 1100*523fa7a6SAndroid Build Coastguard Worker def _emit_operator( # noqa: C901 1101*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1102*523fa7a6SAndroid Build Coastguard Worker ) -> _EmitterValue: 1103*523fa7a6SAndroid Build Coastguard Worker """Emits an operator (aten or custom), directly translates to a call_kernel instruction.""" 1104*523fa7a6SAndroid Build Coastguard Worker assert isinstance( 1105*523fa7a6SAndroid Build Coastguard Worker target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload) 1106*523fa7a6SAndroid Build Coastguard Worker ), f"target is {target}" 1107*523fa7a6SAndroid Build Coastguard Worker 1108*523fa7a6SAndroid Build Coastguard Worker # grab the name 1109*523fa7a6SAndroid Build Coastguard Worker op_name = target._overloadpacket._qualified_op_name 1110*523fa7a6SAndroid Build Coastguard Worker op_overload = "" 1111*523fa7a6SAndroid Build Coastguard Worker if target._overloadname != "default": 1112*523fa7a6SAndroid Build Coastguard Worker op_overload = target._overloadname 1113*523fa7a6SAndroid Build Coastguard Worker 1114*523fa7a6SAndroid Build Coastguard Worker def _get_empty_tensor_evalue() -> EValue: 1115*523fa7a6SAndroid Build Coastguard Worker """Constructs an EValue for an empty tensor.""" 1116*523fa7a6SAndroid Build Coastguard Worker return EValue( 1117*523fa7a6SAndroid Build Coastguard Worker Tensor( 1118*523fa7a6SAndroid Build Coastguard Worker scalar_type=ScalarType.BYTE, 1119*523fa7a6SAndroid Build Coastguard Worker # The runtime currently only supports tensors with offset 0. 1120*523fa7a6SAndroid Build Coastguard Worker storage_offset=0, 1121*523fa7a6SAndroid Build Coastguard Worker sizes=[0], 1122*523fa7a6SAndroid Build Coastguard Worker dim_order=[], 1123*523fa7a6SAndroid Build Coastguard Worker requires_grad=False, 1124*523fa7a6SAndroid Build Coastguard Worker layout=0, 1125*523fa7a6SAndroid Build Coastguard Worker data_buffer_idx=0, 1126*523fa7a6SAndroid Build Coastguard Worker allocation_info=None, 1127*523fa7a6SAndroid Build Coastguard Worker shape_dynamism=TensorShapeDynamism.STATIC, 1128*523fa7a6SAndroid Build Coastguard Worker ) 1129*523fa7a6SAndroid Build Coastguard Worker ) 1130*523fa7a6SAndroid Build Coastguard Worker 1131*523fa7a6SAndroid Build Coastguard Worker op_index, operator = self._get_operator(name=op_name, overload=op_overload) 1132*523fa7a6SAndroid Build Coastguard Worker 1133*523fa7a6SAndroid Build Coastguard Worker # Emit the args and kwargs in the order according to the function schema. 1134*523fa7a6SAndroid Build Coastguard Worker kernel_args = [] 1135*523fa7a6SAndroid Build Coastguard Worker out_args = [] 1136*523fa7a6SAndroid Build Coastguard Worker for i, schema_arg in enumerate(target._schema.arguments): 1137*523fa7a6SAndroid Build Coastguard Worker if schema_arg.name in kwargs: 1138*523fa7a6SAndroid Build Coastguard Worker kernel_arg = kwargs[schema_arg.name] 1139*523fa7a6SAndroid Build Coastguard Worker elif not schema_arg.kwarg_only and i < len(args): 1140*523fa7a6SAndroid Build Coastguard Worker kernel_arg = args[i] 1141*523fa7a6SAndroid Build Coastguard Worker else: 1142*523fa7a6SAndroid Build Coastguard Worker # Emit default values 1143*523fa7a6SAndroid Build Coastguard Worker kernel_arg = schema_arg.default_value 1144*523fa7a6SAndroid Build Coastguard Worker 1145*523fa7a6SAndroid Build Coastguard Worker if kernel_arg is None and isinstance(schema_arg.type, torch.TensorType): 1146*523fa7a6SAndroid Build Coastguard Worker kernel_arg = self._emit_evalue(_get_empty_tensor_evalue()) 1147*523fa7a6SAndroid Build Coastguard Worker 1148*523fa7a6SAndroid Build Coastguard Worker kernel_args.append(self._emit_argument(kernel_arg, schema_arg.type).id) 1149*523fa7a6SAndroid Build Coastguard Worker 1150*523fa7a6SAndroid Build Coastguard Worker if schema_arg.is_out: 1151*523fa7a6SAndroid Build Coastguard Worker out_args.append((schema_arg.name, kernel_arg)) 1152*523fa7a6SAndroid Build Coastguard Worker 1153*523fa7a6SAndroid Build Coastguard Worker if is_out_variant(op_name, op_overload): 1154*523fa7a6SAndroid Build Coastguard Worker ret = [val for _, val in out_args] 1155*523fa7a6SAndroid Build Coastguard Worker ret = ret[0] if len(ret) == 1 else ret 1156*523fa7a6SAndroid Build Coastguard Worker elif is_sym_op(target): 1157*523fa7a6SAndroid Build Coastguard Worker assert ( 1158*523fa7a6SAndroid Build Coastguard Worker len(target._schema.returns) == 1 1159*523fa7a6SAndroid Build Coastguard Worker ), "Only returning a single Sym from symbolic ops is supported currently." 1160*523fa7a6SAndroid Build Coastguard Worker assert type(target._schema.returns[0].type) in ( 1161*523fa7a6SAndroid Build Coastguard Worker torch.IntType, 1162*523fa7a6SAndroid Build Coastguard Worker torch.FloatType, 1163*523fa7a6SAndroid Build Coastguard Worker torch.BoolType, 1164*523fa7a6SAndroid Build Coastguard Worker torch.NumberType, 1165*523fa7a6SAndroid Build Coastguard Worker ), f"Only symbolic ops that return a Int Bool Float are supported currently got {type(target._schema.returns[0].type)}." 1166*523fa7a6SAndroid Build Coastguard Worker ret = self._get_sym_ret(target._schema.returns[0]) 1167*523fa7a6SAndroid Build Coastguard Worker if ret is None: # type(target._schema.returns[0].type) == torch.NumberType: 1168*523fa7a6SAndroid Build Coastguard Worker # Cant definitively say what type this is, the runtime operator just overrides the EValue completely 1169*523fa7a6SAndroid Build Coastguard Worker # though so we can just serialize whatever as a placeholder. 1170*523fa7a6SAndroid Build Coastguard Worker ret = self._emit_evalue(EValue(Int(0))) 1171*523fa7a6SAndroid Build Coastguard Worker else: 1172*523fa7a6SAndroid Build Coastguard Worker ret = self._emit_spec(self.node.meta["spec"]) 1173*523fa7a6SAndroid Build Coastguard Worker 1174*523fa7a6SAndroid Build Coastguard Worker out_args = ( 1175*523fa7a6SAndroid Build Coastguard Worker self._emit_evalue( 1176*523fa7a6SAndroid Build Coastguard Worker EValue(TensorList([cast(_AbstractValue, val).id for val in ret])) 1177*523fa7a6SAndroid Build Coastguard Worker ) 1178*523fa7a6SAndroid Build Coastguard Worker if isinstance(ret, list) 1179*523fa7a6SAndroid Build Coastguard Worker else ret 1180*523fa7a6SAndroid Build Coastguard Worker ) 1181*523fa7a6SAndroid Build Coastguard Worker 1182*523fa7a6SAndroid Build Coastguard Worker for elem in pytree.tree_flatten(out_args)[0]: 1183*523fa7a6SAndroid Build Coastguard Worker kernel_args.append(cast(_AbstractValue, elem).id) 1184*523fa7a6SAndroid Build Coastguard Worker 1185*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append( 1186*523fa7a6SAndroid Build Coastguard Worker Instruction(KernelCall(op_index=op_index, args=kernel_args)) 1187*523fa7a6SAndroid Build Coastguard Worker ) 1188*523fa7a6SAndroid Build Coastguard Worker self._add_debug_handle(len(self.chain.instructions) - 1, target) 1189*523fa7a6SAndroid Build Coastguard Worker 1190*523fa7a6SAndroid Build Coastguard Worker # Get the stacktrace if it exists for each instruction. 1191*523fa7a6SAndroid Build Coastguard Worker if self.emitter_state.emit_stacktrace: 1192*523fa7a6SAndroid Build Coastguard Worker stack_trace = self.node.meta["stack_trace"] 1193*523fa7a6SAndroid Build Coastguard Worker chain_stacktrace = self.chain.stacktrace or [] 1194*523fa7a6SAndroid Build Coastguard Worker 1195*523fa7a6SAndroid Build Coastguard Worker chain_stacktrace.append(_stacktrace_to_framelist(stack_trace)) 1196*523fa7a6SAndroid Build Coastguard Worker self._internal_assert_emitter( 1197*523fa7a6SAndroid Build Coastguard Worker len(chain_stacktrace) == len(self.chain.instructions), 1198*523fa7a6SAndroid Build Coastguard Worker self.node, 1199*523fa7a6SAndroid Build Coastguard Worker f"Each instruction should have corresponding stacktrace received {len(self.chain.instructions)} \ 1200*523fa7a6SAndroid Build Coastguard Worker instructions and {len(chain_stacktrace)} stacktraces", 1201*523fa7a6SAndroid Build Coastguard Worker ) 1202*523fa7a6SAndroid Build Coastguard Worker self.chain.stacktrace = chain_stacktrace 1203*523fa7a6SAndroid Build Coastguard Worker 1204*523fa7a6SAndroid Build Coastguard Worker return cast(_EmitterValue, ret) 1205*523fa7a6SAndroid Build Coastguard Worker 1206*523fa7a6SAndroid Build Coastguard Worker def _emit_free(self, spec: TensorSpec) -> _AbstractValue: 1207*523fa7a6SAndroid Build Coastguard Worker """Emits a FreeCall instruction to release a given Unbounded Tensor's memory.""" 1208*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append( 1209*523fa7a6SAndroid Build Coastguard Worker Instruction(FreeCall(value_index=self.emitter_state.spec2id(spec))) 1210*523fa7a6SAndroid Build Coastguard Worker ) 1211*523fa7a6SAndroid Build Coastguard Worker # The value is not used but the caller expects an AbstractValue returned. 1212*523fa7a6SAndroid Build Coastguard Worker return _AbstractValue(None, None) # pyre-ignore 1213*523fa7a6SAndroid Build Coastguard Worker 1214*523fa7a6SAndroid Build Coastguard Worker def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan]: 1215*523fa7a6SAndroid Build Coastguard Worker """ 1216*523fa7a6SAndroid Build Coastguard Worker Given a mapping of function names to return values, emit simple execution 1217*523fa7a6SAndroid Build Coastguard Worker plans that just return these constant values. 1218*523fa7a6SAndroid Build Coastguard Worker 1219*523fa7a6SAndroid Build Coastguard Worker Precondition: All the values are primitives (bool, float, int, str, enum) 1220*523fa7a6SAndroid Build Coastguard Worker or structures (list, dict) of them. 1221*523fa7a6SAndroid Build Coastguard Worker """ 1222*523fa7a6SAndroid Build Coastguard Worker plans = [] 1223*523fa7a6SAndroid Build Coastguard Worker # flatten any structures 1224*523fa7a6SAndroid Build Coastguard Worker for method, vals in prim_getters.items(): 1225*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 1226*523fa7a6SAndroid Build Coastguard Worker flattened_output, spec = ex_pytree.tree_flatten(vals) 1227*523fa7a6SAndroid Build Coastguard Worker spec = spec.to_str() 1228*523fa7a6SAndroid Build Coastguard Worker chain = Chain( 1229*523fa7a6SAndroid Build Coastguard Worker inputs=[], 1230*523fa7a6SAndroid Build Coastguard Worker outputs=[], 1231*523fa7a6SAndroid Build Coastguard Worker instructions=[], 1232*523fa7a6SAndroid Build Coastguard Worker stacktrace=None, 1233*523fa7a6SAndroid Build Coastguard Worker ) 1234*523fa7a6SAndroid Build Coastguard Worker 1235*523fa7a6SAndroid Build Coastguard Worker # switch on type of prim 1236*523fa7a6SAndroid Build Coastguard Worker values = [] 1237*523fa7a6SAndroid Build Coastguard Worker for val in flattened_output: 1238*523fa7a6SAndroid Build Coastguard Worker if isinstance(val, float): 1239*523fa7a6SAndroid Build Coastguard Worker values.append(EValue(Double(val))) 1240*523fa7a6SAndroid Build Coastguard Worker 1241*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, bool): 1242*523fa7a6SAndroid Build Coastguard Worker values.append(EValue(Bool(val))) 1243*523fa7a6SAndroid Build Coastguard Worker 1244*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, int): 1245*523fa7a6SAndroid Build Coastguard Worker values.append(EValue(Int(val))) 1246*523fa7a6SAndroid Build Coastguard Worker 1247*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, str): 1248*523fa7a6SAndroid Build Coastguard Worker values.append(EValue(String(val))) 1249*523fa7a6SAndroid Build Coastguard Worker 1250*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, torch.dtype): 1251*523fa7a6SAndroid Build Coastguard Worker values.append(EValue(Int(scalar_type_enum(val)))) 1252*523fa7a6SAndroid Build Coastguard Worker 1253*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, torch.layout): 1254*523fa7a6SAndroid Build Coastguard Worker values.append(EValue(Int(layout_enum(val)))) 1255*523fa7a6SAndroid Build Coastguard Worker 1256*523fa7a6SAndroid Build Coastguard Worker elif isinstance(val, torch.Tensor): 1257*523fa7a6SAndroid Build Coastguard Worker values.append( 1258*523fa7a6SAndroid Build Coastguard Worker self._tensor_spec_to_evalue( 1259*523fa7a6SAndroid Build Coastguard Worker TensorSpec.from_tensor(val, const=True) 1260*523fa7a6SAndroid Build Coastguard Worker ) 1261*523fa7a6SAndroid Build Coastguard Worker ) 1262*523fa7a6SAndroid Build Coastguard Worker 1263*523fa7a6SAndroid Build Coastguard Worker else: 1264*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 1265*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 1266*523fa7a6SAndroid Build Coastguard Worker f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive", 1267*523fa7a6SAndroid Build Coastguard Worker ) 1268*523fa7a6SAndroid Build Coastguard Worker 1269*523fa7a6SAndroid Build Coastguard Worker # add to plans 1270*523fa7a6SAndroid Build Coastguard Worker plans.append( 1271*523fa7a6SAndroid Build Coastguard Worker ExecutionPlan( 1272*523fa7a6SAndroid Build Coastguard Worker name=method, 1273*523fa7a6SAndroid Build Coastguard Worker values=values, 1274*523fa7a6SAndroid Build Coastguard Worker inputs=[], 1275*523fa7a6SAndroid Build Coastguard Worker outputs=list(range(0, len(values))), 1276*523fa7a6SAndroid Build Coastguard Worker chains=[chain], 1277*523fa7a6SAndroid Build Coastguard Worker operators=[], 1278*523fa7a6SAndroid Build Coastguard Worker delegates=[], 1279*523fa7a6SAndroid Build Coastguard Worker non_const_buffer_sizes=[0], 1280*523fa7a6SAndroid Build Coastguard Worker container_meta_type=ContainerMetadata("", spec), 1281*523fa7a6SAndroid Build Coastguard Worker ) 1282*523fa7a6SAndroid Build Coastguard Worker ) 1283*523fa7a6SAndroid Build Coastguard Worker return plans 1284*523fa7a6SAndroid Build Coastguard Worker 1285*523fa7a6SAndroid Build Coastguard Worker def fetch_attr(self, target: _Target) -> _AbstractValue: 1286*523fa7a6SAndroid Build Coastguard Worker """Fetch weights and other module parameters. If the attribute is a tensor, emit it.""" 1287*523fa7a6SAndroid Build Coastguard Worker attr = super().fetch_attr(target) # pyre-fixme[6] 1288*523fa7a6SAndroid Build Coastguard Worker 1289*523fa7a6SAndroid Build Coastguard Worker if isinstance(attr, torch.Tensor): 1290*523fa7a6SAndroid Build Coastguard Worker return self._emit_evalue( 1291*523fa7a6SAndroid Build Coastguard Worker self._tensor_spec_to_evalue(TensorSpec.from_tensor(attr, const=True)) 1292*523fa7a6SAndroid Build Coastguard Worker ) 1293*523fa7a6SAndroid Build Coastguard Worker 1294*523fa7a6SAndroid Build Coastguard Worker elif isinstance(attr, torch._C.ScriptObject): 1295*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 1296*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 1297*523fa7a6SAndroid Build Coastguard Worker f"Custom class {attr} is not supported in EXIR", 1298*523fa7a6SAndroid Build Coastguard Worker ) 1299*523fa7a6SAndroid Build Coastguard Worker 1300*523fa7a6SAndroid Build Coastguard Worker else: 1301*523fa7a6SAndroid Build Coastguard Worker return attr 1302*523fa7a6SAndroid Build Coastguard Worker 1303*523fa7a6SAndroid Build Coastguard Worker def call_module( # pyre-fixme[14] 1304*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1305*523fa7a6SAndroid Build Coastguard Worker ) -> None: 1306*523fa7a6SAndroid Build Coastguard Worker """Unsupported in execution IR, so unhandled by the emitter.""" 1307*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 1308*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error(self.node, "call_module is not supported") 1309*523fa7a6SAndroid Build Coastguard Worker ) 1310*523fa7a6SAndroid Build Coastguard Worker 1311*523fa7a6SAndroid Build Coastguard Worker def call_method( # pyre-fixme[14] 1312*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1313*523fa7a6SAndroid Build Coastguard Worker ) -> _EmitterValue: 1314*523fa7a6SAndroid Build Coastguard Worker """Unsupported in execution IR, so unhandled by the emitter.""" 1315*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 1316*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error(self.node, "call_method is not supported") 1317*523fa7a6SAndroid Build Coastguard Worker ) 1318*523fa7a6SAndroid Build Coastguard Worker 1319*523fa7a6SAndroid Build Coastguard Worker def placeholder( # pyre-fixme[14] 1320*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1321*523fa7a6SAndroid Build Coastguard Worker ) -> _AbstractValue: 1322*523fa7a6SAndroid Build Coastguard Worker """Performs actions for the placeholder node of a graph module. 1323*523fa7a6SAndroid Build Coastguard Worker 1324*523fa7a6SAndroid Build Coastguard Worker The placeholder node of the top level entry point is handled by TopLevelEmitter. This 1325*523fa7a6SAndroid Build Coastguard Worker function only executes on control flow subgraphs. Takes the inputs of the subgraph that had 1326*523fa7a6SAndroid Build Coastguard Worker not previously been emitted and emits them. 1327*523fa7a6SAndroid Build Coastguard Worker """ 1328*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. 1329*523fa7a6SAndroid Build Coastguard Worker value = self.binding_input_values[self.placeholder_count] 1330*523fa7a6SAndroid Build Coastguard Worker # This indicates that the placeholder wasn't allocated an EValue id before this sub-emitter 1331*523fa7a6SAndroid Build Coastguard Worker # was run, so we generate one now. 1332*523fa7a6SAndroid Build Coastguard Worker if value == -1: 1333*523fa7a6SAndroid Build Coastguard Worker value = self._emit_evalue( 1334*523fa7a6SAndroid Build Coastguard Worker self._tensor_spec_to_evalue(self.node.meta["spec"]) 1335*523fa7a6SAndroid Build Coastguard Worker ) 1336*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. 1337*523fa7a6SAndroid Build Coastguard Worker self.binding_input_values[self.placeholder_count] = value 1338*523fa7a6SAndroid Build Coastguard Worker self.placeholder_count += 1 1339*523fa7a6SAndroid Build Coastguard Worker return value 1340*523fa7a6SAndroid Build Coastguard Worker 1341*523fa7a6SAndroid Build Coastguard Worker def output( # pyre-fixme[14] 1342*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1343*523fa7a6SAndroid Build Coastguard Worker ) -> None: 1344*523fa7a6SAndroid Build Coastguard Worker """Performs actions for the output node of a graph module. 1345*523fa7a6SAndroid Build Coastguard Worker 1346*523fa7a6SAndroid Build Coastguard Worker The output node of the top level entry point is handled by TopLevelEmitter. This function 1347*523fa7a6SAndroid Build Coastguard Worker only executes on control flow subgraphs. Takes the outputs of the subgraph (if any) and 1348*523fa7a6SAndroid Build Coastguard Worker inserts instructions to move them to the common output location between control flow 1349*523fa7a6SAndroid Build Coastguard Worker branches. 1350*523fa7a6SAndroid Build Coastguard Worker """ 1351*523fa7a6SAndroid Build Coastguard Worker self.concrete_output_ids = list(pytree.tree_flatten(args[0])[0]) 1352*523fa7a6SAndroid Build Coastguard Worker binding_output_values = self.binding_output_values 1353*523fa7a6SAndroid Build Coastguard Worker if binding_output_values is not None: 1354*523fa7a6SAndroid Build Coastguard Worker binding_output_list, _ = pytree.tree_flatten(binding_output_values) 1355*523fa7a6SAndroid Build Coastguard Worker 1356*523fa7a6SAndroid Build Coastguard Worker self._internal_assert_emitter( 1357*523fa7a6SAndroid Build Coastguard Worker len(binding_output_list) == len(self.concrete_output_ids), 1358*523fa7a6SAndroid Build Coastguard Worker self.node, 1359*523fa7a6SAndroid Build Coastguard Worker "The number of binding output values should match the args to output", 1360*523fa7a6SAndroid Build Coastguard Worker ) 1361*523fa7a6SAndroid Build Coastguard Worker 1362*523fa7a6SAndroid Build Coastguard Worker for move_from, move_to in zip( 1363*523fa7a6SAndroid Build Coastguard Worker self.concrete_output_ids, binding_output_list 1364*523fa7a6SAndroid Build Coastguard Worker ): 1365*523fa7a6SAndroid Build Coastguard Worker if move_from != move_to: 1366*523fa7a6SAndroid Build Coastguard Worker instruction = Instruction( 1367*523fa7a6SAndroid Build Coastguard Worker MoveCall(move_from=move_from.id, move_to=move_to.id) 1368*523fa7a6SAndroid Build Coastguard Worker ) 1369*523fa7a6SAndroid Build Coastguard Worker self.chain.instructions.append(instruction) 1370*523fa7a6SAndroid Build Coastguard Worker 1371*523fa7a6SAndroid Build Coastguard Worker def call_function( # pyre-fixme[14] 1372*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1373*523fa7a6SAndroid Build Coastguard Worker ) -> _EmitterValue: 1374*523fa7a6SAndroid Build Coastguard Worker """Performs actions for the call_function node of a graph module. 1375*523fa7a6SAndroid Build Coastguard Worker 1376*523fa7a6SAndroid Build Coastguard Worker Dispatches based on 'target' and emits the corresponding function. 'call_function' is a 1377*523fa7a6SAndroid Build Coastguard Worker powerful node that contains many operations ranging from control_flow, to memory management, 1378*523fa7a6SAndroid Build Coastguard Worker to delegate and operator calls. 1379*523fa7a6SAndroid Build Coastguard Worker """ 1380*523fa7a6SAndroid Build Coastguard Worker 1381*523fa7a6SAndroid Build Coastguard Worker # Delegate and operator calls are the only functions that should have a debug handle 1382*523fa7a6SAndroid Build Coastguard Worker # associated with them. All the others such as memory.alloc, getitem should be ignored. 1383*523fa7a6SAndroid Build Coastguard Worker # Default to none and let delegates and ops override. 1384*523fa7a6SAndroid Build Coastguard Worker if target == operator.getitem: 1385*523fa7a6SAndroid Build Coastguard Worker assert len(args) == 2 1386*523fa7a6SAndroid Build Coastguard Worker head = typing.cast(Mapping[int, _EmitterValue], args[0]) 1387*523fa7a6SAndroid Build Coastguard Worker index = typing.cast(int, args[1]) 1388*523fa7a6SAndroid Build Coastguard Worker return head[index] 1389*523fa7a6SAndroid Build Coastguard Worker 1390*523fa7a6SAndroid Build Coastguard Worker elif target == memory.alloc: 1391*523fa7a6SAndroid Build Coastguard Worker assert len(args) == 1 1392*523fa7a6SAndroid Build Coastguard Worker return self._emit_spec(self.node.meta["spec"]) 1393*523fa7a6SAndroid Build Coastguard Worker 1394*523fa7a6SAndroid Build Coastguard Worker elif target == memory.view: 1395*523fa7a6SAndroid Build Coastguard Worker return self._emit_view(args) 1396*523fa7a6SAndroid Build Coastguard Worker 1397*523fa7a6SAndroid Build Coastguard Worker elif target == memory.free: 1398*523fa7a6SAndroid Build Coastguard Worker assert len(args) == 1 1399*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 1400*523fa7a6SAndroid Build Coastguard Worker return self._emit_free(args[0]) 1401*523fa7a6SAndroid Build Coastguard Worker 1402*523fa7a6SAndroid Build Coastguard Worker elif target is torch.ops.higher_order.cond: 1403*523fa7a6SAndroid Build Coastguard Worker return self._emit_control_flow(target, args, kwargs) 1404*523fa7a6SAndroid Build Coastguard Worker 1405*523fa7a6SAndroid Build Coastguard Worker elif target is torch.ops.higher_order.map_impl: 1406*523fa7a6SAndroid Build Coastguard Worker return self._emit_control_flow(target, args, kwargs) 1407*523fa7a6SAndroid Build Coastguard Worker 1408*523fa7a6SAndroid Build Coastguard Worker elif target == executorch_call_delegate: 1409*523fa7a6SAndroid Build Coastguard Worker lowered_module = args[0] 1410*523fa7a6SAndroid Build Coastguard Worker assert is_lowered_module(lowered_module) 1411*523fa7a6SAndroid Build Coastguard Worker v = self._emit_delegate(lowered_module, args[1:], kwargs) 1412*523fa7a6SAndroid Build Coastguard Worker delegate_instruction_id = len(self.chain.instructions) - 1 1413*523fa7a6SAndroid Build Coastguard Worker self._add_debug_handle(delegate_instruction_id, target, lowered_module) 1414*523fa7a6SAndroid Build Coastguard Worker self._add_delegate_map(lowered_module, delegate_instruction_id) 1415*523fa7a6SAndroid Build Coastguard Worker return v 1416*523fa7a6SAndroid Build Coastguard Worker 1417*523fa7a6SAndroid Build Coastguard Worker elif isinstance( 1418*523fa7a6SAndroid Build Coastguard Worker target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload) 1419*523fa7a6SAndroid Build Coastguard Worker ): 1420*523fa7a6SAndroid Build Coastguard Worker return self._emit_operator(target, args, kwargs) 1421*523fa7a6SAndroid Build Coastguard Worker 1422*523fa7a6SAndroid Build Coastguard Worker else: 1423*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 1424*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 1425*523fa7a6SAndroid Build Coastguard Worker self.node, f"invalid target for call_function {target}" 1426*523fa7a6SAndroid Build Coastguard Worker ) 1427*523fa7a6SAndroid Build Coastguard Worker ) 1428*523fa7a6SAndroid Build Coastguard Worker 1429*523fa7a6SAndroid Build Coastguard Worker def run( # pyre-fixme[14] 1430*523fa7a6SAndroid Build Coastguard Worker self, 1431*523fa7a6SAndroid Build Coastguard Worker *args: _Argument, 1432*523fa7a6SAndroid Build Coastguard Worker initial_env: Optional[Dict[torch.fx.Node, _Argument]] = None, 1433*523fa7a6SAndroid Build Coastguard Worker ) -> None: 1434*523fa7a6SAndroid Build Coastguard Worker """Traverses all nodes in the graph module and emits each one appropriately.""" 1435*523fa7a6SAndroid Build Coastguard Worker super().run(*args, initial_env, enable_io_processing=False) 1436*523fa7a6SAndroid Build Coastguard Worker 1437*523fa7a6SAndroid Build Coastguard Worker def run_node(self, n: torch.fx.Node) -> None: 1438*523fa7a6SAndroid Build Coastguard Worker """Executes and emits the specified node. 1439*523fa7a6SAndroid Build Coastguard Worker 1440*523fa7a6SAndroid Build Coastguard Worker For more context on what a node is and what execution means see 1441*523fa7a6SAndroid Build Coastguard Worker https://pytorch.org/docs/stable/fx.html#torch.fx.Node 1442*523fa7a6SAndroid Build Coastguard Worker """ 1443*523fa7a6SAndroid Build Coastguard Worker self.node = n 1444*523fa7a6SAndroid Build Coastguard Worker try: 1445*523fa7a6SAndroid Build Coastguard Worker ret = super().run_node(n) 1446*523fa7a6SAndroid Build Coastguard Worker except Exception as e: 1447*523fa7a6SAndroid Build Coastguard Worker if isinstance(e, (InternalError, ExportError)): 1448*523fa7a6SAndroid Build Coastguard Worker raise e 1449*523fa7a6SAndroid Build Coastguard Worker else: 1450*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 1451*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error(self.node, str(e)) 1452*523fa7a6SAndroid Build Coastguard Worker ) from e 1453*523fa7a6SAndroid Build Coastguard Worker return ret 1454*523fa7a6SAndroid Build Coastguard Worker 1455*523fa7a6SAndroid Build Coastguard Worker 1456*523fa7a6SAndroid Build Coastguard Workerclass _TopLevelEmitter(_Emitter): 1457*523fa7a6SAndroid Build Coastguard Worker """An emitter that manages the root level operations within a graph module. 1458*523fa7a6SAndroid Build Coastguard Worker 1459*523fa7a6SAndroid Build Coastguard Worker Exists as a separate class so that 'Emitter' can handle the special behavior of 'placeholder' 1460*523fa7a6SAndroid Build Coastguard Worker and 'output' nodes in control flow submodules. 1461*523fa7a6SAndroid Build Coastguard Worker """ 1462*523fa7a6SAndroid Build Coastguard Worker 1463*523fa7a6SAndroid Build Coastguard Worker def __init__( 1464*523fa7a6SAndroid Build Coastguard Worker self, 1465*523fa7a6SAndroid Build Coastguard Worker name: str, 1466*523fa7a6SAndroid Build Coastguard Worker exported_program: ExportedProgram, 1467*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 1468*523fa7a6SAndroid Build Coastguard Worker program_state: _ProgramState, 1469*523fa7a6SAndroid Build Coastguard Worker emitter_state: _EmitterState, 1470*523fa7a6SAndroid Build Coastguard Worker ) -> None: 1471*523fa7a6SAndroid Build Coastguard Worker super().__init__(graph_module, emitter_state, program_state) 1472*523fa7a6SAndroid Build Coastguard Worker self.name = name 1473*523fa7a6SAndroid Build Coastguard Worker self.exported_program = exported_program 1474*523fa7a6SAndroid Build Coastguard Worker 1475*523fa7a6SAndroid Build Coastguard Worker self.inputs: List[int] = [] 1476*523fa7a6SAndroid Build Coastguard Worker self.outputs: List[int] = [] 1477*523fa7a6SAndroid Build Coastguard Worker self.given_mutable_buffer_warning = False 1478*523fa7a6SAndroid Build Coastguard Worker 1479*523fa7a6SAndroid Build Coastguard Worker def create_container_str(spec: Optional[pytree.TreeSpec]) -> str: 1480*523fa7a6SAndroid Build Coastguard Worker if spec is None: 1481*523fa7a6SAndroid Build Coastguard Worker return "" 1482*523fa7a6SAndroid Build Coastguard Worker assert isinstance(spec, pytree.TreeSpec), type(spec) 1483*523fa7a6SAndroid Build Coastguard Worker dummy_leaves = [0] * spec.num_leaves 1484*523fa7a6SAndroid Build Coastguard Worker tree = torch.utils._pytree.tree_unflatten(dummy_leaves, spec) 1485*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 1486*523fa7a6SAndroid Build Coastguard Worker _, tree = ex_pytree.tree_flatten(tree) 1487*523fa7a6SAndroid Build Coastguard Worker return tree.to_str() 1488*523fa7a6SAndroid Build Coastguard Worker 1489*523fa7a6SAndroid Build Coastguard Worker inp_container_str = create_container_str(exported_program.call_spec.in_spec) 1490*523fa7a6SAndroid Build Coastguard Worker out_container_str = create_container_str(exported_program.call_spec.out_spec) 1491*523fa7a6SAndroid Build Coastguard Worker 1492*523fa7a6SAndroid Build Coastguard Worker self.container_meta_type = ContainerMetadata( 1493*523fa7a6SAndroid Build Coastguard Worker inp_container_str, out_container_str 1494*523fa7a6SAndroid Build Coastguard Worker ) 1495*523fa7a6SAndroid Build Coastguard Worker 1496*523fa7a6SAndroid Build Coastguard Worker def _find_fqn_for_placeholder( 1497*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, spec: Any # pyre-ignore[2] 1498*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[Optional[str], bool]: 1499*523fa7a6SAndroid Build Coastguard Worker # Find the fully qualified name 1500*523fa7a6SAndroid Build Coastguard Worker fqn = None 1501*523fa7a6SAndroid Build Coastguard Worker is_mutable_buffer = False 1502*523fa7a6SAndroid Build Coastguard Worker if target in self.exported_program.graph_signature.inputs_to_parameters: 1503*523fa7a6SAndroid Build Coastguard Worker fqn = self.exported_program.graph_signature.inputs_to_parameters[target] 1504*523fa7a6SAndroid Build Coastguard Worker 1505*523fa7a6SAndroid Build Coastguard Worker elif target in self.exported_program.graph_signature.inputs_to_buffers: 1506*523fa7a6SAndroid Build Coastguard Worker fqn = self.exported_program.graph_signature.inputs_to_buffers[target] 1507*523fa7a6SAndroid Build Coastguard Worker 1508*523fa7a6SAndroid Build Coastguard Worker # if the buffer is mutated then record that 1509*523fa7a6SAndroid Build Coastguard Worker if fqn in self.exported_program.graph_signature.buffers_to_mutate.values(): 1510*523fa7a6SAndroid Build Coastguard Worker is_mutable_buffer = True 1511*523fa7a6SAndroid Build Coastguard Worker if not self.given_mutable_buffer_warning: 1512*523fa7a6SAndroid Build Coastguard Worker warnings.warn( 1513*523fa7a6SAndroid Build Coastguard Worker "Mutation on a buffer in the model is detected. ExecuTorch assumes " 1514*523fa7a6SAndroid Build Coastguard Worker "buffers that are mutated in the graph have a meaningless initial state, " 1515*523fa7a6SAndroid Build Coastguard Worker "only the shape and dtype will be serialized.", 1516*523fa7a6SAndroid Build Coastguard Worker UserWarning, 1517*523fa7a6SAndroid Build Coastguard Worker stacklevel=1, 1518*523fa7a6SAndroid Build Coastguard Worker ) 1519*523fa7a6SAndroid Build Coastguard Worker self.given_mutable_buffer_warning = True 1520*523fa7a6SAndroid Build Coastguard Worker 1521*523fa7a6SAndroid Build Coastguard Worker elif ( 1522*523fa7a6SAndroid Build Coastguard Worker target 1523*523fa7a6SAndroid Build Coastguard Worker in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants 1524*523fa7a6SAndroid Build Coastguard Worker ): 1525*523fa7a6SAndroid Build Coastguard Worker fqn = ( 1526*523fa7a6SAndroid Build Coastguard Worker self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[ 1527*523fa7a6SAndroid Build Coastguard Worker target 1528*523fa7a6SAndroid Build Coastguard Worker ] 1529*523fa7a6SAndroid Build Coastguard Worker ) 1530*523fa7a6SAndroid Build Coastguard Worker return fqn, is_mutable_buffer 1531*523fa7a6SAndroid Build Coastguard Worker 1532*523fa7a6SAndroid Build Coastguard Worker def placeholder( 1533*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1534*523fa7a6SAndroid Build Coastguard Worker ) -> _AbstractValue: 1535*523fa7a6SAndroid Build Coastguard Worker """Emits the value within the placeholder node. 1536*523fa7a6SAndroid Build Coastguard Worker 1537*523fa7a6SAndroid Build Coastguard Worker For more information on placeholder nodes see 1538*523fa7a6SAndroid Build Coastguard Worker https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder 1539*523fa7a6SAndroid Build Coastguard Worker """ 1540*523fa7a6SAndroid Build Coastguard Worker spec = self.node.meta["spec"] 1541*523fa7a6SAndroid Build Coastguard Worker is_user_input = True 1542*523fa7a6SAndroid Build Coastguard Worker 1543*523fa7a6SAndroid Build Coastguard Worker if isinstance(target, str) and isinstance(spec, TensorSpec): 1544*523fa7a6SAndroid Build Coastguard Worker fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) 1545*523fa7a6SAndroid Build Coastguard Worker 1546*523fa7a6SAndroid Build Coastguard Worker # From the fqn find the corresponding tensor 1547*523fa7a6SAndroid Build Coastguard Worker real_tensor = None 1548*523fa7a6SAndroid Build Coastguard Worker if fqn in self.exported_program.state_dict: 1549*523fa7a6SAndroid Build Coastguard Worker real_tensor = self.exported_program.state_dict[fqn] 1550*523fa7a6SAndroid Build Coastguard Worker is_user_input = False 1551*523fa7a6SAndroid Build Coastguard Worker 1552*523fa7a6SAndroid Build Coastguard Worker elif fqn in self.exported_program.constants: 1553*523fa7a6SAndroid Build Coastguard Worker real_tensor = self.exported_program.constants[fqn] 1554*523fa7a6SAndroid Build Coastguard Worker is_user_input = False 1555*523fa7a6SAndroid Build Coastguard Worker elif fqn is not None: 1556*523fa7a6SAndroid Build Coastguard Worker buffers = self.exported_program.named_buffers() 1557*523fa7a6SAndroid Build Coastguard Worker buf = next((x[1] for x in buffers if x[0] == fqn), None) 1558*523fa7a6SAndroid Build Coastguard Worker if buf is not None: 1559*523fa7a6SAndroid Build Coastguard Worker real_tensor = buf 1560*523fa7a6SAndroid Build Coastguard Worker is_user_input = False 1561*523fa7a6SAndroid Build Coastguard Worker else: 1562*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 1563*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 1564*523fa7a6SAndroid Build Coastguard Worker self.node, 1565*523fa7a6SAndroid Build Coastguard Worker f"Could not find buffer with fqn {fqn} in state_dict or named_buffers", 1566*523fa7a6SAndroid Build Coastguard Worker ) 1567*523fa7a6SAndroid Build Coastguard Worker ) 1568*523fa7a6SAndroid Build Coastguard Worker 1569*523fa7a6SAndroid Build Coastguard Worker # assign the storage of the placeholder spec to the storage of the real tensor if there is one 1570*523fa7a6SAndroid Build Coastguard Worker if real_tensor is not None: 1571*523fa7a6SAndroid Build Coastguard Worker # for non-contigous tensors, convert to a contiguous one 1572*523fa7a6SAndroid Build Coastguard Worker real_tensor = real_tensor.contiguous() 1573*523fa7a6SAndroid Build Coastguard Worker # Weights cannot be views during emission or serialization 1574*523fa7a6SAndroid Build Coastguard Worker if real_tensor.nbytes != real_tensor.untyped_storage().nbytes(): 1575*523fa7a6SAndroid Build Coastguard Worker real_tensor = real_tensor.clone() 1576*523fa7a6SAndroid Build Coastguard Worker 1577*523fa7a6SAndroid Build Coastguard Worker spec.storage = real_tensor.untyped_storage() 1578*523fa7a6SAndroid Build Coastguard Worker 1579*523fa7a6SAndroid Build Coastguard Worker # User inputs and mutable buffers are not constants, other buffers or parameters are. 1580*523fa7a6SAndroid Build Coastguard Worker spec.const = not (is_user_input or is_mutable_buffer) 1581*523fa7a6SAndroid Build Coastguard Worker 1582*523fa7a6SAndroid Build Coastguard Worker evalue = ( 1583*523fa7a6SAndroid Build Coastguard Worker self._tensor_spec_to_evalue(spec) 1584*523fa7a6SAndroid Build Coastguard Worker if isinstance(spec, TensorSpec) 1585*523fa7a6SAndroid Build Coastguard Worker else self._constant_to_evalue(spec, None) 1586*523fa7a6SAndroid Build Coastguard Worker ) 1587*523fa7a6SAndroid Build Coastguard Worker value = self._emit_evalue(evalue) 1588*523fa7a6SAndroid Build Coastguard Worker 1589*523fa7a6SAndroid Build Coastguard Worker # Only user inputs should remain as inputs. 1590*523fa7a6SAndroid Build Coastguard Worker if is_user_input: 1591*523fa7a6SAndroid Build Coastguard Worker self.inputs.append(value.id) 1592*523fa7a6SAndroid Build Coastguard Worker 1593*523fa7a6SAndroid Build Coastguard Worker return value 1594*523fa7a6SAndroid Build Coastguard Worker 1595*523fa7a6SAndroid Build Coastguard Worker def output( 1596*523fa7a6SAndroid Build Coastguard Worker self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 1597*523fa7a6SAndroid Build Coastguard Worker ) -> None: 1598*523fa7a6SAndroid Build Coastguard Worker """Records the ExecutionPlan's outputs based on the output node in the graph.""" 1599*523fa7a6SAndroid Build Coastguard Worker if isinstance(args[0], dict): 1600*523fa7a6SAndroid Build Coastguard Worker args_tuple, _ = pytree.tree_flatten(args[0]) 1601*523fa7a6SAndroid Build Coastguard Worker else: 1602*523fa7a6SAndroid Build Coastguard Worker args_tuple = typing.cast(Tuple[_AbstractValue, ...], args[0]) 1603*523fa7a6SAndroid Build Coastguard Worker if isinstance(args_tuple, _AbstractValue): 1604*523fa7a6SAndroid Build Coastguard Worker self.outputs.append(args_tuple.id) 1605*523fa7a6SAndroid Build Coastguard Worker else: 1606*523fa7a6SAndroid Build Coastguard Worker for arg in args_tuple: 1607*523fa7a6SAndroid Build Coastguard Worker if isinstance(arg, (int, float, bool, type(None))): 1608*523fa7a6SAndroid Build Coastguard Worker arg = self._emit_evalue(self._constant_to_evalue(arg, None)) 1609*523fa7a6SAndroid Build Coastguard Worker elif isinstance(arg, str): 1610*523fa7a6SAndroid Build Coastguard Worker # TODO(jackkhuu): T181599879 Add support for string outputs IFF compiler supports 1611*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 1612*523fa7a6SAndroid Build Coastguard Worker self._emit_node_specific_error( 1613*523fa7a6SAndroid Build Coastguard Worker self.node, 1614*523fa7a6SAndroid Build Coastguard Worker f"Returning {arg} is not yet supported in the emitter.", 1615*523fa7a6SAndroid Build Coastguard Worker ) 1616*523fa7a6SAndroid Build Coastguard Worker ) 1617*523fa7a6SAndroid Build Coastguard Worker else: 1618*523fa7a6SAndroid Build Coastguard Worker # Every other output should already have its value emitted. 1619*523fa7a6SAndroid Build Coastguard Worker # They should only be abstract IDs at this point. 1620*523fa7a6SAndroid Build Coastguard Worker assert isinstance(arg, _AbstractValue) 1621*523fa7a6SAndroid Build Coastguard Worker self.outputs.append(arg.id) 1622*523fa7a6SAndroid Build Coastguard Worker 1623*523fa7a6SAndroid Build Coastguard Worker def plan(self) -> ExecutionPlan: 1624*523fa7a6SAndroid Build Coastguard Worker """Returns the execution plan emitted from this entry point.""" 1625*523fa7a6SAndroid Build Coastguard Worker return ExecutionPlan( 1626*523fa7a6SAndroid Build Coastguard Worker name=self.name, 1627*523fa7a6SAndroid Build Coastguard Worker values=self.emitter_state.values, 1628*523fa7a6SAndroid Build Coastguard Worker inputs=self.inputs, 1629*523fa7a6SAndroid Build Coastguard Worker outputs=self.outputs, 1630*523fa7a6SAndroid Build Coastguard Worker chains=[self.chain], 1631*523fa7a6SAndroid Build Coastguard Worker operators=self.emitter_state.operators, 1632*523fa7a6SAndroid Build Coastguard Worker delegates=self.emitter_state.delegates, 1633*523fa7a6SAndroid Build Coastguard Worker # non_const_buffer_sizes field is set by the memory_planning_pass. In case the field is 1634*523fa7a6SAndroid Build Coastguard Worker # missing in scenarios like unit test that does not enable memory planning, assume an 1635*523fa7a6SAndroid Build Coastguard Worker # empty list. 1636*523fa7a6SAndroid Build Coastguard Worker non_const_buffer_sizes=typing.cast( 1637*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorB... 1638*523fa7a6SAndroid Build Coastguard Worker List[int], self.module.meta["non_const_buffer_sizes"] 1639*523fa7a6SAndroid Build Coastguard Worker ), 1640*523fa7a6SAndroid Build Coastguard Worker container_meta_type=self.container_meta_type, 1641*523fa7a6SAndroid Build Coastguard Worker ) 1642