1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
17 // input and output tensors. This is needed because TOSA precision is
18 // implemented in the int8 domain. This pass does:
19 // 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
20 // remapped.
21 // 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
22 // is uint8 typed.
23 // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
24 // typed.
25
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31
32 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
33 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
34 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
35 #include "mlir/IR/Builders.h" // from @llvm-project
36 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
37 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
38 #include "mlir/IR/PatternMatch.h" // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
44 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
45 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
46
47 #define PASS_NAME "tosa-convert-tfl-uint8"
48 #define DEBUG_TYPE PASS_NAME
49
50 namespace mlir {
51 namespace tosa {
52 namespace {
53
54 // Performs lowering to TOSA dialect.
55 class ConvertUint8ToInt8
56 : public TosaConvertTFLUint8PassBase<ConvertUint8ToInt8> {
57 public:
ConvertUint8ToInt8()58 explicit ConvertUint8ToInt8() {}
59 void runOnOperation() override;
60 };
61
62 struct ConvertUint8QConstOp : public RewritePattern {
ConvertUint8QConstOpmlir::tosa::__anon97a43a8d0111::ConvertUint8QConstOp63 explicit ConvertUint8QConstOp(MLIRContext *context)
64 : RewritePattern(TFL::QConstOp::getOperationName(), 1, context) {}
65
matchAndRewritemlir::tosa::__anon97a43a8d0111::ConvertUint8QConstOp66 LogicalResult matchAndRewrite(Operation *op,
67 PatternRewriter &builder) const override {
68 auto tfl_qconst_op = cast<TFL::QConstOp>(op);
69
70 // Skip if it's not ranked tensor type.
71 auto output_type =
72 tfl_qconst_op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
73 if (!output_type)
74 return builder.notifyMatchFailure(op, "not ranked tensor");
75
76 // Skip if output is not per-tensor quantized type.
77 auto output_element_type =
78 output_type.getElementType()
79 .dyn_cast<mlir::quant::UniformQuantizedType>();
80 if (!output_element_type) return failure();
81
82 // Skip if output is not uint8.
83 if (output_element_type.isSigned() ||
84 output_element_type.getStorageTypeIntegralWidth() != 8) {
85 return failure();
86 }
87
88 mlir::DenseElementsAttr src_dense_attr =
89 tfl_qconst_op.value().cast<DenseElementsAttr>();
90
91 double type_range_min =
92 static_cast<double>(output_element_type.getStorageTypeMin() -
93 output_element_type.getZeroPoint()) *
94 output_element_type.getScale();
95 double type_range_max =
96 static_cast<double>(output_element_type.getStorageTypeMax() -
97 output_element_type.getZeroPoint()) *
98 output_element_type.getScale();
99 bool narrow_range =
100 output_element_type.getStorageTypeMin() == 1 ? true : false;
101
102 auto dst_qconst_type = TypeAttr::get(RankedTensorType::get(
103 output_type.getShape(),
104 buildQTypeFromMinMax(
105 builder, output_element_type.getExpressedType(),
106 builder.getF64FloatAttr(type_range_min),
107 builder.getF64FloatAttr(type_range_max),
108 builder.getI32IntegerAttr(
109 output_element_type.getStorageTypeIntegralWidth()),
110 0, true /* signed */, builder.getBoolAttr(narrow_range))));
111
112 Type dst_dense_element_type = builder.getIntegerType(8);
113 llvm::function_ref<APInt(const APInt &)> mapping =
114 [](const APInt &in) -> APInt {
115 int64_t in_i64 = in.getLimitedValue();
116 int64_t out_i64 = in_i64 - 128;
117 return APInt(8, out_i64, true);
118 };
119
120 auto dst_dense_attr =
121 src_dense_attr.mapValues(dst_dense_element_type, mapping);
122
123 builder.replaceOpWithNewOp<TFL::QConstOp>(op, dst_qconst_type,
124 dst_dense_attr);
125
126 return success();
127 }
128 };
129
convert_graph_uint8_tensor(mlir::MLIRContext & context,mlir::func::FuncOp & function)130 LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context,
131 mlir::func::FuncOp &function) {
132 size_t num_blocks_in_main = 0;
133 mlir::Region *region = function.getCallableRegion();
134 OpBuilder builder(&context);
135
136 auto tmp_const_type = RankedTensorType::get({1}, builder.getIntegerType(8));
137 auto tmp_const_attr =
138 DenseElementsAttr::get(tmp_const_type, {static_cast<uint8_t>(0)});
139
140 for (mlir::Block &bb : region->getBlocks()) {
141 // Always have one block for each region right now.
142 num_blocks_in_main++;
143 if (num_blocks_in_main > 1) {
144 return function.emitError("Invalid MLIR: multiple blocks in a region");
145 }
146
147 if (!bb.isEntryBlock()) {
148 return function.emitError("Invalid MLIR: block must be entry block");
149 }
150
151 // Insert rescale uint8->int8 after placeholders.
152 for (Value arg : bb.getArguments()) {
153 auto uint8_type = arg.getType().dyn_cast<mlir::ShapedType>();
154 if (!uint8_type) continue;
155
156 auto uint8_element_type =
157 uint8_type.getElementType()
158 .dyn_cast<mlir::quant::UniformQuantizedType>();
159 if (!uint8_element_type) continue;
160
161 if (uint8_element_type.isSigned() ||
162 uint8_element_type.getStorageTypeIntegralWidth() != 8)
163 continue;
164
165 double type_range_min =
166 static_cast<double>(uint8_element_type.getStorageTypeMin() -
167 uint8_element_type.getZeroPoint()) *
168 uint8_element_type.getScale();
169 double type_range_max =
170 static_cast<double>(uint8_element_type.getStorageTypeMax() -
171 uint8_element_type.getZeroPoint()) *
172 uint8_element_type.getScale();
173 bool narrow_range =
174 uint8_element_type.getStorageTypeMin() == 1 ? true : false;
175
176 Type int8_type = uint8_type.clone(buildQTypeFromMinMax(
177 builder, uint8_element_type.getExpressedType(),
178 builder.getF64FloatAttr(type_range_min),
179 builder.getF64FloatAttr(type_range_max),
180 builder.getI32IntegerAttr(
181 uint8_element_type.getStorageTypeIntegralWidth()),
182 0, true /* signed */, builder.getBoolAttr(narrow_range)));
183
184 int32_t uint8_zp = uint8_element_type.getZeroPoint();
185 int32_t int8_zp = uint8_zp - 128;
186
187 // Keep original input_val use with tmp_val.
188 Value tmp_val = builder.create<TFL::ConstOp>(
189 function.getLoc(), tmp_const_type, tmp_const_attr);
190 arg.replaceAllUsesWith(tmp_val);
191 auto rescale_op = builder.create<tosa::RescaleOp>(
192 function.getLoc(), int8_type, arg,
193 builder.getI32IntegerAttr(uint8_zp),
194 builder.getI32IntegerAttr(int8_zp),
195 builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
196 builder.getBoolAttr(true), builder.getBoolAttr(false),
197 builder.getBoolAttr(false));
198
199 Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
200 bb.push_front(op_rescale_op);
201 tmp_val.replaceAllUsesWith(rescale_op.getResult());
202 tmp_val.getDefiningOp()->erase();
203 }
204
205 // Record types of original graph output before we convert intermediate
206 // tensor.
207 auto terminator = bb.getTerminator();
208 SmallVector<Type, 4> output_types;
209 for (Value val : terminator->getOperands()) {
210 output_types.push_back(val.getType());
211 }
212
213 // Convert intermediate tensor.
214 for (auto &op : bb) {
215 for (Value output_val : op.getResults()) {
216 // Skip if output value is not RankedTensorType.
217 auto output_type = output_val.getType().dyn_cast<mlir::ShapedType>();
218 if (!output_type) continue;
219
220 // Skip if output value is not per-tensor quantized element type.
221 auto output_element_type =
222 output_type.getElementType()
223 .dyn_cast<mlir::quant::UniformQuantizedType>();
224 if (!output_element_type) continue;
225
226 // Skip if output is not uint8.
227 if (output_element_type.isSigned() ||
228 output_element_type.getStorageTypeIntegralWidth() != 8)
229 continue;
230
231 double type_range_min =
232 static_cast<double>(output_element_type.getStorageTypeMin() -
233 output_element_type.getZeroPoint()) *
234 output_element_type.getScale();
235 double type_range_max =
236 static_cast<double>(output_element_type.getStorageTypeMax() -
237 output_element_type.getZeroPoint()) *
238 output_element_type.getScale();
239 bool narrow_range =
240 output_element_type.getStorageTypeMin() == 1 ? true : false;
241
242 Type new_type = output_type.clone(buildQTypeFromMinMax(
243 builder, output_element_type.getExpressedType(),
244 builder.getF64FloatAttr(type_range_min),
245 builder.getF64FloatAttr(type_range_max),
246 builder.getI32IntegerAttr(
247 output_element_type.getStorageTypeIntegralWidth()),
248 0, true /* signed */, builder.getBoolAttr(narrow_range)));
249
250 output_val.setType(new_type);
251 }
252 }
253
254 if (terminator->getNumOperands() != output_types.size()) {
255 return function.emitError(
256 "Terminator's operand mismatch with number of outputs in graph");
257 }
258
259 // Insert int8->uint8 rescale before all terminator's operand.
260 for (int32_t i = 0; i < terminator->getNumOperands(); i++) {
261 auto defining_op = terminator->getOperand(i).getDefiningOp();
262 // skip if operand of terminator is block arg (nullptr in this case) or
263 // not
264 if (!defining_op) continue;
265 Value input_val = defining_op->getResult(0);
266
267 // Check if graph output is uint8 type.
268 auto uint8_output_type = output_types[i].dyn_cast<mlir::ShapedType>();
269 if (!uint8_output_type) continue;
270
271 auto uint8_output_element_type =
272 uint8_output_type.getElementType()
273 .dyn_cast<mlir::quant::UniformQuantizedType>();
274 if (!uint8_output_element_type) continue;
275
276 if (uint8_output_element_type.isSigned() ||
277 uint8_output_element_type.getStorageTypeIntegralWidth() != 8)
278 continue;
279
280 // Check if output coming into terminator is int8 type.
281 auto int8_output_type =
282 terminator->getOperand(i).getType().dyn_cast<mlir::ShapedType>();
283 if (!int8_output_type) continue;
284
285 auto int8_output_element_type =
286 int8_output_type.getElementType()
287 .dyn_cast<mlir::quant::UniformQuantizedType>();
288 if (!int8_output_element_type) continue;
289
290 if (!int8_output_element_type.isSigned() ||
291 int8_output_element_type.getStorageTypeIntegralWidth() != 8)
292 continue;
293
294 int32_t int8_zp = int8_output_element_type.getZeroPoint();
295 int32_t uint8_zp = uint8_output_element_type.getZeroPoint();
296
297 // Sanity check if uint8/int8's scale and zeropoint match.
298 if (((uint8_zp - int8_zp) != 128) ||
299 (int8_output_element_type.getScale() !=
300 uint8_output_element_type.getScale())) {
301 return terminator->emitError(
302 "convert_uint8_to_int8: scale mismatch at the output tensors");
303 }
304
305 // Keep original input_val use with tmp_val.
306 Value tmp_val = builder.create<TFL::ConstOp>(
307 function.getLoc(), tmp_const_type, tmp_const_attr);
308 input_val.replaceAllUsesWith(tmp_val);
309 auto rescale_op = builder.create<tosa::RescaleOp>(
310 function.getLoc(), uint8_output_type, input_val,
311 builder.getI32IntegerAttr(int8_zp),
312 builder.getI32IntegerAttr(uint8_zp),
313 builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
314 builder.getBoolAttr(true), builder.getBoolAttr(false),
315 builder.getBoolAttr(false));
316
317 Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
318 bb.push_back(op_rescale_op);
319 op_rescale_op->moveBefore(terminator);
320 tmp_val.replaceAllUsesWith(rescale_op.getResult());
321 tmp_val.getDefiningOp()->erase();
322 }
323 }
324
325 return success();
326 }
327
runOnOperation()328 void ConvertUint8ToInt8::runOnOperation() {
329 RewritePatternSet patterns(&getContext());
330 auto &ctx = getContext();
331 mlir::func::FuncOp func = getOperation();
332
333 // Convert uint8 const tensor. const needs to be handled specifically.
334 patterns.add<ConvertUint8QConstOp>(&ctx);
335 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
336
337 // Replace uint8 tensor in the graph and insert rescale as needed.
338 (void)convert_graph_uint8_tensor(ctx, func);
339 }
340
341 } // anonymous namespace
342
createConvertTFLUint8Pass()343 std::unique_ptr<OperationPass<func::FuncOp>> createConvertTFLUint8Pass() {
344 return std::make_unique<ConvertUint8ToInt8>();
345 }
346
347 } // namespace tosa
348
349 } // namespace mlir
350