xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/utils/convert_type.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
17 
18 #include "mlir/IR/Builders.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "mlir/IR/Types.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 
28 namespace tflite {
29 
30 using xla::StatusOr;
31 
32 namespace errors = tensorflow::errors;
33 
ConvertTypeToTensorType(mlir::Type type)34 tflite::TensorType ConvertTypeToTensorType(mlir::Type type) {
35   if (type.isF16()) {
36     return tflite::TensorType_FLOAT16;
37   } else if (type.isF32()) {
38     return tflite::TensorType_FLOAT32;
39   } else if (type.isF64()) {
40     return tflite::TensorType_FLOAT64;
41   } else if (type.isa<mlir::TF::StringType>()) {
42     return tflite::TensorType_STRING;
43   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
44     if (complex_type.getElementType().isF32()) {
45       return tflite::TensorType_COMPLEX64;
46     } else if (complex_type.getElementType().isF64()) {
47       return tflite::TensorType_COMPLEX128;
48     }
49     llvm_unreachable("invalid complex Type in conversion");
50   } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
51     switch (itype.getWidth()) {
52       case 1:
53         return tflite::TensorType_BOOL;
54       case 8:
55         if (itype.isUnsigned())
56           return tflite::TensorType_UINT8;
57         else
58           return tflite::TensorType_INT8;
59       case 16:
60         return tflite::TensorType_INT16;
61       case 32:
62         return tflite::TensorType_INT32;
63       case 64:
64         if (itype.isUnsigned())
65           return tflite::TensorType_UINT64;
66         else
67           return tflite::TensorType_INT64;
68       default:
69         llvm_unreachable("invalid integer Type in conversion");
70     }
71   }
72   llvm_unreachable("invalid Type in conversion");
73 }
74 
ConvertElementType(tflite::TensorType type,mlir::Builder builder)75 mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
76   switch (type) {
77     case tflite::TensorType_FLOAT16:
78       return builder.getF16Type();
79     case tflite::TensorType_FLOAT32:
80       return builder.getF32Type();
81     case tflite::TensorType_FLOAT64:
82       return builder.getF64Type();
83     case tflite::TensorType_INT32:
84       return builder.getIntegerType(32);
85     case tflite::TensorType_UINT16:
86       return builder.getIntegerType(16, /*isSigned=*/false);
87     case tflite::TensorType_UINT32:
88       return builder.getIntegerType(32, /*isSigned=*/false);
89     case tflite::TensorType_UINT8:
90       return builder.getIntegerType(8, /*isSigned=*/false);
91     case tflite::TensorType_INT64:
92       return builder.getIntegerType(64);
93     case tflite::TensorType_STRING:
94       return mlir::TF::StringType::get(builder.getContext());
95     case tflite::TensorType_BOOL:
96       return builder.getI1Type();
97     case tflite::TensorType_INT16:
98       return builder.getIntegerType(16);
99     case tflite::TensorType_COMPLEX64:
100       return mlir::ComplexType::get(builder.getF32Type());
101     case tflite::TensorType_COMPLEX128:
102       return mlir::ComplexType::get(builder.getF64Type());
103     case tflite::TensorType_INT8:
104       return builder.getIntegerType(8);
105     case tflite::TensorType_UINT64:
106       return builder.getIntegerType(64, /*isSigned=*/false);
107     case tflite::TensorType_RESOURCE:
108       return mlir::TF::ResourceType::get(builder.getContext());
109     case tflite::TensorType_VARIANT:
110       return mlir::TF::VariantType::get(builder.getContext());
111   }
112 }
113 
TflTypeToTfType(tflite::TensorType type)114 tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
115   switch (type) {
116     case tflite::TensorType_BOOL:
117       return tensorflow::DT_BOOL;
118     case tflite::TensorType_COMPLEX64:
119       return tensorflow::DT_COMPLEX64;
120     case tflite::TensorType_COMPLEX128:
121       return tensorflow::DT_COMPLEX128;
122     case tflite::TensorType_FLOAT16:
123       return tensorflow::DT_HALF;
124     case tflite::TensorType_FLOAT32:
125       return tensorflow::DT_FLOAT;
126     case tflite::TensorType_FLOAT64:
127       return tensorflow::DT_DOUBLE;
128     case tflite::TensorType_INT8:
129       return tensorflow::DT_INT8;
130     case tflite::TensorType_INT16:
131       return tensorflow::DT_INT16;
132     case tflite::TensorType_INT32:
133       return tensorflow::DT_INT32;
134     case tflite::TensorType_UINT32:
135       return tensorflow::DT_UINT32;
136     case tflite::TensorType_INT64:
137       return tensorflow::DT_INT64;
138     case tflite::TensorType_STRING:
139       return tensorflow::DT_STRING;
140     case tflite::TensorType_UINT8:
141       return tensorflow::DT_UINT8;
142     case tflite::TensorType_UINT16:
143       return tensorflow::DT_UINT16;
144     case tflite::TensorType_UINT64:
145       return tensorflow::DT_UINT64;
146     case tflite::TensorType_RESOURCE:
147       return tensorflow::DT_RESOURCE;
148     case tflite::TensorType_VARIANT:
149       return tensorflow::DT_VARIANT;
150   }
151 }
152 
TfTypeToTflType(tensorflow::DataType type)153 StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
154   switch (type) {
155     case tensorflow::DT_BOOL:
156       return tflite::TensorType_BOOL;
157     case tensorflow::DT_COMPLEX64:
158       return tflite::TensorType_COMPLEX64;
159     case tensorflow::DT_COMPLEX128:
160       return tflite::TensorType_COMPLEX128;
161     case tensorflow::DT_HALF:
162       return tflite::TensorType_FLOAT16;
163     case tensorflow::DT_FLOAT:
164       return tflite::TensorType_FLOAT32;
165     case tensorflow::DT_DOUBLE:
166       return tflite::TensorType_FLOAT64;
167     case tensorflow::DT_INT8:
168       return tflite::TensorType_INT8;
169     case tensorflow::DT_INT16:
170       return tflite::TensorType_INT16;
171     case tensorflow::DT_INT32:
172       return tflite::TensorType_INT32;
173     case tensorflow::DT_UINT32:
174       return tflite::TensorType_UINT32;
175     case tensorflow::DT_INT64:
176       return tflite::TensorType_INT64;
177     case tensorflow::DT_UINT64:
178       return tflite::TensorType_UINT64;
179     case tensorflow::DT_STRING:
180       return tflite::TensorType_STRING;
181     case tensorflow::DT_UINT8:
182       return tflite::TensorType_UINT8;
183     case tensorflow::DT_RESOURCE:
184       return tflite::TensorType_RESOURCE;
185     case tensorflow::DT_VARIANT:
186       return tflite::TensorType_VARIANT;
187     default:
188       return errors::InvalidArgument("unsupported tensor data type", type);
189   }
190 }
191 
GetShapeStrippedType(mlir::TypeAttr type_attr)192 mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) {
193   auto type = type_attr.getValue();
194   auto shaped_type = type.dyn_cast<mlir::ShapedType>();
195   if (shaped_type) {
196     return shaped_type.getElementType();
197   } else {
198     return type;
199   }
200 }
201 
NotFromQuantOpOrSameQuantType(mlir::Value val,mlir::TypeAttr qtype_attr)202 bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) {
203   auto val_defn_op = val.getDefiningOp();
204   mlir::TFL::QuantizeOp q_op =
205       llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op);
206   if (!q_op) return true;
207 
208   // Ignore shape details - we're really only trying to
209   // check if quantization is the same.
210   auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
211   auto stripped_qtype = GetShapeStrippedType(qtype_attr);
212   return stripped_src_qtype == stripped_qtype;
213 }
214 
215 }  // namespace tflite
216