xref: /aosp_15_r20/external/executorch/exir/serde/schema.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# NOTE: This is a placeholder for iterating on export serialization schema design.
8*523fa7a6SAndroid Build Coastguard Worker#       Anything is subject to change and no guarantee is provided at this point.
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field
11*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.serde.schema as export_schema
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.serde.union import _Union
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker# NOTE: Please update this value if any modifications are made to the schema
19*523fa7a6SAndroid Build Coastguard WorkerSCHEMA_VERSION = (5, 3)
20*523fa7a6SAndroid Build Coastguard WorkerTREESPEC_VERSION = 1
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Workerclass ScalarType(IntEnum):
24*523fa7a6SAndroid Build Coastguard Worker    UNKNOWN = 0
25*523fa7a6SAndroid Build Coastguard Worker    BYTE = 1
26*523fa7a6SAndroid Build Coastguard Worker    CHAR = 2
27*523fa7a6SAndroid Build Coastguard Worker    SHORT = 3
28*523fa7a6SAndroid Build Coastguard Worker    INT = 4
29*523fa7a6SAndroid Build Coastguard Worker    LONG = 5
30*523fa7a6SAndroid Build Coastguard Worker    HALF = 6
31*523fa7a6SAndroid Build Coastguard Worker    FLOAT = 7
32*523fa7a6SAndroid Build Coastguard Worker    DOUBLE = 8
33*523fa7a6SAndroid Build Coastguard Worker    COMPLEXHALF = 9
34*523fa7a6SAndroid Build Coastguard Worker    COMPLEXFLOAT = 10
35*523fa7a6SAndroid Build Coastguard Worker    COMPLEXDOUBLE = 11
36*523fa7a6SAndroid Build Coastguard Worker    BOOL = 12
37*523fa7a6SAndroid Build Coastguard Worker    BFLOAT16 = 13
38*523fa7a6SAndroid Build Coastguard Worker    UINT16 = 14
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Workerclass Layout(IntEnum):
41*523fa7a6SAndroid Build Coastguard Worker    Unknown = 0
42*523fa7a6SAndroid Build Coastguard Worker    SparseCoo = 1
43*523fa7a6SAndroid Build Coastguard Worker    SparseCsr = 2
44*523fa7a6SAndroid Build Coastguard Worker    SparseCsc = 3
45*523fa7a6SAndroid Build Coastguard Worker    SparseBsr = 4
46*523fa7a6SAndroid Build Coastguard Worker    SparseBsc = 5
47*523fa7a6SAndroid Build Coastguard Worker    _mkldnn = 6
48*523fa7a6SAndroid Build Coastguard Worker    Strided = 7
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Workerclass MemoryFormat(IntEnum):
52*523fa7a6SAndroid Build Coastguard Worker    Unknown = 0
53*523fa7a6SAndroid Build Coastguard Worker    ContiguousFormat = 1
54*523fa7a6SAndroid Build Coastguard Worker    ChannelsLast = 2
55*523fa7a6SAndroid Build Coastguard Worker    ChannelsLast3d = 3
56*523fa7a6SAndroid Build Coastguard Worker    PreserveFormat = 4
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker@dataclass
60*523fa7a6SAndroid Build Coastguard Workerclass Device:
61*523fa7a6SAndroid Build Coastguard Worker    type: str
62*523fa7a6SAndroid Build Coastguard Worker    index: Optional[int] = None
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
66*523fa7a6SAndroid Build Coastguard Workerclass SymExprHint(_Union):
67*523fa7a6SAndroid Build Coastguard Worker    as_int: int
68*523fa7a6SAndroid Build Coastguard Worker    as_float: float
69*523fa7a6SAndroid Build Coastguard Worker    as_bool: bool
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker# This is for storing the symbolic expressions behind symints/symfloats/symbools
73*523fa7a6SAndroid Build Coastguard Worker# For example, we can get something like
74*523fa7a6SAndroid Build Coastguard Worker# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
75*523fa7a6SAndroid Build Coastguard Worker# if we also have the hint that s0 and s1 are both 2.
76*523fa7a6SAndroid Build Coastguard Worker@dataclass
77*523fa7a6SAndroid Build Coastguard Workerclass SymExpr:
78*523fa7a6SAndroid Build Coastguard Worker    expr_str: str
79*523fa7a6SAndroid Build Coastguard Worker    hint: Optional[SymExprHint] = None
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
83*523fa7a6SAndroid Build Coastguard Workerclass SymInt(_Union):
84*523fa7a6SAndroid Build Coastguard Worker    as_expr: SymExpr
85*523fa7a6SAndroid Build Coastguard Worker    as_int: int
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
89*523fa7a6SAndroid Build Coastguard Workerclass SymBool(_Union):
90*523fa7a6SAndroid Build Coastguard Worker    as_expr: SymExpr
91*523fa7a6SAndroid Build Coastguard Worker    as_bool: bool
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker@dataclass
95*523fa7a6SAndroid Build Coastguard Workerclass TensorMeta:
96*523fa7a6SAndroid Build Coastguard Worker    dtype: ScalarType
97*523fa7a6SAndroid Build Coastguard Worker    sizes: List[SymInt]
98*523fa7a6SAndroid Build Coastguard Worker    requires_grad: bool
99*523fa7a6SAndroid Build Coastguard Worker    device: Device
100*523fa7a6SAndroid Build Coastguard Worker    strides: List[SymInt]
101*523fa7a6SAndroid Build Coastguard Worker    storage_offset: SymInt
102*523fa7a6SAndroid Build Coastguard Worker    layout: Layout
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Worker# In most cases we will use the "as_name" field to store arguments which are
106*523fa7a6SAndroid Build Coastguard Worker# SymInts.
107*523fa7a6SAndroid Build Coastguard Worker# The "as_int" field is used in the case where we have a list containing a mix
108*523fa7a6SAndroid Build Coastguard Worker# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
109*523fa7a6SAndroid Build Coastguard Worker# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
110*523fa7a6SAndroid Build Coastguard Worker# to the "as_int" field.
111*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
112*523fa7a6SAndroid Build Coastguard Workerclass SymIntArgument(_Union):
113*523fa7a6SAndroid Build Coastguard Worker    as_name: str
114*523fa7a6SAndroid Build Coastguard Worker    as_int: int
115*523fa7a6SAndroid Build Coastguard Worker
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker# In most cases we will use the "as_name" field to store arguments which are
118*523fa7a6SAndroid Build Coastguard Worker# SymBools.
119*523fa7a6SAndroid Build Coastguard Worker# The "as_bool" field is used in the case where we have a list containing a mix
120*523fa7a6SAndroid Build Coastguard Worker# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
121*523fa7a6SAndroid Build Coastguard Worker# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
122*523fa7a6SAndroid Build Coastguard Worker# to the "as_bool" field.
123*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
124*523fa7a6SAndroid Build Coastguard Workerclass SymBoolArgument(_Union):
125*523fa7a6SAndroid Build Coastguard Worker    as_name: str
126*523fa7a6SAndroid Build Coastguard Worker    as_bool: bool
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker@dataclass
130*523fa7a6SAndroid Build Coastguard Workerclass TensorArgument:
131*523fa7a6SAndroid Build Coastguard Worker    name: str
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Worker@dataclass
135*523fa7a6SAndroid Build Coastguard Workerclass TokenArgument:
136*523fa7a6SAndroid Build Coastguard Worker    name: str
137*523fa7a6SAndroid Build Coastguard Worker
138*523fa7a6SAndroid Build Coastguard Worker
139*523fa7a6SAndroid Build Coastguard Worker# This is use for storing the contents of a list which contain optional tensors
140*523fa7a6SAndroid Build Coastguard Worker# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
141*523fa7a6SAndroid Build Coastguard Worker# type List[OptionalTensorArgument], with tensor values seiralized to the
142*523fa7a6SAndroid Build Coastguard Worker# "as_tensor" field, and None values serialized to the "as_none" field.
143*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
144*523fa7a6SAndroid Build Coastguard Workerclass OptionalTensorArgument(_Union):
145*523fa7a6SAndroid Build Coastguard Worker    as_tensor: TensorArgument
146*523fa7a6SAndroid Build Coastguard Worker    as_none: Tuple[()]
147*523fa7a6SAndroid Build Coastguard Worker
148*523fa7a6SAndroid Build Coastguard Worker
149*523fa7a6SAndroid Build Coastguard Worker@dataclass
150*523fa7a6SAndroid Build Coastguard Workerclass GraphArgument:
151*523fa7a6SAndroid Build Coastguard Worker    name: str
152*523fa7a6SAndroid Build Coastguard Worker    graph: "Graph"
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker@dataclass
156*523fa7a6SAndroid Build Coastguard Workerclass CustomObjArgument:
157*523fa7a6SAndroid Build Coastguard Worker    name: str
158*523fa7a6SAndroid Build Coastguard Worker    class_fqn: str
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker# This is actually a union type
162*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
163*523fa7a6SAndroid Build Coastguard Workerclass Argument(_Union):
164*523fa7a6SAndroid Build Coastguard Worker    as_none: Tuple[()]
165*523fa7a6SAndroid Build Coastguard Worker    as_tensor: TensorArgument
166*523fa7a6SAndroid Build Coastguard Worker    as_tensors: List[TensorArgument]
167*523fa7a6SAndroid Build Coastguard Worker    as_int: int
168*523fa7a6SAndroid Build Coastguard Worker    as_ints: List[int]
169*523fa7a6SAndroid Build Coastguard Worker    as_float: float
170*523fa7a6SAndroid Build Coastguard Worker    as_floats: List[float]
171*523fa7a6SAndroid Build Coastguard Worker    as_string: str
172*523fa7a6SAndroid Build Coastguard Worker    as_strings: List[str]
173*523fa7a6SAndroid Build Coastguard Worker    as_sym_int: SymIntArgument
174*523fa7a6SAndroid Build Coastguard Worker    as_sym_ints: List[SymIntArgument]
175*523fa7a6SAndroid Build Coastguard Worker    as_scalar_type: ScalarType
176*523fa7a6SAndroid Build Coastguard Worker    as_memory_format: MemoryFormat
177*523fa7a6SAndroid Build Coastguard Worker    as_layout: Layout
178*523fa7a6SAndroid Build Coastguard Worker    as_device: Device
179*523fa7a6SAndroid Build Coastguard Worker    as_bool: bool
180*523fa7a6SAndroid Build Coastguard Worker    as_bools: List[bool]
181*523fa7a6SAndroid Build Coastguard Worker    as_sym_bool: SymBoolArgument
182*523fa7a6SAndroid Build Coastguard Worker    as_sym_bools: List[SymBoolArgument]
183*523fa7a6SAndroid Build Coastguard Worker    as_graph: GraphArgument
184*523fa7a6SAndroid Build Coastguard Worker    as_optional_tensors: List[OptionalTensorArgument]
185*523fa7a6SAndroid Build Coastguard Worker    as_custom_obj: CustomObjArgument
186*523fa7a6SAndroid Build Coastguard Worker    as_operator: str
187*523fa7a6SAndroid Build Coastguard Worker
188*523fa7a6SAndroid Build Coastguard Worker
189*523fa7a6SAndroid Build Coastguard Worker@dataclass
190*523fa7a6SAndroid Build Coastguard Workerclass NamedArgument:
191*523fa7a6SAndroid Build Coastguard Worker    # Argument name from the operator schema
192*523fa7a6SAndroid Build Coastguard Worker    name: str
193*523fa7a6SAndroid Build Coastguard Worker    arg: Argument
194*523fa7a6SAndroid Build Coastguard Worker
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker@dataclass
197*523fa7a6SAndroid Build Coastguard Workerclass Node:
198*523fa7a6SAndroid Build Coastguard Worker    target: str
199*523fa7a6SAndroid Build Coastguard Worker    inputs: List[NamedArgument]
200*523fa7a6SAndroid Build Coastguard Worker    outputs: List[Argument]
201*523fa7a6SAndroid Build Coastguard Worker    metadata: Dict[str, str]
202*523fa7a6SAndroid Build Coastguard Worker
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker@dataclass
205*523fa7a6SAndroid Build Coastguard Workerclass Graph:
206*523fa7a6SAndroid Build Coastguard Worker    inputs: List[Argument]
207*523fa7a6SAndroid Build Coastguard Worker    outputs: List[Argument]
208*523fa7a6SAndroid Build Coastguard Worker    nodes: List[Node]
209*523fa7a6SAndroid Build Coastguard Worker    tensor_values: Dict[str, TensorMeta]
210*523fa7a6SAndroid Build Coastguard Worker    sym_int_values: Dict[str, SymInt]
211*523fa7a6SAndroid Build Coastguard Worker    sym_bool_values: Dict[str, SymBool]
212*523fa7a6SAndroid Build Coastguard Worker    # This is for deserializing the submodule graphs from higher order ops
213*523fa7a6SAndroid Build Coastguard Worker    # (ex. cond, map) where single tensor returns will just return a single
214*523fa7a6SAndroid Build Coastguard Worker    # tensor, rather than following export schema and returning a singleton
215*523fa7a6SAndroid Build Coastguard Worker    # list.
216*523fa7a6SAndroid Build Coastguard Worker    is_single_tensor_return: bool = False
217*523fa7a6SAndroid Build Coastguard Worker    custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
218*523fa7a6SAndroid Build Coastguard Worker
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker@dataclass
221*523fa7a6SAndroid Build Coastguard Workerclass UserInputSpec:
222*523fa7a6SAndroid Build Coastguard Worker    # Actually, only tensors and SymInts are allowed here
223*523fa7a6SAndroid Build Coastguard Worker    arg: Argument
224*523fa7a6SAndroid Build Coastguard Worker
225*523fa7a6SAndroid Build Coastguard Worker
226*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
227*523fa7a6SAndroid Build Coastguard Workerclass ConstantValue(_Union):
228*523fa7a6SAndroid Build Coastguard Worker    as_none: Tuple[()]
229*523fa7a6SAndroid Build Coastguard Worker    as_int: int
230*523fa7a6SAndroid Build Coastguard Worker    as_float: float
231*523fa7a6SAndroid Build Coastguard Worker    as_string: str
232*523fa7a6SAndroid Build Coastguard Worker    as_bool: bool
233*523fa7a6SAndroid Build Coastguard Worker
234*523fa7a6SAndroid Build Coastguard Worker
235*523fa7a6SAndroid Build Coastguard Worker@dataclass
236*523fa7a6SAndroid Build Coastguard Workerclass ConstantInputSpec:
237*523fa7a6SAndroid Build Coastguard Worker    name: str
238*523fa7a6SAndroid Build Coastguard Worker    value: ConstantValue
239*523fa7a6SAndroid Build Coastguard Worker
240*523fa7a6SAndroid Build Coastguard Worker
241*523fa7a6SAndroid Build Coastguard Worker@dataclass
242*523fa7a6SAndroid Build Coastguard Workerclass InputToParameterSpec:
243*523fa7a6SAndroid Build Coastguard Worker    arg: TensorArgument
244*523fa7a6SAndroid Build Coastguard Worker    parameter_name: str
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker
247*523fa7a6SAndroid Build Coastguard Worker@dataclass
248*523fa7a6SAndroid Build Coastguard Workerclass InputToBufferSpec:
249*523fa7a6SAndroid Build Coastguard Worker    arg: TensorArgument
250*523fa7a6SAndroid Build Coastguard Worker    buffer_name: str
251*523fa7a6SAndroid Build Coastguard Worker    persistent: bool
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Worker
254*523fa7a6SAndroid Build Coastguard Worker@dataclass
255*523fa7a6SAndroid Build Coastguard Workerclass InputToTensorConstantSpec:
256*523fa7a6SAndroid Build Coastguard Worker    arg: TensorArgument
257*523fa7a6SAndroid Build Coastguard Worker    tensor_constant_name: str
258*523fa7a6SAndroid Build Coastguard Worker
259*523fa7a6SAndroid Build Coastguard Worker
260*523fa7a6SAndroid Build Coastguard Worker@dataclass
261*523fa7a6SAndroid Build Coastguard Workerclass InputToCustomObjSpec:
262*523fa7a6SAndroid Build Coastguard Worker    arg: CustomObjArgument
263*523fa7a6SAndroid Build Coastguard Worker    custom_obj_name: str
264*523fa7a6SAndroid Build Coastguard Worker
265*523fa7a6SAndroid Build Coastguard Worker
266*523fa7a6SAndroid Build Coastguard Worker@dataclass
267*523fa7a6SAndroid Build Coastguard Workerclass InputTokenSpec:
268*523fa7a6SAndroid Build Coastguard Worker    arg: TokenArgument
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Worker
271*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
272*523fa7a6SAndroid Build Coastguard Workerclass InputSpec(_Union):
273*523fa7a6SAndroid Build Coastguard Worker    user_input: UserInputSpec
274*523fa7a6SAndroid Build Coastguard Worker    parameter: InputToParameterSpec
275*523fa7a6SAndroid Build Coastguard Worker    buffer: InputToBufferSpec
276*523fa7a6SAndroid Build Coastguard Worker    tensor_constant: InputToTensorConstantSpec
277*523fa7a6SAndroid Build Coastguard Worker    custom_obj: InputToCustomObjSpec
278*523fa7a6SAndroid Build Coastguard Worker    token: InputTokenSpec
279*523fa7a6SAndroid Build Coastguard Worker    constant_input: ConstantInputSpec
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Worker
282*523fa7a6SAndroid Build Coastguard Worker@dataclass
283*523fa7a6SAndroid Build Coastguard Workerclass UserOutputSpec:
284*523fa7a6SAndroid Build Coastguard Worker    arg: Argument
285*523fa7a6SAndroid Build Coastguard Worker
286*523fa7a6SAndroid Build Coastguard Worker
287*523fa7a6SAndroid Build Coastguard Worker@dataclass
288*523fa7a6SAndroid Build Coastguard Workerclass LossOutputSpec:
289*523fa7a6SAndroid Build Coastguard Worker    arg: TensorArgument
290*523fa7a6SAndroid Build Coastguard Worker
291*523fa7a6SAndroid Build Coastguard Worker
292*523fa7a6SAndroid Build Coastguard Worker@dataclass
293*523fa7a6SAndroid Build Coastguard Workerclass BufferMutationSpec:
294*523fa7a6SAndroid Build Coastguard Worker    arg: TensorArgument
295*523fa7a6SAndroid Build Coastguard Worker    buffer_name: str
296*523fa7a6SAndroid Build Coastguard Worker
297*523fa7a6SAndroid Build Coastguard Worker
298*523fa7a6SAndroid Build Coastguard Worker@dataclass
299*523fa7a6SAndroid Build Coastguard Workerclass GradientToParameterSpec:
300*523fa7a6SAndroid Build Coastguard Worker    arg: TensorArgument
301*523fa7a6SAndroid Build Coastguard Worker    parameter_name: str
302*523fa7a6SAndroid Build Coastguard Worker
303*523fa7a6SAndroid Build Coastguard Worker
304*523fa7a6SAndroid Build Coastguard Worker@dataclass
305*523fa7a6SAndroid Build Coastguard Workerclass GradientToUserInputSpec:
306*523fa7a6SAndroid Build Coastguard Worker    arg: TensorArgument
307*523fa7a6SAndroid Build Coastguard Worker    user_input_name: str
308*523fa7a6SAndroid Build Coastguard Worker
309*523fa7a6SAndroid Build Coastguard Worker
310*523fa7a6SAndroid Build Coastguard Worker@dataclass
311*523fa7a6SAndroid Build Coastguard Workerclass UserInputMutationSpec:
312*523fa7a6SAndroid Build Coastguard Worker    arg: TensorArgument
313*523fa7a6SAndroid Build Coastguard Worker    user_input_name: str
314*523fa7a6SAndroid Build Coastguard Worker
315*523fa7a6SAndroid Build Coastguard Worker
316*523fa7a6SAndroid Build Coastguard Worker@dataclass
317*523fa7a6SAndroid Build Coastguard Workerclass OutputTokenSpec:
318*523fa7a6SAndroid Build Coastguard Worker    arg: TokenArgument
319*523fa7a6SAndroid Build Coastguard Worker
320*523fa7a6SAndroid Build Coastguard Worker
321*523fa7a6SAndroid Build Coastguard Worker@dataclass(repr=False)
322*523fa7a6SAndroid Build Coastguard Workerclass OutputSpec(_Union):
323*523fa7a6SAndroid Build Coastguard Worker    user_output: UserOutputSpec
324*523fa7a6SAndroid Build Coastguard Worker    loss_output: LossOutputSpec
325*523fa7a6SAndroid Build Coastguard Worker    buffer_mutation: BufferMutationSpec
326*523fa7a6SAndroid Build Coastguard Worker    gradient_to_parameter: GradientToParameterSpec
327*523fa7a6SAndroid Build Coastguard Worker    gradient_to_user_input: GradientToUserInputSpec
328*523fa7a6SAndroid Build Coastguard Worker    user_input_mutation: UserInputMutationSpec
329*523fa7a6SAndroid Build Coastguard Worker    token: OutputTokenSpec
330*523fa7a6SAndroid Build Coastguard Worker
331*523fa7a6SAndroid Build Coastguard Worker
332*523fa7a6SAndroid Build Coastguard Worker@dataclass
333*523fa7a6SAndroid Build Coastguard Workerclass GraphSignature:
334*523fa7a6SAndroid Build Coastguard Worker    input_specs: List[InputSpec]
335*523fa7a6SAndroid Build Coastguard Worker    output_specs: List[OutputSpec]
336*523fa7a6SAndroid Build Coastguard Worker
337*523fa7a6SAndroid Build Coastguard Worker
338*523fa7a6SAndroid Build Coastguard Worker@dataclass
339*523fa7a6SAndroid Build Coastguard Workerclass RangeConstraint:
340*523fa7a6SAndroid Build Coastguard Worker    min_val: int
341*523fa7a6SAndroid Build Coastguard Worker    max_val: int
342*523fa7a6SAndroid Build Coastguard Worker
343*523fa7a6SAndroid Build Coastguard Worker
344*523fa7a6SAndroid Build Coastguard Worker@dataclass
345*523fa7a6SAndroid Build Coastguard Workerclass ModuleCallSignature:
346*523fa7a6SAndroid Build Coastguard Worker    inputs: List[Argument]
347*523fa7a6SAndroid Build Coastguard Worker    outputs: List[Argument]
348*523fa7a6SAndroid Build Coastguard Worker
349*523fa7a6SAndroid Build Coastguard Worker    # These are serialized by calling pytree.treespec_loads
350*523fa7a6SAndroid Build Coastguard Worker    # And deserialized by calling pytree.treespec_dumps
351*523fa7a6SAndroid Build Coastguard Worker    in_spec: str
352*523fa7a6SAndroid Build Coastguard Worker    out_spec: str
353*523fa7a6SAndroid Build Coastguard Worker
354*523fa7a6SAndroid Build Coastguard Worker
355*523fa7a6SAndroid Build Coastguard Worker@dataclass
356*523fa7a6SAndroid Build Coastguard Workerclass ModuleCallEntry:
357*523fa7a6SAndroid Build Coastguard Worker    fqn: str
358*523fa7a6SAndroid Build Coastguard Worker    signature: Optional[ModuleCallSignature] = None
359*523fa7a6SAndroid Build Coastguard Worker
360*523fa7a6SAndroid Build Coastguard Worker
361*523fa7a6SAndroid Build Coastguard Worker@dataclass
362*523fa7a6SAndroid Build Coastguard Workerclass GraphModule:
363*523fa7a6SAndroid Build Coastguard Worker    graph: Graph
364*523fa7a6SAndroid Build Coastguard Worker    signature: GraphSignature
365*523fa7a6SAndroid Build Coastguard Worker    # This is used for unflattening, by tracking the calling structure of all of
366*523fa7a6SAndroid Build Coastguard Worker    # the modules in order to unflatten the modules back to the eager calling
367*523fa7a6SAndroid Build Coastguard Worker    # conventions.
368*523fa7a6SAndroid Build Coastguard Worker    module_call_graph: List[ModuleCallEntry]
369*523fa7a6SAndroid Build Coastguard Worker
370*523fa7a6SAndroid Build Coastguard Worker
371*523fa7a6SAndroid Build Coastguard Worker# Invariant: Every time a change is made to the schema, one of the versions
372*523fa7a6SAndroid Build Coastguard Worker#            should be upadted.
373*523fa7a6SAndroid Build Coastguard Worker@dataclass
374*523fa7a6SAndroid Build Coastguard Workerclass SchemaVersion:
375*523fa7a6SAndroid Build Coastguard Worker    major: int  # Major version number is bumped every time a breaking change is made.
376*523fa7a6SAndroid Build Coastguard Worker    minor: int  # Minor version number is bumped when a compatible change is made.
377*523fa7a6SAndroid Build Coastguard Worker
378*523fa7a6SAndroid Build Coastguard Worker
379*523fa7a6SAndroid Build Coastguard Worker@dataclass
380*523fa7a6SAndroid Build Coastguard Workerclass ExportedProgram:
381*523fa7a6SAndroid Build Coastguard Worker    graph_module: GraphModule
382*523fa7a6SAndroid Build Coastguard Worker    # Key is the opset namespace (ex. aten), and value is the version number
383*523fa7a6SAndroid Build Coastguard Worker    opset_version: Dict[str, int]
384*523fa7a6SAndroid Build Coastguard Worker    range_constraints: Dict[str, RangeConstraint]
385*523fa7a6SAndroid Build Coastguard Worker    schema_version: SchemaVersion
386*523fa7a6SAndroid Build Coastguard Worker    dialect: str
387*523fa7a6SAndroid Build Coastguard Worker    verifiers: List[str] = field(default_factory=list)
388*523fa7a6SAndroid Build Coastguard Worker    dialect: str = ""  # TODO deprecated
389*523fa7a6SAndroid Build Coastguard Worker
390*523fa7a6SAndroid Build Coastguard Worker
391*523fa7a6SAndroid Build Coastguard Worker@dataclass
392*523fa7a6SAndroid Build Coastguard Workerclass CompileSpec:
393*523fa7a6SAndroid Build Coastguard Worker    key: str
394*523fa7a6SAndroid Build Coastguard Worker    value: str
395*523fa7a6SAndroid Build Coastguard Worker
396*523fa7a6SAndroid Build Coastguard Worker
397*523fa7a6SAndroid Build Coastguard Worker@dataclass
398*523fa7a6SAndroid Build Coastguard Workerclass LoweredBackendModule:
399*523fa7a6SAndroid Build Coastguard Worker    backend_id: str
400*523fa7a6SAndroid Build Coastguard Worker    processed_bytes: str
401*523fa7a6SAndroid Build Coastguard Worker    compile_specs: List[CompileSpec]
402*523fa7a6SAndroid Build Coastguard Worker    original_module: export_schema.ExportedProgram
403*523fa7a6SAndroid Build Coastguard Worker    original_state_dict: str
404*523fa7a6SAndroid Build Coastguard Worker    original_constants: str
405