1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from dataclasses import dataclass 10from enum import IntEnum 11from typing import List, Optional, Union 12 13from executorch.exir.backend.compile_spec_schema import CompileSpec 14 15from executorch.exir.scalar_type import ScalarType 16 17 18@dataclass 19class AllocationDetails: 20 memory_id: int 21 # Low 32 bits 22 memory_offset_low: int 23 # High 32 bits (typically zero) 24 memory_offset_high: int 25 26 @property 27 def memory_offset(self) -> int: 28 return self.memory_offset_low | (self.memory_offset_high << 32) 29 30 31@dataclass 32class OptionalTensorList: 33 items: List[int] 34 35 36class TensorShapeDynamism(IntEnum): 37 """ 38 Check program.fbs for explanations of this enum. 39 """ 40 41 STATIC = 0 42 DYNAMIC_BOUND = 1 43 DYNAMIC_UNBOUND = 2 44 45 46@dataclass 47class ExtraTensorInfo: 48 """ 49 Check program.fbs for explanations of this enum. 50 """ 51 52 mutable_data_segments_idx: Optional[int] = None 53 fully_qualified_name: Optional[str] = None 54 55 56@dataclass 57class Tensor: 58 scalar_type: ScalarType 59 storage_offset: int 60 sizes: List[int] 61 dim_order: List[bytes] 62 requires_grad: bool 63 layout: int 64 data_buffer_idx: int 65 allocation_info: Optional[AllocationDetails] 66 67 # check program.fbs for explanations. 68 shape_dynamism: TensorShapeDynamism 69 extra_tensor_info: Optional[ExtraTensorInfo] = None 70 71 72@dataclass 73class Null: 74 pass 75 76 77@dataclass 78class Int: 79 int_val: int 80 81 82@dataclass 83class Bool: 84 bool_val: bool 85 86 87@dataclass 88class Double: 89 double_val: Union[float, str] 90 91 def __init__(self, double_val: float) -> None: 92 if double_val == float("inf"): 93 self.double_val = "inf" 94 elif double_val == float("-inf"): 95 self.double_val = "-inf" 96 else: 97 self.double_val = double_val 98 99 def __post_init__(self) -> None: 100 if isinstance(self.double_val, str): 101 assert self.double_val in ["inf", "-inf"] 102 else: 103 assert isinstance(self.double_val, float) 104 assert not self.double_val == float("inf") 105 assert not self.double_val == float("-inf") 106 107 108@dataclass 109class String: 110 string_val: str 111 112 113@dataclass 114class ContainerMetadata: 115 encoded_inp_str: str 116 encoded_out_str: str 117 118 119@dataclass 120class IntList: 121 items: List[int] 122 123 124@dataclass 125class DoubleList: 126 items: List[float] 127 128 129@dataclass 130class BoolList: 131 items: List[bool] 132 133 134@dataclass 135class TensorList: 136 items: List[int] 137 138 139KernelTypes = Union[ 140 Int, 141 Double, 142 Bool, 143 String, 144 Tensor, 145 IntList, 146 BoolList, 147 DoubleList, 148 TensorList, 149 Null, 150 OptionalTensorList, 151] 152 153 154@dataclass 155class EValue: 156 # Union types must be specified as strings so DataclassEncoder can see them. 157 val: "KernelTypes" 158 159 160@dataclass 161class Buffer: 162 storage: bytes 163 164 165@dataclass 166class BackendDelegateInlineData: 167 data: bytes 168 169 170@dataclass 171class KernelCall: 172 op_index: int 173 args: List[int] 174 175 176@dataclass 177class DelegateCall: 178 delegate_index: int 179 args: List[int] 180 181 182@dataclass 183class MoveCall: 184 move_from: int 185 move_to: int 186 187 188@dataclass 189class JumpFalseCall: 190 cond_value_index: int 191 destination_instruction: int 192 193 194@dataclass 195class FreeCall: 196 value_index: int 197 198 199InstructionArguments = Union[ 200 KernelCall, 201 DelegateCall, 202 MoveCall, 203 JumpFalseCall, 204 FreeCall, 205] 206 207 208@dataclass 209class Instruction: 210 instr_args: "InstructionArguments" 211 212 213@dataclass 214class Frame: 215 filename: str 216 lineno: int 217 name: str 218 context: str 219 220 221@dataclass 222class FrameList: 223 items: List[Frame] 224 225 226class DataLocation(IntEnum): 227 INLINE = 0 228 SEGMENT = 1 229 230 231@dataclass 232class BackendDelegateDataReference: 233 location: DataLocation 234 index: int 235 236 237@dataclass 238class BackendDelegate: 239 id: str 240 processed: BackendDelegateDataReference 241 compile_specs: List[CompileSpec] 242 243 244@dataclass 245class Chain: 246 inputs: List[int] 247 outputs: List[int] 248 instructions: List[Instruction] 249 stacktrace: Optional[List[FrameList]] 250 251 252@dataclass 253class Operator: 254 name: str 255 overload: str 256 257 258@dataclass 259class ExecutionPlan: 260 name: str 261 container_meta_type: ContainerMetadata 262 values: List[EValue] 263 inputs: List[int] 264 outputs: List[int] 265 chains: List[Chain] 266 operators: List[Operator] 267 delegates: List[BackendDelegate] 268 # the list index is memory buffer id, the value is the memory buffer size. 269 # memory_buffer_id == 0 is special and is for constant memory buffer. 270 # Runtime should use the len(constant_buffer) as the ground truch of 271 # constant memory buffer size, and ignore non_const_buffer_sizes[0]. 272 non_const_buffer_sizes: List[int] 273 274 275@dataclass 276class DataSegment: 277 offset: int 278 size: int 279 280 281@dataclass 282class SubsegmentOffsets: 283 segment_index: int 284 offsets: List[int] 285 286 287@dataclass 288class Program: 289 version: int 290 execution_plan: List[ExecutionPlan] 291 constant_buffer: List[Buffer] 292 backend_delegate_data: List[BackendDelegateInlineData] 293 segments: List[DataSegment] 294 constant_segment: SubsegmentOffsets 295 mutable_data_segments: Optional[List[SubsegmentOffsets]] = None 296