1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# NOTE: This is a placeholder for iterating on export serialization schema design. 8*523fa7a6SAndroid Build Coastguard Worker# Anything is subject to change and no guarantee is provided at this point. 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field 11*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum 12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.serde.schema as export_schema 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.serde.union import _Union 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker# NOTE: Please update this value if any modifications are made to the schema 19*523fa7a6SAndroid Build Coastguard WorkerSCHEMA_VERSION = (5, 3) 20*523fa7a6SAndroid Build Coastguard WorkerTREESPEC_VERSION = 1 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Workerclass ScalarType(IntEnum): 24*523fa7a6SAndroid Build Coastguard Worker UNKNOWN = 0 25*523fa7a6SAndroid Build Coastguard Worker BYTE = 1 26*523fa7a6SAndroid Build Coastguard Worker CHAR = 2 27*523fa7a6SAndroid Build Coastguard Worker SHORT = 3 28*523fa7a6SAndroid Build Coastguard Worker INT = 4 29*523fa7a6SAndroid Build Coastguard Worker LONG = 5 30*523fa7a6SAndroid Build Coastguard Worker HALF = 6 31*523fa7a6SAndroid Build Coastguard Worker FLOAT = 7 32*523fa7a6SAndroid Build Coastguard Worker DOUBLE = 8 33*523fa7a6SAndroid Build Coastguard Worker COMPLEXHALF = 9 34*523fa7a6SAndroid Build Coastguard Worker COMPLEXFLOAT = 10 35*523fa7a6SAndroid Build Coastguard Worker COMPLEXDOUBLE = 11 36*523fa7a6SAndroid Build Coastguard Worker BOOL = 12 37*523fa7a6SAndroid Build Coastguard Worker BFLOAT16 = 13 38*523fa7a6SAndroid Build Coastguard Worker UINT16 = 14 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Workerclass Layout(IntEnum): 41*523fa7a6SAndroid Build Coastguard Worker Unknown = 0 42*523fa7a6SAndroid Build Coastguard Worker SparseCoo = 1 43*523fa7a6SAndroid Build Coastguard Worker SparseCsr = 2 44*523fa7a6SAndroid Build Coastguard Worker SparseCsc = 3 45*523fa7a6SAndroid Build Coastguard Worker SparseBsr = 4 46*523fa7a6SAndroid Build Coastguard Worker SparseBsc = 5 47*523fa7a6SAndroid Build Coastguard Worker _mkldnn = 6 48*523fa7a6SAndroid Build Coastguard Worker Strided = 7 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Workerclass MemoryFormat(IntEnum): 52*523fa7a6SAndroid Build Coastguard Worker Unknown = 0 53*523fa7a6SAndroid Build Coastguard Worker ContiguousFormat = 1 54*523fa7a6SAndroid Build Coastguard Worker ChannelsLast = 2 55*523fa7a6SAndroid Build Coastguard Worker ChannelsLast3d = 3 56*523fa7a6SAndroid Build Coastguard Worker PreserveFormat = 4 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Worker@dataclass 60*523fa7a6SAndroid Build Coastguard Workerclass Device: 61*523fa7a6SAndroid Build Coastguard Worker type: str 62*523fa7a6SAndroid Build Coastguard Worker index: Optional[int] = None 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 66*523fa7a6SAndroid Build Coastguard Workerclass SymExprHint(_Union): 67*523fa7a6SAndroid Build Coastguard Worker as_int: int 68*523fa7a6SAndroid Build Coastguard Worker as_float: float 69*523fa7a6SAndroid Build Coastguard Worker as_bool: bool 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker# This is for storing the symbolic expressions behind symints/symfloats/symbools 73*523fa7a6SAndroid Build Coastguard Worker# For example, we can get something like 74*523fa7a6SAndroid Build Coastguard Worker# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) 75*523fa7a6SAndroid Build Coastguard Worker# if we also have the hint that s0 and s1 are both 2. 76*523fa7a6SAndroid Build Coastguard Worker@dataclass 77*523fa7a6SAndroid Build Coastguard Workerclass SymExpr: 78*523fa7a6SAndroid Build Coastguard Worker expr_str: str 79*523fa7a6SAndroid Build Coastguard Worker hint: Optional[SymExprHint] = None 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 83*523fa7a6SAndroid Build Coastguard Workerclass SymInt(_Union): 84*523fa7a6SAndroid Build Coastguard Worker as_expr: SymExpr 85*523fa7a6SAndroid Build Coastguard Worker as_int: int 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 89*523fa7a6SAndroid Build Coastguard Workerclass SymBool(_Union): 90*523fa7a6SAndroid Build Coastguard Worker as_expr: SymExpr 91*523fa7a6SAndroid Build Coastguard Worker as_bool: bool 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker@dataclass 95*523fa7a6SAndroid Build Coastguard Workerclass TensorMeta: 96*523fa7a6SAndroid Build Coastguard Worker dtype: ScalarType 97*523fa7a6SAndroid Build Coastguard Worker sizes: List[SymInt] 98*523fa7a6SAndroid Build Coastguard Worker requires_grad: bool 99*523fa7a6SAndroid Build Coastguard Worker device: Device 100*523fa7a6SAndroid Build Coastguard Worker strides: List[SymInt] 101*523fa7a6SAndroid Build Coastguard Worker storage_offset: SymInt 102*523fa7a6SAndroid Build Coastguard Worker layout: Layout 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Worker 105*523fa7a6SAndroid Build Coastguard Worker# In most cases we will use the "as_name" field to store arguments which are 106*523fa7a6SAndroid Build Coastguard Worker# SymInts. 107*523fa7a6SAndroid Build Coastguard Worker# The "as_int" field is used in the case where we have a list containing a mix 108*523fa7a6SAndroid Build Coastguard Worker# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to 109*523fa7a6SAndroid Build Coastguard Worker# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints 110*523fa7a6SAndroid Build Coastguard Worker# to the "as_int" field. 111*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 112*523fa7a6SAndroid Build Coastguard Workerclass SymIntArgument(_Union): 113*523fa7a6SAndroid Build Coastguard Worker as_name: str 114*523fa7a6SAndroid Build Coastguard Worker as_int: int 115*523fa7a6SAndroid Build Coastguard Worker 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Worker# In most cases we will use the "as_name" field to store arguments which are 118*523fa7a6SAndroid Build Coastguard Worker# SymBools. 119*523fa7a6SAndroid Build Coastguard Worker# The "as_bool" field is used in the case where we have a list containing a mix 120*523fa7a6SAndroid Build Coastguard Worker# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to 121*523fa7a6SAndroid Build Coastguard Worker# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools 122*523fa7a6SAndroid Build Coastguard Worker# to the "as_bool" field. 123*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 124*523fa7a6SAndroid Build Coastguard Workerclass SymBoolArgument(_Union): 125*523fa7a6SAndroid Build Coastguard Worker as_name: str 126*523fa7a6SAndroid Build Coastguard Worker as_bool: bool 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker 129*523fa7a6SAndroid Build Coastguard Worker@dataclass 130*523fa7a6SAndroid Build Coastguard Workerclass TensorArgument: 131*523fa7a6SAndroid Build Coastguard Worker name: str 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Worker@dataclass 135*523fa7a6SAndroid Build Coastguard Workerclass TokenArgument: 136*523fa7a6SAndroid Build Coastguard Worker name: str 137*523fa7a6SAndroid Build Coastguard Worker 138*523fa7a6SAndroid Build Coastguard Worker 139*523fa7a6SAndroid Build Coastguard Worker# This is use for storing the contents of a list which contain optional tensors 140*523fa7a6SAndroid Build Coastguard Worker# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the 141*523fa7a6SAndroid Build Coastguard Worker# type List[OptionalTensorArgument], with tensor values seiralized to the 142*523fa7a6SAndroid Build Coastguard Worker# "as_tensor" field, and None values serialized to the "as_none" field. 143*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 144*523fa7a6SAndroid Build Coastguard Workerclass OptionalTensorArgument(_Union): 145*523fa7a6SAndroid Build Coastguard Worker as_tensor: TensorArgument 146*523fa7a6SAndroid Build Coastguard Worker as_none: Tuple[()] 147*523fa7a6SAndroid Build Coastguard Worker 148*523fa7a6SAndroid Build Coastguard Worker 149*523fa7a6SAndroid Build Coastguard Worker@dataclass 150*523fa7a6SAndroid Build Coastguard Workerclass GraphArgument: 151*523fa7a6SAndroid Build Coastguard Worker name: str 152*523fa7a6SAndroid Build Coastguard Worker graph: "Graph" 153*523fa7a6SAndroid Build Coastguard Worker 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker@dataclass 156*523fa7a6SAndroid Build Coastguard Workerclass CustomObjArgument: 157*523fa7a6SAndroid Build Coastguard Worker name: str 158*523fa7a6SAndroid Build Coastguard Worker class_fqn: str 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker# This is actually a union type 162*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 163*523fa7a6SAndroid Build Coastguard Workerclass Argument(_Union): 164*523fa7a6SAndroid Build Coastguard Worker as_none: Tuple[()] 165*523fa7a6SAndroid Build Coastguard Worker as_tensor: TensorArgument 166*523fa7a6SAndroid Build Coastguard Worker as_tensors: List[TensorArgument] 167*523fa7a6SAndroid Build Coastguard Worker as_int: int 168*523fa7a6SAndroid Build Coastguard Worker as_ints: List[int] 169*523fa7a6SAndroid Build Coastguard Worker as_float: float 170*523fa7a6SAndroid Build Coastguard Worker as_floats: List[float] 171*523fa7a6SAndroid Build Coastguard Worker as_string: str 172*523fa7a6SAndroid Build Coastguard Worker as_strings: List[str] 173*523fa7a6SAndroid Build Coastguard Worker as_sym_int: SymIntArgument 174*523fa7a6SAndroid Build Coastguard Worker as_sym_ints: List[SymIntArgument] 175*523fa7a6SAndroid Build Coastguard Worker as_scalar_type: ScalarType 176*523fa7a6SAndroid Build Coastguard Worker as_memory_format: MemoryFormat 177*523fa7a6SAndroid Build Coastguard Worker as_layout: Layout 178*523fa7a6SAndroid Build Coastguard Worker as_device: Device 179*523fa7a6SAndroid Build Coastguard Worker as_bool: bool 180*523fa7a6SAndroid Build Coastguard Worker as_bools: List[bool] 181*523fa7a6SAndroid Build Coastguard Worker as_sym_bool: SymBoolArgument 182*523fa7a6SAndroid Build Coastguard Worker as_sym_bools: List[SymBoolArgument] 183*523fa7a6SAndroid Build Coastguard Worker as_graph: GraphArgument 184*523fa7a6SAndroid Build Coastguard Worker as_optional_tensors: List[OptionalTensorArgument] 185*523fa7a6SAndroid Build Coastguard Worker as_custom_obj: CustomObjArgument 186*523fa7a6SAndroid Build Coastguard Worker as_operator: str 187*523fa7a6SAndroid Build Coastguard Worker 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker@dataclass 190*523fa7a6SAndroid Build Coastguard Workerclass NamedArgument: 191*523fa7a6SAndroid Build Coastguard Worker # Argument name from the operator schema 192*523fa7a6SAndroid Build Coastguard Worker name: str 193*523fa7a6SAndroid Build Coastguard Worker arg: Argument 194*523fa7a6SAndroid Build Coastguard Worker 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker@dataclass 197*523fa7a6SAndroid Build Coastguard Workerclass Node: 198*523fa7a6SAndroid Build Coastguard Worker target: str 199*523fa7a6SAndroid Build Coastguard Worker inputs: List[NamedArgument] 200*523fa7a6SAndroid Build Coastguard Worker outputs: List[Argument] 201*523fa7a6SAndroid Build Coastguard Worker metadata: Dict[str, str] 202*523fa7a6SAndroid Build Coastguard Worker 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker@dataclass 205*523fa7a6SAndroid Build Coastguard Workerclass Graph: 206*523fa7a6SAndroid Build Coastguard Worker inputs: List[Argument] 207*523fa7a6SAndroid Build Coastguard Worker outputs: List[Argument] 208*523fa7a6SAndroid Build Coastguard Worker nodes: List[Node] 209*523fa7a6SAndroid Build Coastguard Worker tensor_values: Dict[str, TensorMeta] 210*523fa7a6SAndroid Build Coastguard Worker sym_int_values: Dict[str, SymInt] 211*523fa7a6SAndroid Build Coastguard Worker sym_bool_values: Dict[str, SymBool] 212*523fa7a6SAndroid Build Coastguard Worker # This is for deserializing the submodule graphs from higher order ops 213*523fa7a6SAndroid Build Coastguard Worker # (ex. cond, map) where single tensor returns will just return a single 214*523fa7a6SAndroid Build Coastguard Worker # tensor, rather than following export schema and returning a singleton 215*523fa7a6SAndroid Build Coastguard Worker # list. 216*523fa7a6SAndroid Build Coastguard Worker is_single_tensor_return: bool = False 217*523fa7a6SAndroid Build Coastguard Worker custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) 218*523fa7a6SAndroid Build Coastguard Worker 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker@dataclass 221*523fa7a6SAndroid Build Coastguard Workerclass UserInputSpec: 222*523fa7a6SAndroid Build Coastguard Worker # Actually, only tensors and SymInts are allowed here 223*523fa7a6SAndroid Build Coastguard Worker arg: Argument 224*523fa7a6SAndroid Build Coastguard Worker 225*523fa7a6SAndroid Build Coastguard Worker 226*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 227*523fa7a6SAndroid Build Coastguard Workerclass ConstantValue(_Union): 228*523fa7a6SAndroid Build Coastguard Worker as_none: Tuple[()] 229*523fa7a6SAndroid Build Coastguard Worker as_int: int 230*523fa7a6SAndroid Build Coastguard Worker as_float: float 231*523fa7a6SAndroid Build Coastguard Worker as_string: str 232*523fa7a6SAndroid Build Coastguard Worker as_bool: bool 233*523fa7a6SAndroid Build Coastguard Worker 234*523fa7a6SAndroid Build Coastguard Worker 235*523fa7a6SAndroid Build Coastguard Worker@dataclass 236*523fa7a6SAndroid Build Coastguard Workerclass ConstantInputSpec: 237*523fa7a6SAndroid Build Coastguard Worker name: str 238*523fa7a6SAndroid Build Coastguard Worker value: ConstantValue 239*523fa7a6SAndroid Build Coastguard Worker 240*523fa7a6SAndroid Build Coastguard Worker 241*523fa7a6SAndroid Build Coastguard Worker@dataclass 242*523fa7a6SAndroid Build Coastguard Workerclass InputToParameterSpec: 243*523fa7a6SAndroid Build Coastguard Worker arg: TensorArgument 244*523fa7a6SAndroid Build Coastguard Worker parameter_name: str 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker 247*523fa7a6SAndroid Build Coastguard Worker@dataclass 248*523fa7a6SAndroid Build Coastguard Workerclass InputToBufferSpec: 249*523fa7a6SAndroid Build Coastguard Worker arg: TensorArgument 250*523fa7a6SAndroid Build Coastguard Worker buffer_name: str 251*523fa7a6SAndroid Build Coastguard Worker persistent: bool 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Worker 254*523fa7a6SAndroid Build Coastguard Worker@dataclass 255*523fa7a6SAndroid Build Coastguard Workerclass InputToTensorConstantSpec: 256*523fa7a6SAndroid Build Coastguard Worker arg: TensorArgument 257*523fa7a6SAndroid Build Coastguard Worker tensor_constant_name: str 258*523fa7a6SAndroid Build Coastguard Worker 259*523fa7a6SAndroid Build Coastguard Worker 260*523fa7a6SAndroid Build Coastguard Worker@dataclass 261*523fa7a6SAndroid Build Coastguard Workerclass InputToCustomObjSpec: 262*523fa7a6SAndroid Build Coastguard Worker arg: CustomObjArgument 263*523fa7a6SAndroid Build Coastguard Worker custom_obj_name: str 264*523fa7a6SAndroid Build Coastguard Worker 265*523fa7a6SAndroid Build Coastguard Worker 266*523fa7a6SAndroid Build Coastguard Worker@dataclass 267*523fa7a6SAndroid Build Coastguard Workerclass InputTokenSpec: 268*523fa7a6SAndroid Build Coastguard Worker arg: TokenArgument 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Worker 271*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 272*523fa7a6SAndroid Build Coastguard Workerclass InputSpec(_Union): 273*523fa7a6SAndroid Build Coastguard Worker user_input: UserInputSpec 274*523fa7a6SAndroid Build Coastguard Worker parameter: InputToParameterSpec 275*523fa7a6SAndroid Build Coastguard Worker buffer: InputToBufferSpec 276*523fa7a6SAndroid Build Coastguard Worker tensor_constant: InputToTensorConstantSpec 277*523fa7a6SAndroid Build Coastguard Worker custom_obj: InputToCustomObjSpec 278*523fa7a6SAndroid Build Coastguard Worker token: InputTokenSpec 279*523fa7a6SAndroid Build Coastguard Worker constant_input: ConstantInputSpec 280*523fa7a6SAndroid Build Coastguard Worker 281*523fa7a6SAndroid Build Coastguard Worker 282*523fa7a6SAndroid Build Coastguard Worker@dataclass 283*523fa7a6SAndroid Build Coastguard Workerclass UserOutputSpec: 284*523fa7a6SAndroid Build Coastguard Worker arg: Argument 285*523fa7a6SAndroid Build Coastguard Worker 286*523fa7a6SAndroid Build Coastguard Worker 287*523fa7a6SAndroid Build Coastguard Worker@dataclass 288*523fa7a6SAndroid Build Coastguard Workerclass LossOutputSpec: 289*523fa7a6SAndroid Build Coastguard Worker arg: TensorArgument 290*523fa7a6SAndroid Build Coastguard Worker 291*523fa7a6SAndroid Build Coastguard Worker 292*523fa7a6SAndroid Build Coastguard Worker@dataclass 293*523fa7a6SAndroid Build Coastguard Workerclass BufferMutationSpec: 294*523fa7a6SAndroid Build Coastguard Worker arg: TensorArgument 295*523fa7a6SAndroid Build Coastguard Worker buffer_name: str 296*523fa7a6SAndroid Build Coastguard Worker 297*523fa7a6SAndroid Build Coastguard Worker 298*523fa7a6SAndroid Build Coastguard Worker@dataclass 299*523fa7a6SAndroid Build Coastguard Workerclass GradientToParameterSpec: 300*523fa7a6SAndroid Build Coastguard Worker arg: TensorArgument 301*523fa7a6SAndroid Build Coastguard Worker parameter_name: str 302*523fa7a6SAndroid Build Coastguard Worker 303*523fa7a6SAndroid Build Coastguard Worker 304*523fa7a6SAndroid Build Coastguard Worker@dataclass 305*523fa7a6SAndroid Build Coastguard Workerclass GradientToUserInputSpec: 306*523fa7a6SAndroid Build Coastguard Worker arg: TensorArgument 307*523fa7a6SAndroid Build Coastguard Worker user_input_name: str 308*523fa7a6SAndroid Build Coastguard Worker 309*523fa7a6SAndroid Build Coastguard Worker 310*523fa7a6SAndroid Build Coastguard Worker@dataclass 311*523fa7a6SAndroid Build Coastguard Workerclass UserInputMutationSpec: 312*523fa7a6SAndroid Build Coastguard Worker arg: TensorArgument 313*523fa7a6SAndroid Build Coastguard Worker user_input_name: str 314*523fa7a6SAndroid Build Coastguard Worker 315*523fa7a6SAndroid Build Coastguard Worker 316*523fa7a6SAndroid Build Coastguard Worker@dataclass 317*523fa7a6SAndroid Build Coastguard Workerclass OutputTokenSpec: 318*523fa7a6SAndroid Build Coastguard Worker arg: TokenArgument 319*523fa7a6SAndroid Build Coastguard Worker 320*523fa7a6SAndroid Build Coastguard Worker 321*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False) 322*523fa7a6SAndroid Build Coastguard Workerclass OutputSpec(_Union): 323*523fa7a6SAndroid Build Coastguard Worker user_output: UserOutputSpec 324*523fa7a6SAndroid Build Coastguard Worker loss_output: LossOutputSpec 325*523fa7a6SAndroid Build Coastguard Worker buffer_mutation: BufferMutationSpec 326*523fa7a6SAndroid Build Coastguard Worker gradient_to_parameter: GradientToParameterSpec 327*523fa7a6SAndroid Build Coastguard Worker gradient_to_user_input: GradientToUserInputSpec 328*523fa7a6SAndroid Build Coastguard Worker user_input_mutation: UserInputMutationSpec 329*523fa7a6SAndroid Build Coastguard Worker token: OutputTokenSpec 330*523fa7a6SAndroid Build Coastguard Worker 331*523fa7a6SAndroid Build Coastguard Worker 332*523fa7a6SAndroid Build Coastguard Worker@dataclass 333*523fa7a6SAndroid Build Coastguard Workerclass GraphSignature: 334*523fa7a6SAndroid Build Coastguard Worker input_specs: List[InputSpec] 335*523fa7a6SAndroid Build Coastguard Worker output_specs: List[OutputSpec] 336*523fa7a6SAndroid Build Coastguard Worker 337*523fa7a6SAndroid Build Coastguard Worker 338*523fa7a6SAndroid Build Coastguard Worker@dataclass 339*523fa7a6SAndroid Build Coastguard Workerclass RangeConstraint: 340*523fa7a6SAndroid Build Coastguard Worker min_val: int 341*523fa7a6SAndroid Build Coastguard Worker max_val: int 342*523fa7a6SAndroid Build Coastguard Worker 343*523fa7a6SAndroid Build Coastguard Worker 344*523fa7a6SAndroid Build Coastguard Worker@dataclass 345*523fa7a6SAndroid Build Coastguard Workerclass ModuleCallSignature: 346*523fa7a6SAndroid Build Coastguard Worker inputs: List[Argument] 347*523fa7a6SAndroid Build Coastguard Worker outputs: List[Argument] 348*523fa7a6SAndroid Build Coastguard Worker 349*523fa7a6SAndroid Build Coastguard Worker # These are serialized by calling pytree.treespec_loads 350*523fa7a6SAndroid Build Coastguard Worker # And deserialized by calling pytree.treespec_dumps 351*523fa7a6SAndroid Build Coastguard Worker in_spec: str 352*523fa7a6SAndroid Build Coastguard Worker out_spec: str 353*523fa7a6SAndroid Build Coastguard Worker 354*523fa7a6SAndroid Build Coastguard Worker 355*523fa7a6SAndroid Build Coastguard Worker@dataclass 356*523fa7a6SAndroid Build Coastguard Workerclass ModuleCallEntry: 357*523fa7a6SAndroid Build Coastguard Worker fqn: str 358*523fa7a6SAndroid Build Coastguard Worker signature: Optional[ModuleCallSignature] = None 359*523fa7a6SAndroid Build Coastguard Worker 360*523fa7a6SAndroid Build Coastguard Worker 361*523fa7a6SAndroid Build Coastguard Worker@dataclass 362*523fa7a6SAndroid Build Coastguard Workerclass GraphModule: 363*523fa7a6SAndroid Build Coastguard Worker graph: Graph 364*523fa7a6SAndroid Build Coastguard Worker signature: GraphSignature 365*523fa7a6SAndroid Build Coastguard Worker # This is used for unflattening, by tracking the calling structure of all of 366*523fa7a6SAndroid Build Coastguard Worker # the modules in order to unflatten the modules back to the eager calling 367*523fa7a6SAndroid Build Coastguard Worker # conventions. 368*523fa7a6SAndroid Build Coastguard Worker module_call_graph: List[ModuleCallEntry] 369*523fa7a6SAndroid Build Coastguard Worker 370*523fa7a6SAndroid Build Coastguard Worker 371*523fa7a6SAndroid Build Coastguard Worker# Invariant: Every time a change is made to the schema, one of the versions 372*523fa7a6SAndroid Build Coastguard Worker# should be upadted. 373*523fa7a6SAndroid Build Coastguard Worker@dataclass 374*523fa7a6SAndroid Build Coastguard Workerclass SchemaVersion: 375*523fa7a6SAndroid Build Coastguard Worker major: int # Major version number is bumped every time a breaking change is made. 376*523fa7a6SAndroid Build Coastguard Worker minor: int # Minor version number is bumped when a compatible change is made. 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Worker 379*523fa7a6SAndroid Build Coastguard Worker@dataclass 380*523fa7a6SAndroid Build Coastguard Workerclass ExportedProgram: 381*523fa7a6SAndroid Build Coastguard Worker graph_module: GraphModule 382*523fa7a6SAndroid Build Coastguard Worker # Key is the opset namespace (ex. aten), and value is the version number 383*523fa7a6SAndroid Build Coastguard Worker opset_version: Dict[str, int] 384*523fa7a6SAndroid Build Coastguard Worker range_constraints: Dict[str, RangeConstraint] 385*523fa7a6SAndroid Build Coastguard Worker schema_version: SchemaVersion 386*523fa7a6SAndroid Build Coastguard Worker dialect: str 387*523fa7a6SAndroid Build Coastguard Worker verifiers: List[str] = field(default_factory=list) 388*523fa7a6SAndroid Build Coastguard Worker dialect: str = "" # TODO deprecated 389*523fa7a6SAndroid Build Coastguard Worker 390*523fa7a6SAndroid Build Coastguard Worker 391*523fa7a6SAndroid Build Coastguard Worker@dataclass 392*523fa7a6SAndroid Build Coastguard Workerclass CompileSpec: 393*523fa7a6SAndroid Build Coastguard Worker key: str 394*523fa7a6SAndroid Build Coastguard Worker value: str 395*523fa7a6SAndroid Build Coastguard Worker 396*523fa7a6SAndroid Build Coastguard Worker 397*523fa7a6SAndroid Build Coastguard Worker@dataclass 398*523fa7a6SAndroid Build Coastguard Workerclass LoweredBackendModule: 399*523fa7a6SAndroid Build Coastguard Worker backend_id: str 400*523fa7a6SAndroid Build Coastguard Worker processed_bytes: str 401*523fa7a6SAndroid Build Coastguard Worker compile_specs: List[CompileSpec] 402*523fa7a6SAndroid Build Coastguard Worker original_module: export_schema.ExportedProgram 403*523fa7a6SAndroid Build Coastguard Worker original_state_dict: str 404*523fa7a6SAndroid Build Coastguard Worker original_constants: str 405