# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # NOTE: This is a placeholder for iterating on export serialization schema design. # Anything is subject to change and no guarantee is provided at this point. from dataclasses import dataclass, field from enum import IntEnum from typing import Dict, List, Optional, Tuple import executorch.exir.serde.schema as export_schema from executorch.exir.serde.union import _Union # NOTE: Please update this value if any modifications are made to the schema SCHEMA_VERSION = (5, 3) TREESPEC_VERSION = 1 class ScalarType(IntEnum): UNKNOWN = 0 BYTE = 1 CHAR = 2 SHORT = 3 INT = 4 LONG = 5 HALF = 6 FLOAT = 7 DOUBLE = 8 COMPLEXHALF = 9 COMPLEXFLOAT = 10 COMPLEXDOUBLE = 11 BOOL = 12 BFLOAT16 = 13 UINT16 = 14 class Layout(IntEnum): Unknown = 0 SparseCoo = 1 SparseCsr = 2 SparseCsc = 3 SparseBsr = 4 SparseBsc = 5 _mkldnn = 6 Strided = 7 class MemoryFormat(IntEnum): Unknown = 0 ContiguousFormat = 1 ChannelsLast = 2 ChannelsLast3d = 3 PreserveFormat = 4 @dataclass class Device: type: str index: Optional[int] = None @dataclass(repr=False) class SymExprHint(_Union): as_int: int as_float: float as_bool: bool # This is for storing the symbolic expressions behind symints/symfloats/symbools # For example, we can get something like # SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) # if we also have the hint that s0 and s1 are both 2. @dataclass class SymExpr: expr_str: str hint: Optional[SymExprHint] = None @dataclass(repr=False) class SymInt(_Union): as_expr: SymExpr as_int: int @dataclass(repr=False) class SymBool(_Union): as_expr: SymExpr as_bool: bool @dataclass class TensorMeta: dtype: ScalarType sizes: List[SymInt] requires_grad: bool device: Device strides: List[SymInt] storage_offset: SymInt layout: Layout # In most cases we will use the "as_name" field to store arguments which are # SymInts. # The "as_int" field is used in the case where we have a list containing a mix # of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to # be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints # to the "as_int" field. @dataclass(repr=False) class SymIntArgument(_Union): as_name: str as_int: int # In most cases we will use the "as_name" field to store arguments which are # SymBools. # The "as_bool" field is used in the case where we have a list containing a mix # of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to # be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools # to the "as_bool" field. @dataclass(repr=False) class SymBoolArgument(_Union): as_name: str as_bool: bool @dataclass class TensorArgument: name: str @dataclass class TokenArgument: name: str # This is use for storing the contents of a list which contain optional tensors # (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the # type List[OptionalTensorArgument], with tensor values seiralized to the # "as_tensor" field, and None values serialized to the "as_none" field. @dataclass(repr=False) class OptionalTensorArgument(_Union): as_tensor: TensorArgument as_none: Tuple[()] @dataclass class GraphArgument: name: str graph: "Graph" @dataclass class CustomObjArgument: name: str class_fqn: str # This is actually a union type @dataclass(repr=False) class Argument(_Union): as_none: Tuple[()] as_tensor: TensorArgument as_tensors: List[TensorArgument] as_int: int as_ints: List[int] as_float: float as_floats: List[float] as_string: str as_strings: List[str] as_sym_int: SymIntArgument as_sym_ints: List[SymIntArgument] as_scalar_type: ScalarType as_memory_format: MemoryFormat as_layout: Layout as_device: Device as_bool: bool as_bools: List[bool] as_sym_bool: SymBoolArgument as_sym_bools: List[SymBoolArgument] as_graph: GraphArgument as_optional_tensors: List[OptionalTensorArgument] as_custom_obj: CustomObjArgument as_operator: str @dataclass class NamedArgument: # Argument name from the operator schema name: str arg: Argument @dataclass class Node: target: str inputs: List[NamedArgument] outputs: List[Argument] metadata: Dict[str, str] @dataclass class Graph: inputs: List[Argument] outputs: List[Argument] nodes: List[Node] tensor_values: Dict[str, TensorMeta] sym_int_values: Dict[str, SymInt] sym_bool_values: Dict[str, SymBool] # This is for deserializing the submodule graphs from higher order ops # (ex. cond, map) where single tensor returns will just return a single # tensor, rather than following export schema and returning a singleton # list. is_single_tensor_return: bool = False custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) @dataclass class UserInputSpec: # Actually, only tensors and SymInts are allowed here arg: Argument @dataclass(repr=False) class ConstantValue(_Union): as_none: Tuple[()] as_int: int as_float: float as_string: str as_bool: bool @dataclass class ConstantInputSpec: name: str value: ConstantValue @dataclass class InputToParameterSpec: arg: TensorArgument parameter_name: str @dataclass class InputToBufferSpec: arg: TensorArgument buffer_name: str persistent: bool @dataclass class InputToTensorConstantSpec: arg: TensorArgument tensor_constant_name: str @dataclass class InputToCustomObjSpec: arg: CustomObjArgument custom_obj_name: str @dataclass class InputTokenSpec: arg: TokenArgument @dataclass(repr=False) class InputSpec(_Union): user_input: UserInputSpec parameter: InputToParameterSpec buffer: InputToBufferSpec tensor_constant: InputToTensorConstantSpec custom_obj: InputToCustomObjSpec token: InputTokenSpec constant_input: ConstantInputSpec @dataclass class UserOutputSpec: arg: Argument @dataclass class LossOutputSpec: arg: TensorArgument @dataclass class BufferMutationSpec: arg: TensorArgument buffer_name: str @dataclass class GradientToParameterSpec: arg: TensorArgument parameter_name: str @dataclass class GradientToUserInputSpec: arg: TensorArgument user_input_name: str @dataclass class UserInputMutationSpec: arg: TensorArgument user_input_name: str @dataclass class OutputTokenSpec: arg: TokenArgument @dataclass(repr=False) class OutputSpec(_Union): user_output: UserOutputSpec loss_output: LossOutputSpec buffer_mutation: BufferMutationSpec gradient_to_parameter: GradientToParameterSpec gradient_to_user_input: GradientToUserInputSpec user_input_mutation: UserInputMutationSpec token: OutputTokenSpec @dataclass class GraphSignature: input_specs: List[InputSpec] output_specs: List[OutputSpec] @dataclass class RangeConstraint: min_val: int max_val: int @dataclass class ModuleCallSignature: inputs: List[Argument] outputs: List[Argument] # These are serialized by calling pytree.treespec_loads # And deserialized by calling pytree.treespec_dumps in_spec: str out_spec: str @dataclass class ModuleCallEntry: fqn: str signature: Optional[ModuleCallSignature] = None @dataclass class GraphModule: graph: Graph signature: GraphSignature # This is used for unflattening, by tracking the calling structure of all of # the modules in order to unflatten the modules back to the eager calling # conventions. module_call_graph: List[ModuleCallEntry] # Invariant: Every time a change is made to the schema, one of the versions # should be upadted. @dataclass class SchemaVersion: major: int # Major version number is bumped every time a breaking change is made. minor: int # Minor version number is bumped when a compatible change is made. @dataclass class ExportedProgram: graph_module: GraphModule # Key is the opset namespace (ex. aten), and value is the version number opset_version: Dict[str, int] range_constraints: Dict[str, RangeConstraint] schema_version: SchemaVersion dialect: str verifiers: List[str] = field(default_factory=list) dialect: str = "" # TODO deprecated @dataclass class CompileSpec: key: str value: str @dataclass class LoweredBackendModule: backend_id: str processed_bytes: str compile_specs: List[CompileSpec] original_module: export_schema.ExportedProgram original_state_dict: str original_constants: str