xref: /aosp_15_r20/external/executorch/backends/vulkan/serialization/vulkan_graph_schema.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9"""
10Please refer to fbcode/caffe2/executorch/backends/vulkan/serialization/schema/schema.fbs for the schema definitions
11"""
12
13from dataclasses import dataclass
14from enum import IntEnum
15from typing import List, Union
16
17
18@dataclass
19class OperatorCall:
20    node_id: int
21    name: str
22    args: List[int]
23
24
25class VkDataType(IntEnum):
26    BOOL = 0
27    UINT8 = 1
28    INT8 = 2
29    INT32 = 3
30    FLOAT16 = 4
31    FLOAT32 = 5
32
33
34class VkStorageType(IntEnum):
35    BUFFER = 0
36    TEXTURE_3D = 1
37    TEXTURE_2D = 2
38    DEFAULT_STORAGE = 255
39
40    def __str__(self) -> str:
41        return self.name
42
43
44class VkMemoryLayout(IntEnum):
45    TENSOR_WIDTH_PACKED = 0
46    TENSOR_HEIGHT_PACKED = 1
47    TENSOR_CHANNELS_PACKED = 2
48    DEFAULT_LAYOUT = 255
49
50    def __str__(self) -> str:
51        return self.name
52
53
54@dataclass
55class VkTensor:
56    datatype: VkDataType
57    dims: List[int]
58    constant_id: int
59    mem_obj_id: int
60    storage_type: VkStorageType = VkStorageType.DEFAULT_STORAGE
61    memory_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT
62
63
64@dataclass
65class Null:
66    pass
67
68
69@dataclass
70class Int:
71    int_val: int
72
73
74@dataclass
75class Bool:
76    bool_val: bool
77
78
79@dataclass
80class Double:
81    double_val: float
82
83
84@dataclass
85class IntList:
86    items: List[int]
87
88
89@dataclass
90class DoubleList:
91    items: List[float]
92
93
94@dataclass
95class BoolList:
96    items: List[bool]
97
98
99@dataclass
100class ValueList:
101    items: List[int]
102
103
104@dataclass
105class String:
106    string_val: str
107
108
109@dataclass
110class SymInt:
111    value: int
112
113
114GraphTypes = Union[
115    Null,
116    Int,
117    Double,
118    Bool,
119    VkTensor,
120    IntList,
121    BoolList,
122    DoubleList,
123    ValueList,
124    String,
125    SymInt,
126]
127
128
129@dataclass
130class VkValue:
131    value: "GraphTypes"
132
133
134@dataclass
135class VkBytes:
136    offset: int
137    length: int
138
139
140@dataclass
141class VkGraph:
142    version: str
143
144    chain: List[OperatorCall]
145    values: List[VkValue]
146
147    input_ids: List[int]
148    output_ids: List[int]
149
150    constants: List[VkBytes]
151    shaders: List[VkBytes]
152
153    storage_type_override: VkStorageType = VkStorageType.DEFAULT_STORAGE
154    memory_layout_override: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT
155