1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9""" 10Please refer to fbcode/caffe2/executorch/backends/vulkan/serialization/schema/schema.fbs for the schema definitions 11""" 12 13from dataclasses import dataclass 14from enum import IntEnum 15from typing import List, Union 16 17 18@dataclass 19class OperatorCall: 20 node_id: int 21 name: str 22 args: List[int] 23 24 25class VkDataType(IntEnum): 26 BOOL = 0 27 UINT8 = 1 28 INT8 = 2 29 INT32 = 3 30 FLOAT16 = 4 31 FLOAT32 = 5 32 33 34class VkStorageType(IntEnum): 35 BUFFER = 0 36 TEXTURE_3D = 1 37 TEXTURE_2D = 2 38 DEFAULT_STORAGE = 255 39 40 def __str__(self) -> str: 41 return self.name 42 43 44class VkMemoryLayout(IntEnum): 45 TENSOR_WIDTH_PACKED = 0 46 TENSOR_HEIGHT_PACKED = 1 47 TENSOR_CHANNELS_PACKED = 2 48 DEFAULT_LAYOUT = 255 49 50 def __str__(self) -> str: 51 return self.name 52 53 54@dataclass 55class VkTensor: 56 datatype: VkDataType 57 dims: List[int] 58 constant_id: int 59 mem_obj_id: int 60 storage_type: VkStorageType = VkStorageType.DEFAULT_STORAGE 61 memory_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT 62 63 64@dataclass 65class Null: 66 pass 67 68 69@dataclass 70class Int: 71 int_val: int 72 73 74@dataclass 75class Bool: 76 bool_val: bool 77 78 79@dataclass 80class Double: 81 double_val: float 82 83 84@dataclass 85class IntList: 86 items: List[int] 87 88 89@dataclass 90class DoubleList: 91 items: List[float] 92 93 94@dataclass 95class BoolList: 96 items: List[bool] 97 98 99@dataclass 100class ValueList: 101 items: List[int] 102 103 104@dataclass 105class String: 106 string_val: str 107 108 109@dataclass 110class SymInt: 111 value: int 112 113 114GraphTypes = Union[ 115 Null, 116 Int, 117 Double, 118 Bool, 119 VkTensor, 120 IntList, 121 BoolList, 122 DoubleList, 123 ValueList, 124 String, 125 SymInt, 126] 127 128 129@dataclass 130class VkValue: 131 value: "GraphTypes" 132 133 134@dataclass 135class VkBytes: 136 offset: int 137 length: int 138 139 140@dataclass 141class VkGraph: 142 version: str 143 144 chain: List[OperatorCall] 145 values: List[VkValue] 146 147 input_ids: List[int] 148 output_ids: List[int] 149 150 constants: List[VkBytes] 151 shaders: List[VkBytes] 152 153 storage_type_override: VkStorageType = VkStorageType.DEFAULT_STORAGE 154 memory_layout_override: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT 155