xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
17 // input and output tensors. This is needed because TOSA precision is
18 // implemented in the int8 domain. This pass does:
19 // 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
20 // remapped.
21 // 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
22 // is uint8 typed.
23 // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
24 // typed.
25 
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31 
32 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
33 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
34 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"  // from @llvm-project
35 #include "mlir/IR/Builders.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
38 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
44 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
45 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
46 
47 #define PASS_NAME "tosa-convert-tfl-uint8"
48 #define DEBUG_TYPE PASS_NAME
49 
50 namespace mlir {
51 namespace tosa {
52 namespace {
53 
54 // Performs lowering to TOSA dialect.
55 class ConvertUint8ToInt8
56     : public TosaConvertTFLUint8PassBase<ConvertUint8ToInt8> {
57  public:
ConvertUint8ToInt8()58   explicit ConvertUint8ToInt8() {}
59   void runOnOperation() override;
60 };
61 
62 struct ConvertUint8QConstOp : public RewritePattern {
ConvertUint8QConstOpmlir::tosa::__anon97a43a8d0111::ConvertUint8QConstOp63   explicit ConvertUint8QConstOp(MLIRContext *context)
64       : RewritePattern(TFL::QConstOp::getOperationName(), 1, context) {}
65 
matchAndRewritemlir::tosa::__anon97a43a8d0111::ConvertUint8QConstOp66   LogicalResult matchAndRewrite(Operation *op,
67                                 PatternRewriter &builder) const override {
68     auto tfl_qconst_op = cast<TFL::QConstOp>(op);
69 
70     // Skip if it's not ranked tensor type.
71     auto output_type =
72         tfl_qconst_op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
73     if (!output_type)
74       return builder.notifyMatchFailure(op, "not ranked tensor");
75 
76     // Skip if output is not per-tensor quantized type.
77     auto output_element_type =
78         output_type.getElementType()
79             .dyn_cast<mlir::quant::UniformQuantizedType>();
80     if (!output_element_type) return failure();
81 
82     // Skip if output is not uint8.
83     if (output_element_type.isSigned() ||
84         output_element_type.getStorageTypeIntegralWidth() != 8) {
85       return failure();
86     }
87 
88     mlir::DenseElementsAttr src_dense_attr =
89         tfl_qconst_op.value().cast<DenseElementsAttr>();
90 
91     double type_range_min =
92         static_cast<double>(output_element_type.getStorageTypeMin() -
93                             output_element_type.getZeroPoint()) *
94         output_element_type.getScale();
95     double type_range_max =
96         static_cast<double>(output_element_type.getStorageTypeMax() -
97                             output_element_type.getZeroPoint()) *
98         output_element_type.getScale();
99     bool narrow_range =
100         output_element_type.getStorageTypeMin() == 1 ? true : false;
101 
102     auto dst_qconst_type = TypeAttr::get(RankedTensorType::get(
103         output_type.getShape(),
104         buildQTypeFromMinMax(
105             builder, output_element_type.getExpressedType(),
106             builder.getF64FloatAttr(type_range_min),
107             builder.getF64FloatAttr(type_range_max),
108             builder.getI32IntegerAttr(
109                 output_element_type.getStorageTypeIntegralWidth()),
110             0, true /* signed */, builder.getBoolAttr(narrow_range))));
111 
112     Type dst_dense_element_type = builder.getIntegerType(8);
113     llvm::function_ref<APInt(const APInt &)> mapping =
114         [](const APInt &in) -> APInt {
115       int64_t in_i64 = in.getLimitedValue();
116       int64_t out_i64 = in_i64 - 128;
117       return APInt(8, out_i64, true);
118     };
119 
120     auto dst_dense_attr =
121         src_dense_attr.mapValues(dst_dense_element_type, mapping);
122 
123     builder.replaceOpWithNewOp<TFL::QConstOp>(op, dst_qconst_type,
124                                               dst_dense_attr);
125 
126     return success();
127   }
128 };
129 
convert_graph_uint8_tensor(mlir::MLIRContext & context,mlir::func::FuncOp & function)130 LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context,
131                                          mlir::func::FuncOp &function) {
132   size_t num_blocks_in_main = 0;
133   mlir::Region *region = function.getCallableRegion();
134   OpBuilder builder(&context);
135 
136   auto tmp_const_type = RankedTensorType::get({1}, builder.getIntegerType(8));
137   auto tmp_const_attr =
138       DenseElementsAttr::get(tmp_const_type, {static_cast<uint8_t>(0)});
139 
140   for (mlir::Block &bb : region->getBlocks()) {
141     // Always have one block for each region right now.
142     num_blocks_in_main++;
143     if (num_blocks_in_main > 1) {
144       return function.emitError("Invalid MLIR: multiple blocks in a region");
145     }
146 
147     if (!bb.isEntryBlock()) {
148       return function.emitError("Invalid MLIR: block must be entry block");
149     }
150 
151     // Insert rescale uint8->int8 after placeholders.
152     for (Value arg : bb.getArguments()) {
153       auto uint8_type = arg.getType().dyn_cast<mlir::ShapedType>();
154       if (!uint8_type) continue;
155 
156       auto uint8_element_type =
157           uint8_type.getElementType()
158               .dyn_cast<mlir::quant::UniformQuantizedType>();
159       if (!uint8_element_type) continue;
160 
161       if (uint8_element_type.isSigned() ||
162           uint8_element_type.getStorageTypeIntegralWidth() != 8)
163         continue;
164 
165       double type_range_min =
166           static_cast<double>(uint8_element_type.getStorageTypeMin() -
167                               uint8_element_type.getZeroPoint()) *
168           uint8_element_type.getScale();
169       double type_range_max =
170           static_cast<double>(uint8_element_type.getStorageTypeMax() -
171                               uint8_element_type.getZeroPoint()) *
172           uint8_element_type.getScale();
173       bool narrow_range =
174           uint8_element_type.getStorageTypeMin() == 1 ? true : false;
175 
176       Type int8_type = uint8_type.clone(buildQTypeFromMinMax(
177           builder, uint8_element_type.getExpressedType(),
178           builder.getF64FloatAttr(type_range_min),
179           builder.getF64FloatAttr(type_range_max),
180           builder.getI32IntegerAttr(
181               uint8_element_type.getStorageTypeIntegralWidth()),
182           0, true /* signed */, builder.getBoolAttr(narrow_range)));
183 
184       int32_t uint8_zp = uint8_element_type.getZeroPoint();
185       int32_t int8_zp = uint8_zp - 128;
186 
187       // Keep original input_val use with tmp_val.
188       Value tmp_val = builder.create<TFL::ConstOp>(
189           function.getLoc(), tmp_const_type, tmp_const_attr);
190       arg.replaceAllUsesWith(tmp_val);
191       auto rescale_op = builder.create<tosa::RescaleOp>(
192           function.getLoc(), int8_type, arg,
193           builder.getI32IntegerAttr(uint8_zp),
194           builder.getI32IntegerAttr(int8_zp),
195           builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
196           builder.getBoolAttr(true), builder.getBoolAttr(false),
197           builder.getBoolAttr(false));
198 
199       Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
200       bb.push_front(op_rescale_op);
201       tmp_val.replaceAllUsesWith(rescale_op.getResult());
202       tmp_val.getDefiningOp()->erase();
203     }
204 
205     // Record types of original graph output before we convert intermediate
206     // tensor.
207     auto terminator = bb.getTerminator();
208     SmallVector<Type, 4> output_types;
209     for (Value val : terminator->getOperands()) {
210       output_types.push_back(val.getType());
211     }
212 
213     // Convert intermediate tensor.
214     for (auto &op : bb) {
215       for (Value output_val : op.getResults()) {
216         // Skip if output value is not RankedTensorType.
217         auto output_type = output_val.getType().dyn_cast<mlir::ShapedType>();
218         if (!output_type) continue;
219 
220         // Skip if output value is not per-tensor quantized element type.
221         auto output_element_type =
222             output_type.getElementType()
223                 .dyn_cast<mlir::quant::UniformQuantizedType>();
224         if (!output_element_type) continue;
225 
226         // Skip if output is not uint8.
227         if (output_element_type.isSigned() ||
228             output_element_type.getStorageTypeIntegralWidth() != 8)
229           continue;
230 
231         double type_range_min =
232             static_cast<double>(output_element_type.getStorageTypeMin() -
233                                 output_element_type.getZeroPoint()) *
234             output_element_type.getScale();
235         double type_range_max =
236             static_cast<double>(output_element_type.getStorageTypeMax() -
237                                 output_element_type.getZeroPoint()) *
238             output_element_type.getScale();
239         bool narrow_range =
240             output_element_type.getStorageTypeMin() == 1 ? true : false;
241 
242         Type new_type = output_type.clone(buildQTypeFromMinMax(
243             builder, output_element_type.getExpressedType(),
244             builder.getF64FloatAttr(type_range_min),
245             builder.getF64FloatAttr(type_range_max),
246             builder.getI32IntegerAttr(
247                 output_element_type.getStorageTypeIntegralWidth()),
248             0, true /* signed */, builder.getBoolAttr(narrow_range)));
249 
250         output_val.setType(new_type);
251       }
252     }
253 
254     if (terminator->getNumOperands() != output_types.size()) {
255       return function.emitError(
256           "Terminator's operand mismatch with number of outputs in graph");
257     }
258 
259     // Insert int8->uint8 rescale before all terminator's operand.
260     for (int32_t i = 0; i < terminator->getNumOperands(); i++) {
261       auto defining_op = terminator->getOperand(i).getDefiningOp();
262       // skip if operand of terminator is block arg (nullptr in this case) or
263       // not
264       if (!defining_op) continue;
265       Value input_val = defining_op->getResult(0);
266 
267       // Check if graph output is uint8 type.
268       auto uint8_output_type = output_types[i].dyn_cast<mlir::ShapedType>();
269       if (!uint8_output_type) continue;
270 
271       auto uint8_output_element_type =
272           uint8_output_type.getElementType()
273               .dyn_cast<mlir::quant::UniformQuantizedType>();
274       if (!uint8_output_element_type) continue;
275 
276       if (uint8_output_element_type.isSigned() ||
277           uint8_output_element_type.getStorageTypeIntegralWidth() != 8)
278         continue;
279 
280       // Check if output coming into terminator is int8 type.
281       auto int8_output_type =
282           terminator->getOperand(i).getType().dyn_cast<mlir::ShapedType>();
283       if (!int8_output_type) continue;
284 
285       auto int8_output_element_type =
286           int8_output_type.getElementType()
287               .dyn_cast<mlir::quant::UniformQuantizedType>();
288       if (!int8_output_element_type) continue;
289 
290       if (!int8_output_element_type.isSigned() ||
291           int8_output_element_type.getStorageTypeIntegralWidth() != 8)
292         continue;
293 
294       int32_t int8_zp = int8_output_element_type.getZeroPoint();
295       int32_t uint8_zp = uint8_output_element_type.getZeroPoint();
296 
297       // Sanity check if uint8/int8's scale and zeropoint match.
298       if (((uint8_zp - int8_zp) != 128) ||
299           (int8_output_element_type.getScale() !=
300            uint8_output_element_type.getScale())) {
301         return terminator->emitError(
302             "convert_uint8_to_int8: scale mismatch at the output tensors");
303       }
304 
305       // Keep original input_val use with tmp_val.
306       Value tmp_val = builder.create<TFL::ConstOp>(
307           function.getLoc(), tmp_const_type, tmp_const_attr);
308       input_val.replaceAllUsesWith(tmp_val);
309       auto rescale_op = builder.create<tosa::RescaleOp>(
310           function.getLoc(), uint8_output_type, input_val,
311           builder.getI32IntegerAttr(int8_zp),
312           builder.getI32IntegerAttr(uint8_zp),
313           builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
314           builder.getBoolAttr(true), builder.getBoolAttr(false),
315           builder.getBoolAttr(false));
316 
317       Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
318       bb.push_back(op_rescale_op);
319       op_rescale_op->moveBefore(terminator);
320       tmp_val.replaceAllUsesWith(rescale_op.getResult());
321       tmp_val.getDefiningOp()->erase();
322     }
323   }
324 
325   return success();
326 }
327 
runOnOperation()328 void ConvertUint8ToInt8::runOnOperation() {
329   RewritePatternSet patterns(&getContext());
330   auto &ctx = getContext();
331   mlir::func::FuncOp func = getOperation();
332 
333   // Convert uint8 const tensor. const needs to be handled specifically.
334   patterns.add<ConvertUint8QConstOp>(&ctx);
335   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
336 
337   // Replace uint8 tensor in the graph and insert rescale as needed.
338   (void)convert_graph_uint8_tensor(ctx, func);
339 }
340 
341 }  // anonymous namespace
342 
createConvertTFLUint8Pass()343 std::unique_ptr<OperationPass<func::FuncOp>> createConvertTFLUint8Pass() {
344   return std::make_unique<ConvertUint8ToInt8>();
345 }
346 
347 }  // namespace tosa
348 
349 }  // namespace mlir
350