1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
17
18 #include "mlir/IR/Builders.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "mlir/IR/Types.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27
28 namespace tflite {
29
30 using xla::StatusOr;
31
32 namespace errors = tensorflow::errors;
33
ConvertTypeToTensorType(mlir::Type type)34 tflite::TensorType ConvertTypeToTensorType(mlir::Type type) {
35 if (type.isF16()) {
36 return tflite::TensorType_FLOAT16;
37 } else if (type.isF32()) {
38 return tflite::TensorType_FLOAT32;
39 } else if (type.isF64()) {
40 return tflite::TensorType_FLOAT64;
41 } else if (type.isa<mlir::TF::StringType>()) {
42 return tflite::TensorType_STRING;
43 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
44 if (complex_type.getElementType().isF32()) {
45 return tflite::TensorType_COMPLEX64;
46 } else if (complex_type.getElementType().isF64()) {
47 return tflite::TensorType_COMPLEX128;
48 }
49 llvm_unreachable("invalid complex Type in conversion");
50 } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
51 switch (itype.getWidth()) {
52 case 1:
53 return tflite::TensorType_BOOL;
54 case 8:
55 if (itype.isUnsigned())
56 return tflite::TensorType_UINT8;
57 else
58 return tflite::TensorType_INT8;
59 case 16:
60 return tflite::TensorType_INT16;
61 case 32:
62 return tflite::TensorType_INT32;
63 case 64:
64 if (itype.isUnsigned())
65 return tflite::TensorType_UINT64;
66 else
67 return tflite::TensorType_INT64;
68 default:
69 llvm_unreachable("invalid integer Type in conversion");
70 }
71 }
72 llvm_unreachable("invalid Type in conversion");
73 }
74
ConvertElementType(tflite::TensorType type,mlir::Builder builder)75 mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
76 switch (type) {
77 case tflite::TensorType_FLOAT16:
78 return builder.getF16Type();
79 case tflite::TensorType_FLOAT32:
80 return builder.getF32Type();
81 case tflite::TensorType_FLOAT64:
82 return builder.getF64Type();
83 case tflite::TensorType_INT32:
84 return builder.getIntegerType(32);
85 case tflite::TensorType_UINT16:
86 return builder.getIntegerType(16, /*isSigned=*/false);
87 case tflite::TensorType_UINT32:
88 return builder.getIntegerType(32, /*isSigned=*/false);
89 case tflite::TensorType_UINT8:
90 return builder.getIntegerType(8, /*isSigned=*/false);
91 case tflite::TensorType_INT64:
92 return builder.getIntegerType(64);
93 case tflite::TensorType_STRING:
94 return mlir::TF::StringType::get(builder.getContext());
95 case tflite::TensorType_BOOL:
96 return builder.getI1Type();
97 case tflite::TensorType_INT16:
98 return builder.getIntegerType(16);
99 case tflite::TensorType_COMPLEX64:
100 return mlir::ComplexType::get(builder.getF32Type());
101 case tflite::TensorType_COMPLEX128:
102 return mlir::ComplexType::get(builder.getF64Type());
103 case tflite::TensorType_INT8:
104 return builder.getIntegerType(8);
105 case tflite::TensorType_UINT64:
106 return builder.getIntegerType(64, /*isSigned=*/false);
107 case tflite::TensorType_RESOURCE:
108 return mlir::TF::ResourceType::get(builder.getContext());
109 case tflite::TensorType_VARIANT:
110 return mlir::TF::VariantType::get(builder.getContext());
111 }
112 }
113
TflTypeToTfType(tflite::TensorType type)114 tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
115 switch (type) {
116 case tflite::TensorType_BOOL:
117 return tensorflow::DT_BOOL;
118 case tflite::TensorType_COMPLEX64:
119 return tensorflow::DT_COMPLEX64;
120 case tflite::TensorType_COMPLEX128:
121 return tensorflow::DT_COMPLEX128;
122 case tflite::TensorType_FLOAT16:
123 return tensorflow::DT_HALF;
124 case tflite::TensorType_FLOAT32:
125 return tensorflow::DT_FLOAT;
126 case tflite::TensorType_FLOAT64:
127 return tensorflow::DT_DOUBLE;
128 case tflite::TensorType_INT8:
129 return tensorflow::DT_INT8;
130 case tflite::TensorType_INT16:
131 return tensorflow::DT_INT16;
132 case tflite::TensorType_INT32:
133 return tensorflow::DT_INT32;
134 case tflite::TensorType_UINT32:
135 return tensorflow::DT_UINT32;
136 case tflite::TensorType_INT64:
137 return tensorflow::DT_INT64;
138 case tflite::TensorType_STRING:
139 return tensorflow::DT_STRING;
140 case tflite::TensorType_UINT8:
141 return tensorflow::DT_UINT8;
142 case tflite::TensorType_UINT16:
143 return tensorflow::DT_UINT16;
144 case tflite::TensorType_UINT64:
145 return tensorflow::DT_UINT64;
146 case tflite::TensorType_RESOURCE:
147 return tensorflow::DT_RESOURCE;
148 case tflite::TensorType_VARIANT:
149 return tensorflow::DT_VARIANT;
150 }
151 }
152
TfTypeToTflType(tensorflow::DataType type)153 StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
154 switch (type) {
155 case tensorflow::DT_BOOL:
156 return tflite::TensorType_BOOL;
157 case tensorflow::DT_COMPLEX64:
158 return tflite::TensorType_COMPLEX64;
159 case tensorflow::DT_COMPLEX128:
160 return tflite::TensorType_COMPLEX128;
161 case tensorflow::DT_HALF:
162 return tflite::TensorType_FLOAT16;
163 case tensorflow::DT_FLOAT:
164 return tflite::TensorType_FLOAT32;
165 case tensorflow::DT_DOUBLE:
166 return tflite::TensorType_FLOAT64;
167 case tensorflow::DT_INT8:
168 return tflite::TensorType_INT8;
169 case tensorflow::DT_INT16:
170 return tflite::TensorType_INT16;
171 case tensorflow::DT_INT32:
172 return tflite::TensorType_INT32;
173 case tensorflow::DT_UINT32:
174 return tflite::TensorType_UINT32;
175 case tensorflow::DT_INT64:
176 return tflite::TensorType_INT64;
177 case tensorflow::DT_UINT64:
178 return tflite::TensorType_UINT64;
179 case tensorflow::DT_STRING:
180 return tflite::TensorType_STRING;
181 case tensorflow::DT_UINT8:
182 return tflite::TensorType_UINT8;
183 case tensorflow::DT_RESOURCE:
184 return tflite::TensorType_RESOURCE;
185 case tensorflow::DT_VARIANT:
186 return tflite::TensorType_VARIANT;
187 default:
188 return errors::InvalidArgument("unsupported tensor data type", type);
189 }
190 }
191
GetShapeStrippedType(mlir::TypeAttr type_attr)192 mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) {
193 auto type = type_attr.getValue();
194 auto shaped_type = type.dyn_cast<mlir::ShapedType>();
195 if (shaped_type) {
196 return shaped_type.getElementType();
197 } else {
198 return type;
199 }
200 }
201
NotFromQuantOpOrSameQuantType(mlir::Value val,mlir::TypeAttr qtype_attr)202 bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) {
203 auto val_defn_op = val.getDefiningOp();
204 mlir::TFL::QuantizeOp q_op =
205 llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op);
206 if (!q_op) return true;
207
208 // Ignore shape details - we're really only trying to
209 // check if quantization is the same.
210 auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
211 auto stripped_qtype = GetShapeStrippedType(qtype_attr);
212 return stripped_src_qtype == stripped_qtype;
213 }
214
215 } // namespace tflite
216