xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the operation enums definition file for TensorFlow Lite.
17
18#ifndef TFL_OP_ENUMS
19#define TFL_OP_ENUMS
20
21include "mlir/IR/AttrTypeBase.td"
22include "mlir/IR/EnumAttr.td"
23include "mlir/IR/OpBase.td"
24include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
25
26// A string attribute whose value are one of the values in `cases`.
27// Referred TF_AnyStrAttrOf in tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
28class TFL_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
29  CPred<!foldl(
30      "$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
31      !foreach(case, !tail(cases),
32               "$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
33      prev, cur, prev # " || " # cur)>,
34  "string attribute whose value is " #
35    !foldl(/*init*/!head(cases), /*list*/!tail(cases),
36           prev, cur, prev # ", or " # cur)>;
37
38// Attributes used for encoding sparse tensors.
39// Please find detailed explanation of these parameters in the TFLite schema.
40def TFL_DT_Dense : I32EnumAttrCase<"DENSE", 0>;
41def TFL_DT_SparseCSR : I32EnumAttrCase<"SPARSE_CSR", 1>;
42
43def TFL_DimensionType : I32EnumAttr<
44    "DimensionType", "dimension_type", [TFL_DT_Dense, TFL_DT_SparseCSR]> {
45  let genSpecializedAttr = 0;
46  let cppNamespace = "::mlir::TFL";
47}
48
49def TFL_DimensionTypeAttr : EnumAttr<TFL_Dialect, TFL_DimensionType,
50    "dimension_type_attr"> {
51  let convertFromStorage = "$_self";
52}
53
54// Allowed activation function cases
55// These should match the ActivationFunctionType enum in TFLite schema.
56def TFL_AFEnum_None  : I32EnumAttrCase<"NONE", 0>;
57def TFL_AFEnum_Relu  : I32EnumAttrCase<"RELU", 1>;
58def TFL_AFEnum_Relu1 : I32EnumAttrCase<"RELU_N1_TO_1", 2>;
59def TFL_AFEnum_Relu6 : I32EnumAttrCase<"RELU6", 3>;
60def TFL_AFEnum_Tanh  : I32EnumAttrCase<"TANH", 4>;
61def TFL_AFEnum_Sign  : I32EnumAttrCase<"SIGN_BIT", 5>;
62
63def TFL_AFAttr : TFL_AnyStrAttrOf<[
64      TFL_AFEnum_None.symbol,  TFL_AFEnum_Relu.symbol, TFL_AFEnum_Relu1.symbol,
65      TFL_AFEnum_Relu6.symbol, TFL_AFEnum_Tanh.symbol, TFL_AFEnum_Sign.symbol
66    ]>;
67
68// Predefined constant attributes of activation function
69def TFL_AF_None  : ConstantStrAttr<TFL_AFAttr, TFL_AFEnum_None.symbol>;
70def TFL_AF_Relu  : ConstantStrAttr<TFL_AFAttr, TFL_AFEnum_Relu.symbol>;
71def TFL_AF_Relu1 : ConstantStrAttr<TFL_AFAttr, TFL_AFEnum_Relu1.symbol>;
72def TFL_AF_Relu6 : ConstantStrAttr<TFL_AFAttr, TFL_AFEnum_Relu6.symbol>;
73def TFL_AF_Tanh  : ConstantStrAttr<TFL_AFAttr, TFL_AFEnum_Tanh.symbol>;
74def TFL_AF_Sign  : ConstantStrAttr<TFL_AFAttr, TFL_AFEnum_Sign.symbol>;
75
76// Allowed padding cases
77// These should match the padding enum in TFLite schema.
78def TFL_PADEnum_Same  : I32EnumAttrCase<"SAME", 0>;
79def TFL_PADEnum_Valid : I32EnumAttrCase<"VALID", 1>;
80def TFL_PaddingAttr : TFL_AnyStrAttrOf<[
81      TFL_PADEnum_Same.symbol, TFL_PADEnum_Valid.symbol
82    ]>;
83def TFL_PAD_Same  : ConstantStrAttr<TFL_PaddingAttr, TFL_PADEnum_Same.symbol>;
84def TFL_PAD_Valid : ConstantStrAttr<TFL_PaddingAttr, TFL_PADEnum_Valid.symbol>;
85
86// FullyConnectedOptionsWeightFormat attributes
87def TFL_FCWOEnum_Default         : I32EnumAttrCase<"DEFAULT", 0>;
88def TFL_FCWOEnum_Shuffled4x16i8  : I32EnumAttrCase<"SHUFFLED4x16INT8", 1>;
89def TFL_FullyConnectedOptionsWeightFormatAttr :
90    TFL_AnyStrAttrOf<[
91      TFL_FCWOEnum_Default.symbol,
92      TFL_FCWOEnum_Shuffled4x16i8.symbol
93    ]>;
94def TFL_FCWO_Default        : ConstantStrAttr<
95      TFL_FullyConnectedOptionsWeightFormatAttr, TFL_FCWOEnum_Default.symbol>;
96def TFL_FCWO_Shuffled4x16i8 : ConstantStrAttr<
97      TFL_FullyConnectedOptionsWeightFormatAttr, TFL_FCWOEnum_Shuffled4x16i8.symbol>;
98
99// MirrorPadding type attributes
100def TFL_MIRRORPAD_Reflect : I32EnumAttrCase<"REFLECT", 0>;
101def TFL_MIRRORPAD_Symmetric : I32EnumAttrCase<"SYMMETRIC", 1>;
102def TFL_MirrorPaddingType : I32EnumAttr<"MirrorPaddingType", "mirror_pad_enum", [
103      TFL_MIRRORPAD_Reflect, TFL_MIRRORPAD_Symmetric
104    ]> {
105  let genSpecializedAttr = 0;
106  let cppNamespace = "::mlir::TFL";
107}
108def TFL_MirrorPaddingAttr : EnumAttr<TFL_Dialect, TFL_MirrorPaddingType,
109    "mirror_pad_attr">;
110
111// LSTM Kernel Type attributes
112def TFL_LSTM_KT_FULL  : I32EnumAttrCase<"FULL", 0>;
113def TFL_LSTM_KT_BASIC  : I32EnumAttrCase<"BASIC", 1>;
114def TFL_LSTMKernelType : I32EnumAttr<"LSTMKernelType", "lstm_kernel_type",
115   [
116     TFL_LSTM_KT_FULL,  TFL_LSTM_KT_BASIC
117   ]>{
118  let genSpecializedAttr = 0;
119  let cppNamespace = "::mlir::TFL";
120}
121def TFL_LSTMKernelTypeAttr : EnumAttr<TFL_Dialect, TFL_LSTMKernelType,
122    "lstm_kernel_type_attr">;
123
124def I32ArrayParameter :
125    AttrOrTypeParameter<"::llvm::ArrayRef<int32_t>", ""> {
126  let allocator = [{$_dst = $_allocator.copyInto($_self);}];
127  let cppStorageType = "::llvm::SmallVector<int32_t>";
128  let parser = "::mlir::TFL::parseI32Array($_parser)";
129  let printer = "$_printer << '[' << $_self << ']'";
130}
131
132def DimensionMetadataAttr : AttrDef<TFL_Dialect, "DimensionMetadata"> {
133  let mnemonic = "dimension_metadata";
134  let parameters = (ins
135      TFL_DimensionTypeAttr:$format,
136      "int32_t":$dense_size,
137      I32ArrayParameter:$segments,
138      I32ArrayParameter:$indices
139  );
140  let summary = "Dimension metadata.";
141  let assemblyFormat = "`<` struct(params) `>`";
142}
143
144def SparsityParameterAttr : AttrDef<TFL_Dialect, "SparsityParameter"> {
145  let mnemonic = "sparsity_parameter";
146  let parameters = (ins
147      I32ArrayParameter:$traversal_order,
148      I32ArrayParameter:$block_map,
149      ArrayRefParameter<"DimensionMetadataAttr">:$dim_metadata
150  );
151  let summary = "Sparsity parameter.";
152  let assemblyFormat = "`<` struct(params) `>`";
153}
154
155def TFL_ConstBytesAttr : AttrDef<TFL_Dialect, "ConstBytes"> {
156  let summary = "A string attribute representation of compiled bytes";
157  let description = [{
158    Syntax Examples:
159
160    ```mlir
161    #tfl<const_bytes : "0xDEADBEEF">
162    ```
163  }];
164  let mnemonic = "const_bytes";
165  let parameters = (ins StringRefParameter<"">:$value);
166  let hasCustomAssemblyFormat = 1;
167}
168
169#endif // TFL_OP_ENUMS
170