1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "flatbuffers/flexbuffers.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/lite/context.h"
29 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/schema/schema_generated.h"
32 #include "tensorflow/lite/schema/schema_utils.h"
33 #include "tensorflow/lite/tools/optimize/model_utils.h"
34 #include "tensorflow/lite/tools/optimize/quantization_utils.h"
35
36 namespace tflite {
37 namespace optimize {
38
39 namespace {
40
41 struct ConsumerOpInfo {
42 OperatorT* op;
43 // The index of the op in the operators vector.
44 int32_t op_idx;
45 // The index of the tensor to quantize in subgraph->tensors.
46 int32_t op_input_idx;
47 };
48
49 struct TensorPerChannel {
50 TensorT* t;
51 bool is_per_channel;
52 int channel_dim;
53 };
54
55 // The default minimum number of elements a weights array must have to be
56 // quantized by this transformation.
57 const int kWeightsMinNumElementsDefault = 1024;
58
59 // Convert the MLIR CustomOpMap from the TFlite CustomOpMap as their member
60 // variables differ.
ConstructMLIRCustomOpMap(mlir::lite::CustomOpMap & mlir_map,const CustomOpMap & tflite_map)61 void ConstructMLIRCustomOpMap(mlir::lite::CustomOpMap& mlir_map,
62 const CustomOpMap& tflite_map) {
63 for (const auto& entry : tflite_map) {
64 mlir_map[entry.first].quantizable_input_indices =
65 entry.second.quantizable_input_indices;
66 mlir_map[entry.first].is_weight_only = !entry.second.is_hybrid;
67 mlir_map[entry.first].no_side_effect = true;
68 }
69 }
70
71 // Gets the operators that consume tensor_idx.
GetTensorConsumers(const ModelT * model,const SubGraphT * subgraph,int32_t tensor_idx)72 std::vector<ConsumerOpInfo> GetTensorConsumers(const ModelT* model,
73 const SubGraphT* subgraph,
74 int32_t tensor_idx) {
75 // TODO(suharshs): If this proves to be too slow, avoid calling it per tensor,
76 // instead doing one sweep for the entire model.
77 std::vector<ConsumerOpInfo> consumer_ops;
78 for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
79 OperatorT* op = subgraph->operators[op_idx].get();
80 if (op == nullptr) {
81 continue;
82 }
83 for (size_t i = 0; i < op->inputs.size(); ++i) {
84 if (op->inputs[i] == tensor_idx) {
85 consumer_ops.push_back(
86 {op, static_cast<int32_t>(op_idx), static_cast<int32_t>(i)});
87 }
88 }
89 }
90 return consumer_ops;
91 }
92
93 // Gets the list of op->inputs indices of the weights inputs to be quantized for
94 // the provided op.
GetWeightInputIndices(const OperatorCodeT * op_code,const CustomOpMap & custom_op_map)95 std::vector<int32_t> GetWeightInputIndices(const OperatorCodeT* op_code,
96 const CustomOpMap& custom_op_map) {
97 const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code);
98 if (builtin_op_code == BuiltinOperator_CUSTOM) {
99 const std::string custom_code = op_code->custom_code;
100 const auto& custom_op_info = custom_op_map.find(custom_code);
101 if (custom_op_info != custom_op_map.end()) {
102 return custom_op_info->second.quantizable_input_indices;
103 }
104 } else if (builtin_op_code == BuiltinOperator_CONV_2D ||
105 builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
106 builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
107 builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
108 builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
109 builtin_op_code == BuiltinOperator_TRANSPOSE_CONV) {
110 return {1};
111 } else if (builtin_op_code == BuiltinOperator_SVDF) {
112 // https://www.tensorflow.org/code/tensorflow/lite/kernels/svdf.cc
113 return {1, 2};
114 } else if (builtin_op_code == BuiltinOperator_LSTM ||
115 builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
116 // https://www.tensorflow.org/code/tensorflow/lite/kernels/lstm.cc
117 // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc
118 return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16};
119 } else if (builtin_op_code == BuiltinOperator_RNN ||
120 builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
121 // https://www.tensorflow.org/code/tensorflow/lite/kernels/basic_rnn.cc
122 // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc
123 return {1, 2};
124 } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) {
125 // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc
126 return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 18, 19, 20, 21,
127 22, 23, 24, 25, 26, 27, 28, 33, 40, 41, 42, 43, 44, 45, 46, 47};
128 } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
129 // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
130 return {1, 2, 4, 5, 6, 8, 9, 10, 11};
131 } else if (builtin_op_code == BuiltinOperator_GATHER) {
132 // https://www.tensorflow.org/code/tensorflow/lite/kernels/gather.cc
133 return {0};
134 }
135 return {};
136 }
137
138 // Checks that a specific input can be quantized.
IsQuantizedInput(const OperatorCodeT * op_code,const CustomOpMap & custom_op_map,int op_input_idx)139 bool IsQuantizedInput(const OperatorCodeT* op_code,
140 const CustomOpMap& custom_op_map, int op_input_idx) {
141 const auto quantized_input_indices =
142 GetWeightInputIndices(op_code, custom_op_map);
143 return std::find(std::begin(quantized_input_indices),
144 std::end(quantized_input_indices),
145 op_input_idx) != std::end(quantized_input_indices);
146 }
147
148 // Returns true if the operator supports hybrid evaluation.
IsHybridEvaluationOp(const OperatorT * op,const OperatorCodeT * op_code,const CustomOpMap & custom_op_map,bool use_updated_hybrid_scheme)149 bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
150 const CustomOpMap& custom_op_map,
151 bool use_updated_hybrid_scheme) {
152 const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code);
153 // Operations that support hybrid evaluation.
154 bool eval_hybrid = false;
155 if (builtin_op_code == BuiltinOperator_CUSTOM) {
156 const std::string custom_code = op_code->custom_code;
157 const auto custom_op_info = custom_op_map.find(custom_code);
158 if (custom_op_info == custom_op_map.end()) {
159 return {};
160 } else {
161 return custom_op_info->second.is_hybrid;
162 }
163 } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
164 builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
165 builtin_op_code == BuiltinOperator_CONV_2D ||
166 builtin_op_code == BuiltinOperator_SVDF ||
167 builtin_op_code == BuiltinOperator_RNN ||
168 builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
169 builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
170 builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
171 builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
172 eval_hybrid = true;
173 } else if (builtin_op_code == BuiltinOperator_LSTM) {
174 const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions();
175 // Only lstm kernel_type full supports hybrid evaluation.
176 if (options->kernel_type == LSTMKernelType_FULL) {
177 eval_hybrid = true;
178 }
179 } else if (builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
180 eval_hybrid = use_updated_hybrid_scheme;
181 }
182 return eval_hybrid;
183 }
184
185 // Returns true if all of the op's inputs are quantized.
CheckAllOpInputsQuantized(const SubGraphT * subgraph,const OperatorT * op,const OperatorCodeT * op_code,const CustomOpMap & custom_op_map)186 bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op,
187 const OperatorCodeT* op_code,
188 const CustomOpMap& custom_op_map) {
189 std::vector<int32_t> op_input_indices =
190 GetWeightInputIndices(op_code, custom_op_map);
191 for (const int32_t op_input_idx : op_input_indices) {
192 int32_t tensor_idx = op->inputs[op_input_idx];
193
194 if (tensor_idx == -1) {
195 // Optional tensor.
196 continue;
197 }
198
199 TensorT* tensor = subgraph->tensors[tensor_idx].get();
200
201 if (tensor->type != TensorType_INT8) {
202 return false;
203 }
204 }
205 return true;
206 }
207
208 // Inserts Tensors for each input tensor of op that should be
209 // quantized into tensor_map.
InsertQuantizableInputTensorsFromOperator(const ModelT * model,OperatorT * op,uint64_t weights_min_num_elements,const CustomOpMap & custom_op_map,absl::flat_hash_map<int32_t,TensorPerChannel> * tensor_map,int subgraph_index,bool use_updated_hybrid_scheme)210 TfLiteStatus InsertQuantizableInputTensorsFromOperator(
211 const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements,
212 const CustomOpMap& custom_op_map,
213 absl::flat_hash_map<int32_t, TensorPerChannel>* tensor_map,
214 int subgraph_index, bool use_updated_hybrid_scheme) {
215 SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get();
216 const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
217 auto builtin_code = GetBuiltinCode(op_code);
218
219 std::vector<int32_t> op_input_indices =
220 GetWeightInputIndices(op_code, custom_op_map);
221 for (const int32_t op_input_idx : op_input_indices) {
222 int32_t tensor_idx = op->inputs[op_input_idx];
223 if (tensor_idx == -1) {
224 LOG(INFO) << "Skipping optional tensor input " << op_input_idx
225 << " of operation " << EnumNameBuiltinOperator(builtin_code);
226 continue;
227 }
228
229 TensorT* tensor = subgraph->tensors[tensor_idx].get();
230 if (tensor->type != TensorType_FLOAT32) {
231 LOG(INFO) << "Skipping quantization of tensor " << tensor->name
232 << " that is not type float.";
233 continue;
234 }
235
236 uint64_t num_elements;
237 TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
238 if (num_elements < weights_min_num_elements) {
239 LOG(INFO) << "Skipping quantization of tensor " << tensor->name
240 << " because it has fewer than " << weights_min_num_elements
241 << " elements (" << num_elements << ").";
242 continue;
243 }
244
245 // Some tensors may have a null buffer vector, indicating an intermediate
246 // array.
247 if (model->buffers[tensor->buffer]->data.data() == nullptr) {
248 LOG(INFO) << "Skipping quantization of tensor " << tensor->name
249 << " because it has no allocated buffer.";
250 continue;
251 }
252
253 if (builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
254 tensor_map->insert({tensor_idx,
255 {tensor, /*is_per_channel=*/use_updated_hybrid_scheme,
256 /*dim=*/3}});
257 } else if (builtin_code == BuiltinOperator_CONV_2D) {
258 tensor_map->insert({tensor_idx,
259 {tensor, /*is_per_channel=*/use_updated_hybrid_scheme,
260 /*dim=*/0}});
261 } else {
262 switch (builtin_code) {
263 case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
264 op->builtin_options.AsBidirectionalSequenceLSTMOptions()
265 ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
266 break;
267 case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
268 op->builtin_options.AsBidirectionalSequenceRNNOptions()
269 ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
270 break;
271 case BuiltinOperator_FULLY_CONNECTED:
272 op->builtin_options.AsFullyConnectedOptions()
273 ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
274 break;
275 case BuiltinOperator_BATCH_MATMUL:
276 op->builtin_options.AsBatchMatMulOptions()
277 ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
278 break;
279 case BuiltinOperator_LSTM:
280 op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs =
281 use_updated_hybrid_scheme;
282 break;
283 case BuiltinOperator_RNN:
284 op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs =
285 use_updated_hybrid_scheme;
286 break;
287 case BuiltinOperator_SVDF:
288 op->builtin_options.AsSVDFOptions()->asymmetric_quantize_inputs =
289 use_updated_hybrid_scheme;
290 break;
291 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
292 op->builtin_options.AsUnidirectionalSequenceLSTMOptions()
293 ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
294 break;
295 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
296 op->builtin_options.AsSequenceRNNOptions()
297 ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
298 break;
299 default:
300 break;
301 }
302 tensor_map->insert({tensor_idx, {tensor, /*is_per_channel=*/false}});
303 }
304 }
305
306 return kTfLiteOk;
307 }
308
309 // Updates operator code versions for the operators with INT8 inputs.
UpdateInt8OperatorVersions(ModelT * model,bool use_updated_hybrid_scheme)310 void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) {
311 for (int i = 0, end = model->operator_codes.size(); i < end; ++i) {
312 const BuiltinOperator& op_code =
313 GetBuiltinCode(model->operator_codes[i].get());
314 if (op_code == BuiltinOperator_RNN ||
315 op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
316 op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
317 op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
318 model->operator_codes[i]->version = use_updated_hybrid_scheme ? 3 : 2;
319 } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
320 op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
321 model->operator_codes[i]->version = 3;
322 } else if (op_code == BuiltinOperator_LSTM) {
323 model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 3;
324 } else if (op_code == BuiltinOperator_CONV_2D) {
325 model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2;
326 } else if (op_code == BuiltinOperator_FULLY_CONNECTED) {
327 model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3;
328 } else if (op_code == BuiltinOperator_BATCH_MATMUL) {
329 model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 1;
330 } else if (op_code == BuiltinOperator_SVDF) {
331 model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2;
332 } else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
333 model->operator_codes[i]->version = 6;
334 }
335 }
336 }
337
338 // Returns true if the op in consumer_op_infos can pass through quantization.
IsQuantizationPassThroughOps(const ModelT * model,const std::vector<ConsumerOpInfo> & consumer_op_infos)339 bool IsQuantizationPassThroughOps(
340 const ModelT* model, const std::vector<ConsumerOpInfo>& consumer_op_infos) {
341 if (consumer_op_infos.size() != 1) {
342 return false;
343 }
344 const OperatorT* consumer_op = consumer_op_infos.front().op;
345 const BuiltinOperator op_code =
346 GetBuiltinCode(model->operator_codes[consumer_op->opcode_index].get());
347 return op_code == BuiltinOperator_GATHER ||
348 op_code == BuiltinOperator_EMBEDDING_LOOKUP;
349 }
350
351 // Copies quantization parameters from input to output and returns consumers of
352 // the output tensor as a tuple with values:
353 // - index of the output tensor
354 // - pointer to the output tensor
355 // - vector of consumers ops.
356 std::tuple<int32_t, TensorT*, std::vector<ConsumerOpInfo>>
PassQuantizationAndGetConsumers(const ModelT * model,const SubGraphT * subgraph,const std::vector<ConsumerOpInfo> & consumer_op_infos,const CustomOpMap & custom_op_map)357 PassQuantizationAndGetConsumers(
358 const ModelT* model, const SubGraphT* subgraph,
359 const std::vector<ConsumerOpInfo>& consumer_op_infos,
360 const CustomOpMap& custom_op_map) {
361 const OperatorT* op = consumer_op_infos.front().op;
362 const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
363 if (op->outputs.size() != 1) {
364 LOG(ERROR)
365 << "An op that passes quantization has more than one quantized output";
366 return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
367 }
368 const int32_t output_tensor_idx = op->outputs.front();
369 const auto input_idx = GetWeightInputIndices(op_code, custom_op_map);
370 if (input_idx.size() != 1) {
371 LOG(ERROR)
372 << "An op that passes quantization has more than one quantized input";
373 return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
374 }
375 const int32_t input_tensor_idx = op->inputs[input_idx.front()];
376
377 // Propagate quantization params.
378 const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get();
379 TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get();
380 if (!output_tensor->quantization) {
381 output_tensor->quantization = std::make_unique<QuantizationParametersT>();
382 }
383 *output_tensor->quantization = *input_tensor->quantization;
384 output_tensor->type = TensorType_INT8;
385 return std::make_tuple(
386 output_tensor_idx, output_tensor,
387 GetTensorConsumers(model, subgraph, output_tensor_idx));
388 }
389
IsOpDenylisted(const flat_hash_set<BuiltinOperator> & op_denylist,const BuiltinOperator op_code)390 inline bool IsOpDenylisted(const flat_hash_set<BuiltinOperator>& op_denylist,
391 const BuiltinOperator op_code) {
392 return op_denylist.find(op_code) != op_denylist.end();
393 }
394
QuantizeWeightsInt8(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,bool use_hybrid_evaluation,uint64_t weights_min_num_elements,const CustomOpMap & custom_op_map,bool use_updated_hybrid_scheme,const flat_hash_set<BuiltinOperator> & op_denylist={})395 TfLiteStatus QuantizeWeightsInt8(
396 flatbuffers::FlatBufferBuilder* builder, const Model* input_model,
397 bool use_hybrid_evaluation, uint64_t weights_min_num_elements,
398 const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme,
399 const flat_hash_set<BuiltinOperator>& op_denylist = {}) {
400 std::unique_ptr<ModelT> model;
401 model.reset(input_model->UnPack());
402
403 for (int subgraph_index = 0, end = model->subgraphs.size();
404 subgraph_index < end; ++subgraph_index) {
405 SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get();
406
407 absl::flat_hash_map<int32_t, TensorPerChannel> tensor_map;
408 for (int i = 0; i < subgraph->operators.size(); ++i) {
409 OperatorT* op = subgraph->operators[i].get();
410 TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
411 model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map,
412 subgraph_index, use_updated_hybrid_scheme));
413 }
414
415 for (std::pair<int32_t, TensorPerChannel> tensor_pair : tensor_map) {
416 // Quantize the tensor.
417 if (tensor_pair.second.is_per_channel) {
418 TF_LITE_ENSURE_STATUS(utils::SymmetricQuantizeTensorPerChannel(
419 model.get(), tensor_pair.second.t, tensor_pair.second.channel_dim,
420 nullptr));
421 } else {
422 TF_LITE_ENSURE_STATUS(
423 utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second.t));
424 }
425 }
426
427 // Examine the tensor consumers to determine which require dequantize ops.
428 for (const auto& tensor_pair : tensor_map) {
429 int32_t tensor_idx = tensor_pair.first;
430 TensorT* tensor = tensor_pair.second.t;
431 std::vector<ConsumerOpInfo> consumer_op_infos =
432 GetTensorConsumers(model.get(), subgraph, tensor_idx);
433 if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) {
434 std::tie(tensor_idx, tensor, consumer_op_infos) =
435 PassQuantizationAndGetConsumers(model.get(), subgraph,
436 consumer_op_infos, custom_op_map);
437 if (tensor_idx < 0) {
438 // Error message is already logged by PassQuantizationAndGetConsumers.
439 return kTfLiteError;
440 }
441 }
442
443 std::vector<ConsumerOpInfo> dequant_op_infos; // Ops that need dequants.
444 for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) {
445 OperatorT* consumer_op = consumer_op_info.op;
446 const OperatorCodeT* consumer_op_code =
447 model->operator_codes[consumer_op->opcode_index].get();
448 // If the op is a hybrid op and all the required tensors are quantized,
449 // we have no further work to do, but for all ops that require
450 // dequantization we need to add a Dequantize op.
451 bool eval_hybrid =
452 use_hybrid_evaluation &&
453 !IsOpDenylisted(op_denylist, GetBuiltinCode(consumer_op_code)) &&
454 IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map,
455 use_updated_hybrid_scheme) &&
456 CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code,
457 custom_op_map) &&
458 IsQuantizedInput(consumer_op_code, custom_op_map,
459 consumer_op_info.op_input_idx);
460 if (!eval_hybrid) {
461 dequant_op_infos.push_back(consumer_op_info);
462 }
463 }
464
465 // Check if this tensor is an output tensor.
466 int32_t output_index = -1;
467 for (int32_t i = 0; i < subgraph->outputs.size(); ++i) {
468 if (subgraph->outputs[i] == tensor_idx) {
469 output_index = i;
470 break;
471 }
472 }
473
474 // If no ops require dequant and it is not output, we are done for this
475 // tensor.
476 if (dequant_op_infos.empty() && output_index < 0) {
477 continue;
478 }
479
480 // Create a new tensor to be the output of the dequantize op.
481 std::unique_ptr<TensorT> dequantize_output;
482 const string dequant_name = tensor->name + "_dequantize";
483 utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature,
484 TensorType_FLOAT32, &dequantize_output);
485 const int32_t dequantize_output_idx = subgraph->tensors.size();
486 subgraph->tensors.push_back(std::move(dequantize_output));
487
488 // Create the Dequantize operation.
489 std::unique_ptr<OperatorT> dequantize_op;
490 utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
491 dequantize_output_idx);
492
493 // Update the op_input of all the ops that need the created dequantize
494 // operation.
495 int32_t min_op_idx = subgraph->operators.size();
496 for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
497 dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
498 dequantize_output_idx;
499 min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
500 }
501 // Update output name.
502 if (output_index >= 0) {
503 subgraph->outputs[output_index] = dequantize_output_idx;
504 }
505
506 // Insert the newly created Dequantize operation before the earliest
507 // consumer, since TFLite requires operators to be topo-sorted.
508 subgraph->operators.insert(subgraph->operators.begin() + min_op_idx,
509 std::move(dequantize_op));
510 }
511 }
512
513 // Update the modified operator code versions.
514 UpdateInt8OperatorVersions(model.get(), use_updated_hybrid_scheme);
515
516 flatbuffers::Offset<Model> output_model_location =
517 Model::Pack(*builder, model.get());
518 FinishModelBuffer(*builder, output_model_location);
519
520 return kTfLiteOk;
521 }
522
QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder * builder,const Model * input_model)523 TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder,
524 const Model* input_model) {
525 std::unique_ptr<ModelT> model;
526 model.reset(input_model->UnPack());
527
528 for (int subgraph_index = 0, end = model->subgraphs.size();
529 subgraph_index < end; ++subgraph_index) {
530 SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get();
531
532 absl::flat_hash_map<int32_t, TensorT*> tensor_map;
533 for (int i = 0, sub_end = subgraph->operators.size(); i < sub_end; ++i) {
534 OperatorT* op = subgraph->operators[i].get();
535 for (auto tensor_idx : op->inputs) {
536 // Skip optional tensors.
537 if (tensor_idx == kTfLiteOptionalTensor) {
538 continue;
539 }
540 TensorT* tensor = subgraph->tensors[tensor_idx].get();
541 BufferT* buffer = model->buffers[tensor->buffer].get();
542 if (buffer == nullptr) {
543 return kTfLiteError;
544 }
545 // Quantize tensors that have data to quantize.
546 bool is_constant = !model->buffers[tensor->buffer].get()->data.empty();
547 if (tensor->type == TensorType_FLOAT32 && is_constant) {
548 tensor_map.insert({tensor_idx, tensor});
549 }
550 }
551 }
552
553 // The hash map ensures that we quantize each tensor exactly once.
554 for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
555 // Quantize the tensor.
556 TF_LITE_ENSURE_STATUS(
557 utils::QuantizeTensorFloat16(model.get(), tensor_pair.second));
558
559 int32_t tensor_idx = tensor_pair.first;
560 TensorT* tensor = tensor_pair.second;
561 std::vector<ConsumerOpInfo> dequant_op_infos =
562 GetTensorConsumers(model.get(), subgraph, tensor_idx);
563
564 // Create a new tensor to be the output of the dequantize op.
565 std::unique_ptr<TensorT> dequantize_output;
566 const string dequant_name = tensor->name + "_dequantize";
567 utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature,
568 TensorType_FLOAT32, &dequantize_output);
569 const int32_t dequantize_output_idx = subgraph->tensors.size();
570 subgraph->tensors.push_back(std::move(dequantize_output));
571
572 // Create the Dequantize operation.
573 std::unique_ptr<OperatorT> dequantize_op;
574 utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
575 dequantize_output_idx);
576
577 // Update the op_input of all the ops that need the created dequantize
578 // operation.
579 int32_t min_op_idx = subgraph->operators.size();
580 for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
581 dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
582 dequantize_output_idx;
583 min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
584 }
585
586 // Insert the newly created Dequantize operation before the earliest
587 // consumer, since TFLite requires operators to be topo-sorted.
588 subgraph->operators.insert(subgraph->operators.begin() + min_op_idx,
589 std::move(dequantize_op));
590 }
591 }
592
593 flatbuffers::Offset<Model> output_model_location =
594 Model::Pack(*builder, model.get());
595 FinishModelBuffer(*builder, output_model_location);
596 return kTfLiteOk;
597 }
598 } // namespace
599
600 namespace internal {
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,bool use_hybrid_evaluation,QuantizerType quantizer_type)601 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
602 const Model* input_model,
603 uint64_t weights_min_num_elements,
604 bool use_hybrid_evaluation,
605 QuantizerType quantizer_type) {
606 // By default we require that only weights with more than
607 // kWeightsMinSizeDefault elements are quantized.
608 if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
609 return mlir::lite::QuantizeWeights(
610 builder, input_model, weights_min_num_elements, use_hybrid_evaluation);
611 }
612 CustomOpMap custom_op_map;
613 return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation,
614 weights_min_num_elements, custom_op_map,
615 kUseUpdatedHybridSchemeDefault);
616 }
617 } // namespace internal
618
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,QuantizerType quantizer_type)619 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
620 const Model* input_model,
621 uint64_t weights_min_num_elements,
622 QuantizerType quantizer_type) {
623 if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
624 return mlir::lite::QuantizeWeights(builder, input_model,
625 weights_min_num_elements);
626 }
627 CustomOpMap custom_op_map;
628 return QuantizeWeightsInt8(builder, input_model, true,
629 weights_min_num_elements, custom_op_map,
630 kUseUpdatedHybridSchemeDefault);
631 }
632
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,BufferType quant_type,bool use_updated_hybrid_scheme,QuantizerType quantizer_type)633 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
634 const Model* input_model, BufferType quant_type,
635 bool use_updated_hybrid_scheme,
636 QuantizerType quantizer_type) {
637 // By default we require that only weights with more than
638 // kWeightsMinSizeDefault elements are quantized.
639 if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
640 return mlir::lite::QuantizeWeights(builder, input_model,
641 (mlir::lite::BufferType)quant_type,
642 use_updated_hybrid_scheme);
643 }
644 switch (quant_type) {
645 case BufferType::QUANTIZED_INT8: {
646 CustomOpMap custom_op_map;
647 return QuantizeWeightsInt8(builder, input_model, true,
648 kWeightsMinNumElementsDefault, custom_op_map,
649 use_updated_hybrid_scheme);
650 }
651 case BufferType::QUANTIZED_FLOAT16:
652 return QuantizeWeightsFloat16(builder, input_model);
653 }
654 }
655
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,const CustomOpMap & custom_op_map,QuantizerType quantizer_type)656 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
657 const Model* input_model,
658 uint64_t weights_min_num_elements,
659 const CustomOpMap& custom_op_map,
660 QuantizerType quantizer_type) {
661 if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
662 mlir::lite::CustomOpMap mlir_custom_op_map;
663 ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map);
664 return mlir::lite::QuantizeWeights(
665 builder, input_model, weights_min_num_elements, mlir_custom_op_map);
666 }
667 return QuantizeWeightsInt8(builder, input_model, true,
668 weights_min_num_elements, custom_op_map,
669 kUseUpdatedHybridSchemeDefault);
670 }
671
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,const CustomOpMap & custom_op_map,bool use_updated_hybrid_scheme,const flat_hash_set<BuiltinOperator> & op_denylist,QuantizerType quantizer_type)672 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
673 const Model* input_model,
674 uint64_t weights_min_num_elements,
675 const CustomOpMap& custom_op_map,
676 bool use_updated_hybrid_scheme,
677 const flat_hash_set<BuiltinOperator>& op_denylist,
678 QuantizerType quantizer_type) {
679 if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
680 mlir::lite::CustomOpMap mlir_custom_op_map;
681 ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map);
682 return mlir::lite::QuantizeWeights(
683 builder, input_model, weights_min_num_elements, mlir_custom_op_map,
684 use_updated_hybrid_scheme, op_denylist);
685 }
686 return QuantizeWeightsInt8(builder, input_model,
687 /*use_hybrid_evaluation=*/true,
688 weights_min_num_elements, custom_op_map,
689 use_updated_hybrid_scheme, op_denylist);
690 }
691
692 } // namespace optimize
693 } // namespace tflite
694