xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/optimize/quantize_weights.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/quantize_weights.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "flatbuffers/flexbuffers.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/lite/context.h"
29 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/schema/schema_generated.h"
32 #include "tensorflow/lite/schema/schema_utils.h"
33 #include "tensorflow/lite/tools/optimize/model_utils.h"
34 #include "tensorflow/lite/tools/optimize/quantization_utils.h"
35 
36 namespace tflite {
37 namespace optimize {
38 
39 namespace {
40 
41 struct ConsumerOpInfo {
42   OperatorT* op;
43   // The index of the op in the operators vector.
44   int32_t op_idx;
45   // The index of the tensor to quantize in subgraph->tensors.
46   int32_t op_input_idx;
47 };
48 
49 struct TensorPerChannel {
50   TensorT* t;
51   bool is_per_channel;
52   int channel_dim;
53 };
54 
55 // The default minimum number of elements a weights array must have to be
56 // quantized by this transformation.
57 const int kWeightsMinNumElementsDefault = 1024;
58 
59 // Convert the MLIR CustomOpMap from the TFlite CustomOpMap as their member
60 // variables differ.
ConstructMLIRCustomOpMap(mlir::lite::CustomOpMap & mlir_map,const CustomOpMap & tflite_map)61 void ConstructMLIRCustomOpMap(mlir::lite::CustomOpMap& mlir_map,
62                               const CustomOpMap& tflite_map) {
63   for (const auto& entry : tflite_map) {
64     mlir_map[entry.first].quantizable_input_indices =
65         entry.second.quantizable_input_indices;
66     mlir_map[entry.first].is_weight_only = !entry.second.is_hybrid;
67     mlir_map[entry.first].no_side_effect = true;
68   }
69 }
70 
71 // Gets the operators that consume tensor_idx.
GetTensorConsumers(const ModelT * model,const SubGraphT * subgraph,int32_t tensor_idx)72 std::vector<ConsumerOpInfo> GetTensorConsumers(const ModelT* model,
73                                                const SubGraphT* subgraph,
74                                                int32_t tensor_idx) {
75   // TODO(suharshs): If this proves to be too slow, avoid calling it per tensor,
76   // instead doing one sweep for the entire model.
77   std::vector<ConsumerOpInfo> consumer_ops;
78   for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
79     OperatorT* op = subgraph->operators[op_idx].get();
80     if (op == nullptr) {
81       continue;
82     }
83     for (size_t i = 0; i < op->inputs.size(); ++i) {
84       if (op->inputs[i] == tensor_idx) {
85         consumer_ops.push_back(
86             {op, static_cast<int32_t>(op_idx), static_cast<int32_t>(i)});
87       }
88     }
89   }
90   return consumer_ops;
91 }
92 
93 // Gets the list of op->inputs indices of the weights inputs to be quantized for
94 // the provided op.
GetWeightInputIndices(const OperatorCodeT * op_code,const CustomOpMap & custom_op_map)95 std::vector<int32_t> GetWeightInputIndices(const OperatorCodeT* op_code,
96                                            const CustomOpMap& custom_op_map) {
97   const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code);
98   if (builtin_op_code == BuiltinOperator_CUSTOM) {
99     const std::string custom_code = op_code->custom_code;
100     const auto& custom_op_info = custom_op_map.find(custom_code);
101     if (custom_op_info != custom_op_map.end()) {
102       return custom_op_info->second.quantizable_input_indices;
103     }
104   } else if (builtin_op_code == BuiltinOperator_CONV_2D ||
105              builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
106              builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
107              builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
108              builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
109              builtin_op_code == BuiltinOperator_TRANSPOSE_CONV) {
110     return {1};
111   } else if (builtin_op_code == BuiltinOperator_SVDF) {
112     // https://www.tensorflow.org/code/tensorflow/lite/kernels/svdf.cc
113     return {1, 2};
114   } else if (builtin_op_code == BuiltinOperator_LSTM ||
115              builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
116     // https://www.tensorflow.org/code/tensorflow/lite/kernels/lstm.cc
117     // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc
118     return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16};
119   } else if (builtin_op_code == BuiltinOperator_RNN ||
120              builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
121     // https://www.tensorflow.org/code/tensorflow/lite/kernels/basic_rnn.cc
122     // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc
123     return {1, 2};
124   } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) {
125     // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc
126     return {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 16, 18, 19, 20, 21,
127             22, 23, 24, 25, 26, 27, 28, 33, 40, 41, 42, 43, 44, 45, 46, 47};
128   } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
129     // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc
130     return {1, 2, 4, 5, 6, 8, 9, 10, 11};
131   } else if (builtin_op_code == BuiltinOperator_GATHER) {
132     // https://www.tensorflow.org/code/tensorflow/lite/kernels/gather.cc
133     return {0};
134   }
135   return {};
136 }
137 
138 // Checks that a specific input can be quantized.
IsQuantizedInput(const OperatorCodeT * op_code,const CustomOpMap & custom_op_map,int op_input_idx)139 bool IsQuantizedInput(const OperatorCodeT* op_code,
140                       const CustomOpMap& custom_op_map, int op_input_idx) {
141   const auto quantized_input_indices =
142       GetWeightInputIndices(op_code, custom_op_map);
143   return std::find(std::begin(quantized_input_indices),
144                    std::end(quantized_input_indices),
145                    op_input_idx) != std::end(quantized_input_indices);
146 }
147 
148 // Returns true if the operator supports hybrid evaluation.
IsHybridEvaluationOp(const OperatorT * op,const OperatorCodeT * op_code,const CustomOpMap & custom_op_map,bool use_updated_hybrid_scheme)149 bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
150                           const CustomOpMap& custom_op_map,
151                           bool use_updated_hybrid_scheme) {
152   const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code);
153   // Operations that support hybrid evaluation.
154   bool eval_hybrid = false;
155   if (builtin_op_code == BuiltinOperator_CUSTOM) {
156     const std::string custom_code = op_code->custom_code;
157     const auto custom_op_info = custom_op_map.find(custom_code);
158     if (custom_op_info == custom_op_map.end()) {
159       return {};
160     } else {
161       return custom_op_info->second.is_hybrid;
162     }
163   } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
164              builtin_op_code == BuiltinOperator_BATCH_MATMUL ||
165              builtin_op_code == BuiltinOperator_CONV_2D ||
166              builtin_op_code == BuiltinOperator_SVDF ||
167              builtin_op_code == BuiltinOperator_RNN ||
168              builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
169              builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
170              builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
171              builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
172     eval_hybrid = true;
173   } else if (builtin_op_code == BuiltinOperator_LSTM) {
174     const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions();
175     // Only lstm kernel_type full supports hybrid evaluation.
176     if (options->kernel_type == LSTMKernelType_FULL) {
177       eval_hybrid = true;
178     }
179   } else if (builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
180     eval_hybrid = use_updated_hybrid_scheme;
181   }
182   return eval_hybrid;
183 }
184 
185 // Returns true if all of the op's inputs are quantized.
CheckAllOpInputsQuantized(const SubGraphT * subgraph,const OperatorT * op,const OperatorCodeT * op_code,const CustomOpMap & custom_op_map)186 bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op,
187                                const OperatorCodeT* op_code,
188                                const CustomOpMap& custom_op_map) {
189   std::vector<int32_t> op_input_indices =
190       GetWeightInputIndices(op_code, custom_op_map);
191   for (const int32_t op_input_idx : op_input_indices) {
192     int32_t tensor_idx = op->inputs[op_input_idx];
193 
194     if (tensor_idx == -1) {
195       // Optional tensor.
196       continue;
197     }
198 
199     TensorT* tensor = subgraph->tensors[tensor_idx].get();
200 
201     if (tensor->type != TensorType_INT8) {
202       return false;
203     }
204   }
205   return true;
206 }
207 
208 // Inserts Tensors for each input tensor of op that should be
209 // quantized into tensor_map.
InsertQuantizableInputTensorsFromOperator(const ModelT * model,OperatorT * op,uint64_t weights_min_num_elements,const CustomOpMap & custom_op_map,absl::flat_hash_map<int32_t,TensorPerChannel> * tensor_map,int subgraph_index,bool use_updated_hybrid_scheme)210 TfLiteStatus InsertQuantizableInputTensorsFromOperator(
211     const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements,
212     const CustomOpMap& custom_op_map,
213     absl::flat_hash_map<int32_t, TensorPerChannel>* tensor_map,
214     int subgraph_index, bool use_updated_hybrid_scheme) {
215   SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get();
216   const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
217   auto builtin_code = GetBuiltinCode(op_code);
218 
219   std::vector<int32_t> op_input_indices =
220       GetWeightInputIndices(op_code, custom_op_map);
221   for (const int32_t op_input_idx : op_input_indices) {
222     int32_t tensor_idx = op->inputs[op_input_idx];
223     if (tensor_idx == -1) {
224       LOG(INFO) << "Skipping optional tensor input " << op_input_idx
225                 << " of operation " << EnumNameBuiltinOperator(builtin_code);
226       continue;
227     }
228 
229     TensorT* tensor = subgraph->tensors[tensor_idx].get();
230     if (tensor->type != TensorType_FLOAT32) {
231       LOG(INFO) << "Skipping quantization of tensor " << tensor->name
232                 << " that is not type float.";
233       continue;
234     }
235 
236     uint64_t num_elements;
237     TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements));
238     if (num_elements < weights_min_num_elements) {
239       LOG(INFO) << "Skipping quantization of tensor " << tensor->name
240                 << " because it has fewer than " << weights_min_num_elements
241                 << " elements (" << num_elements << ").";
242       continue;
243     }
244 
245     // Some tensors may have a null buffer vector, indicating an intermediate
246     // array.
247     if (model->buffers[tensor->buffer]->data.data() == nullptr) {
248       LOG(INFO) << "Skipping quantization of tensor " << tensor->name
249                 << " because it has no allocated buffer.";
250       continue;
251     }
252 
253     if (builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
254       tensor_map->insert({tensor_idx,
255                           {tensor, /*is_per_channel=*/use_updated_hybrid_scheme,
256                            /*dim=*/3}});
257     } else if (builtin_code == BuiltinOperator_CONV_2D) {
258       tensor_map->insert({tensor_idx,
259                           {tensor, /*is_per_channel=*/use_updated_hybrid_scheme,
260                            /*dim=*/0}});
261     } else {
262       switch (builtin_code) {
263         case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
264           op->builtin_options.AsBidirectionalSequenceLSTMOptions()
265               ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
266           break;
267         case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
268           op->builtin_options.AsBidirectionalSequenceRNNOptions()
269               ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
270           break;
271         case BuiltinOperator_FULLY_CONNECTED:
272           op->builtin_options.AsFullyConnectedOptions()
273               ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
274           break;
275         case BuiltinOperator_BATCH_MATMUL:
276           op->builtin_options.AsBatchMatMulOptions()
277               ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
278           break;
279         case BuiltinOperator_LSTM:
280           op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs =
281               use_updated_hybrid_scheme;
282           break;
283         case BuiltinOperator_RNN:
284           op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs =
285               use_updated_hybrid_scheme;
286           break;
287         case BuiltinOperator_SVDF:
288           op->builtin_options.AsSVDFOptions()->asymmetric_quantize_inputs =
289               use_updated_hybrid_scheme;
290           break;
291         case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
292           op->builtin_options.AsUnidirectionalSequenceLSTMOptions()
293               ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
294           break;
295         case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
296           op->builtin_options.AsSequenceRNNOptions()
297               ->asymmetric_quantize_inputs = use_updated_hybrid_scheme;
298           break;
299         default:
300           break;
301       }
302       tensor_map->insert({tensor_idx, {tensor, /*is_per_channel=*/false}});
303     }
304   }
305 
306   return kTfLiteOk;
307 }
308 
309 // Updates operator code versions for the operators with INT8 inputs.
UpdateInt8OperatorVersions(ModelT * model,bool use_updated_hybrid_scheme)310 void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) {
311   for (int i = 0, end = model->operator_codes.size(); i < end; ++i) {
312     const BuiltinOperator& op_code =
313         GetBuiltinCode(model->operator_codes[i].get());
314     if (op_code == BuiltinOperator_RNN ||
315         op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
316         op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
317         op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
318       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 3 : 2;
319     } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
320                op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
321       model->operator_codes[i]->version = 3;
322     } else if (op_code == BuiltinOperator_LSTM) {
323       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 3;
324     } else if (op_code == BuiltinOperator_CONV_2D) {
325       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2;
326     } else if (op_code == BuiltinOperator_FULLY_CONNECTED) {
327       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3;
328     } else if (op_code == BuiltinOperator_BATCH_MATMUL) {
329       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 1;
330     } else if (op_code == BuiltinOperator_SVDF) {
331       model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2;
332     } else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) {
333       model->operator_codes[i]->version = 6;
334     }
335   }
336 }
337 
338 // Returns true if the op in consumer_op_infos can pass through quantization.
IsQuantizationPassThroughOps(const ModelT * model,const std::vector<ConsumerOpInfo> & consumer_op_infos)339 bool IsQuantizationPassThroughOps(
340     const ModelT* model, const std::vector<ConsumerOpInfo>& consumer_op_infos) {
341   if (consumer_op_infos.size() != 1) {
342     return false;
343   }
344   const OperatorT* consumer_op = consumer_op_infos.front().op;
345   const BuiltinOperator op_code =
346       GetBuiltinCode(model->operator_codes[consumer_op->opcode_index].get());
347   return op_code == BuiltinOperator_GATHER ||
348          op_code == BuiltinOperator_EMBEDDING_LOOKUP;
349 }
350 
351 // Copies quantization parameters from input to output and returns consumers of
352 // the output tensor as a tuple with values:
353 // - index of the output tensor
354 // - pointer to the output tensor
355 // - vector of consumers ops.
356 std::tuple<int32_t, TensorT*, std::vector<ConsumerOpInfo>>
PassQuantizationAndGetConsumers(const ModelT * model,const SubGraphT * subgraph,const std::vector<ConsumerOpInfo> & consumer_op_infos,const CustomOpMap & custom_op_map)357 PassQuantizationAndGetConsumers(
358     const ModelT* model, const SubGraphT* subgraph,
359     const std::vector<ConsumerOpInfo>& consumer_op_infos,
360     const CustomOpMap& custom_op_map) {
361   const OperatorT* op = consumer_op_infos.front().op;
362   const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
363   if (op->outputs.size() != 1) {
364     LOG(ERROR)
365         << "An op that passes quantization has more than one quantized output";
366     return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
367   }
368   const int32_t output_tensor_idx = op->outputs.front();
369   const auto input_idx = GetWeightInputIndices(op_code, custom_op_map);
370   if (input_idx.size() != 1) {
371     LOG(ERROR)
372         << "An op that passes quantization has more than one quantized input";
373     return std::make_tuple(-1, nullptr, std::vector<ConsumerOpInfo>());
374   }
375   const int32_t input_tensor_idx = op->inputs[input_idx.front()];
376 
377   // Propagate quantization params.
378   const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get();
379   TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get();
380   if (!output_tensor->quantization) {
381     output_tensor->quantization = std::make_unique<QuantizationParametersT>();
382   }
383   *output_tensor->quantization = *input_tensor->quantization;
384   output_tensor->type = TensorType_INT8;
385   return std::make_tuple(
386       output_tensor_idx, output_tensor,
387       GetTensorConsumers(model, subgraph, output_tensor_idx));
388 }
389 
IsOpDenylisted(const flat_hash_set<BuiltinOperator> & op_denylist,const BuiltinOperator op_code)390 inline bool IsOpDenylisted(const flat_hash_set<BuiltinOperator>& op_denylist,
391                            const BuiltinOperator op_code) {
392   return op_denylist.find(op_code) != op_denylist.end();
393 }
394 
QuantizeWeightsInt8(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,bool use_hybrid_evaluation,uint64_t weights_min_num_elements,const CustomOpMap & custom_op_map,bool use_updated_hybrid_scheme,const flat_hash_set<BuiltinOperator> & op_denylist={})395 TfLiteStatus QuantizeWeightsInt8(
396     flatbuffers::FlatBufferBuilder* builder, const Model* input_model,
397     bool use_hybrid_evaluation, uint64_t weights_min_num_elements,
398     const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme,
399     const flat_hash_set<BuiltinOperator>& op_denylist = {}) {
400   std::unique_ptr<ModelT> model;
401   model.reset(input_model->UnPack());
402 
403   for (int subgraph_index = 0, end = model->subgraphs.size();
404        subgraph_index < end; ++subgraph_index) {
405     SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get();
406 
407     absl::flat_hash_map<int32_t, TensorPerChannel> tensor_map;
408     for (int i = 0; i < subgraph->operators.size(); ++i) {
409       OperatorT* op = subgraph->operators[i].get();
410       TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
411           model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map,
412           subgraph_index, use_updated_hybrid_scheme));
413     }
414 
415     for (std::pair<int32_t, TensorPerChannel> tensor_pair : tensor_map) {
416       // Quantize the tensor.
417       if (tensor_pair.second.is_per_channel) {
418         TF_LITE_ENSURE_STATUS(utils::SymmetricQuantizeTensorPerChannel(
419             model.get(), tensor_pair.second.t, tensor_pair.second.channel_dim,
420             nullptr));
421       } else {
422         TF_LITE_ENSURE_STATUS(
423             utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second.t));
424       }
425     }
426 
427     // Examine the tensor consumers to determine which require dequantize ops.
428     for (const auto& tensor_pair : tensor_map) {
429       int32_t tensor_idx = tensor_pair.first;
430       TensorT* tensor = tensor_pair.second.t;
431       std::vector<ConsumerOpInfo> consumer_op_infos =
432           GetTensorConsumers(model.get(), subgraph, tensor_idx);
433       if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) {
434         std::tie(tensor_idx, tensor, consumer_op_infos) =
435             PassQuantizationAndGetConsumers(model.get(), subgraph,
436                                             consumer_op_infos, custom_op_map);
437         if (tensor_idx < 0) {
438           // Error message is already logged by PassQuantizationAndGetConsumers.
439           return kTfLiteError;
440         }
441       }
442 
443       std::vector<ConsumerOpInfo> dequant_op_infos;  // Ops that need dequants.
444       for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) {
445         OperatorT* consumer_op = consumer_op_info.op;
446         const OperatorCodeT* consumer_op_code =
447             model->operator_codes[consumer_op->opcode_index].get();
448         // If the op is a hybrid op and all the required tensors are quantized,
449         // we have no further work to do, but for all ops that require
450         // dequantization we need to add a Dequantize op.
451         bool eval_hybrid =
452             use_hybrid_evaluation &&
453             !IsOpDenylisted(op_denylist, GetBuiltinCode(consumer_op_code)) &&
454             IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map,
455                                  use_updated_hybrid_scheme) &&
456             CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code,
457                                       custom_op_map) &&
458             IsQuantizedInput(consumer_op_code, custom_op_map,
459                              consumer_op_info.op_input_idx);
460         if (!eval_hybrid) {
461           dequant_op_infos.push_back(consumer_op_info);
462         }
463       }
464 
465       // Check if this tensor is an output tensor.
466       int32_t output_index = -1;
467       for (int32_t i = 0; i < subgraph->outputs.size(); ++i) {
468         if (subgraph->outputs[i] == tensor_idx) {
469           output_index = i;
470           break;
471         }
472       }
473 
474       // If no ops require dequant and it is not output, we are done for this
475       // tensor.
476       if (dequant_op_infos.empty() && output_index < 0) {
477         continue;
478       }
479 
480       // Create a new tensor to be the output of the dequantize op.
481       std::unique_ptr<TensorT> dequantize_output;
482       const string dequant_name = tensor->name + "_dequantize";
483       utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature,
484                         TensorType_FLOAT32, &dequantize_output);
485       const int32_t dequantize_output_idx = subgraph->tensors.size();
486       subgraph->tensors.push_back(std::move(dequantize_output));
487 
488       // Create the Dequantize operation.
489       std::unique_ptr<OperatorT> dequantize_op;
490       utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
491                                     dequantize_output_idx);
492 
493       // Update the op_input of all the ops that need the created dequantize
494       // operation.
495       int32_t min_op_idx = subgraph->operators.size();
496       for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
497         dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
498             dequantize_output_idx;
499         min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
500       }
501       // Update output name.
502       if (output_index >= 0) {
503         subgraph->outputs[output_index] = dequantize_output_idx;
504       }
505 
506       // Insert the newly created Dequantize operation before the earliest
507       // consumer, since TFLite requires operators to be topo-sorted.
508       subgraph->operators.insert(subgraph->operators.begin() + min_op_idx,
509                                  std::move(dequantize_op));
510     }
511   }
512 
513   // Update the modified operator code versions.
514   UpdateInt8OperatorVersions(model.get(), use_updated_hybrid_scheme);
515 
516   flatbuffers::Offset<Model> output_model_location =
517       Model::Pack(*builder, model.get());
518   FinishModelBuffer(*builder, output_model_location);
519 
520   return kTfLiteOk;
521 }
522 
QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder * builder,const Model * input_model)523 TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder,
524                                     const Model* input_model) {
525   std::unique_ptr<ModelT> model;
526   model.reset(input_model->UnPack());
527 
528   for (int subgraph_index = 0, end = model->subgraphs.size();
529        subgraph_index < end; ++subgraph_index) {
530     SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get();
531 
532     absl::flat_hash_map<int32_t, TensorT*> tensor_map;
533     for (int i = 0, sub_end = subgraph->operators.size(); i < sub_end; ++i) {
534       OperatorT* op = subgraph->operators[i].get();
535       for (auto tensor_idx : op->inputs) {
536         // Skip optional tensors.
537         if (tensor_idx == kTfLiteOptionalTensor) {
538           continue;
539         }
540         TensorT* tensor = subgraph->tensors[tensor_idx].get();
541         BufferT* buffer = model->buffers[tensor->buffer].get();
542         if (buffer == nullptr) {
543           return kTfLiteError;
544         }
545         // Quantize tensors that have data to quantize.
546         bool is_constant = !model->buffers[tensor->buffer].get()->data.empty();
547         if (tensor->type == TensorType_FLOAT32 && is_constant) {
548           tensor_map.insert({tensor_idx, tensor});
549         }
550       }
551     }
552 
553     // The hash map ensures that we quantize each tensor exactly once.
554     for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
555       // Quantize the tensor.
556       TF_LITE_ENSURE_STATUS(
557           utils::QuantizeTensorFloat16(model.get(), tensor_pair.second));
558 
559       int32_t tensor_idx = tensor_pair.first;
560       TensorT* tensor = tensor_pair.second;
561       std::vector<ConsumerOpInfo> dequant_op_infos =
562           GetTensorConsumers(model.get(), subgraph, tensor_idx);
563 
564       // Create a new tensor to be the output of the dequantize op.
565       std::unique_ptr<TensorT> dequantize_output;
566       const string dequant_name = tensor->name + "_dequantize";
567       utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature,
568                         TensorType_FLOAT32, &dequantize_output);
569       const int32_t dequantize_output_idx = subgraph->tensors.size();
570       subgraph->tensors.push_back(std::move(dequantize_output));
571 
572       // Create the Dequantize operation.
573       std::unique_ptr<OperatorT> dequantize_op;
574       utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
575                                     dequantize_output_idx);
576 
577       // Update the op_input of all the ops that need the created dequantize
578       // operation.
579       int32_t min_op_idx = subgraph->operators.size();
580       for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
581         dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
582             dequantize_output_idx;
583         min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
584       }
585 
586       // Insert the newly created Dequantize operation before the earliest
587       // consumer, since TFLite requires operators to be topo-sorted.
588       subgraph->operators.insert(subgraph->operators.begin() + min_op_idx,
589                                  std::move(dequantize_op));
590     }
591   }
592 
593   flatbuffers::Offset<Model> output_model_location =
594       Model::Pack(*builder, model.get());
595   FinishModelBuffer(*builder, output_model_location);
596   return kTfLiteOk;
597 }
598 }  // namespace
599 
600 namespace internal {
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,bool use_hybrid_evaluation,QuantizerType quantizer_type)601 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
602                              const Model* input_model,
603                              uint64_t weights_min_num_elements,
604                              bool use_hybrid_evaluation,
605                              QuantizerType quantizer_type) {
606   // By default we require that only weights with more than
607   // kWeightsMinSizeDefault elements are quantized.
608   if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
609     return mlir::lite::QuantizeWeights(
610         builder, input_model, weights_min_num_elements, use_hybrid_evaluation);
611   }
612   CustomOpMap custom_op_map;
613   return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation,
614                              weights_min_num_elements, custom_op_map,
615                              kUseUpdatedHybridSchemeDefault);
616 }
617 }  // namespace internal
618 
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,QuantizerType quantizer_type)619 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
620                              const Model* input_model,
621                              uint64_t weights_min_num_elements,
622                              QuantizerType quantizer_type) {
623   if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
624     return mlir::lite::QuantizeWeights(builder, input_model,
625                                        weights_min_num_elements);
626   }
627   CustomOpMap custom_op_map;
628   return QuantizeWeightsInt8(builder, input_model, true,
629                              weights_min_num_elements, custom_op_map,
630                              kUseUpdatedHybridSchemeDefault);
631 }
632 
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,BufferType quant_type,bool use_updated_hybrid_scheme,QuantizerType quantizer_type)633 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
634                              const Model* input_model, BufferType quant_type,
635                              bool use_updated_hybrid_scheme,
636                              QuantizerType quantizer_type) {
637   // By default we require that only weights with more than
638   // kWeightsMinSizeDefault elements are quantized.
639   if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
640     return mlir::lite::QuantizeWeights(builder, input_model,
641                                        (mlir::lite::BufferType)quant_type,
642                                        use_updated_hybrid_scheme);
643   }
644   switch (quant_type) {
645     case BufferType::QUANTIZED_INT8: {
646       CustomOpMap custom_op_map;
647       return QuantizeWeightsInt8(builder, input_model, true,
648                                  kWeightsMinNumElementsDefault, custom_op_map,
649                                  use_updated_hybrid_scheme);
650     }
651     case BufferType::QUANTIZED_FLOAT16:
652       return QuantizeWeightsFloat16(builder, input_model);
653   }
654 }
655 
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,const CustomOpMap & custom_op_map,QuantizerType quantizer_type)656 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
657                              const Model* input_model,
658                              uint64_t weights_min_num_elements,
659                              const CustomOpMap& custom_op_map,
660                              QuantizerType quantizer_type) {
661   if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
662     mlir::lite::CustomOpMap mlir_custom_op_map;
663     ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map);
664     return mlir::lite::QuantizeWeights(
665         builder, input_model, weights_min_num_elements, mlir_custom_op_map);
666   }
667   return QuantizeWeightsInt8(builder, input_model, true,
668                              weights_min_num_elements, custom_op_map,
669                              kUseUpdatedHybridSchemeDefault);
670 }
671 
QuantizeWeights(flatbuffers::FlatBufferBuilder * builder,const Model * input_model,uint64_t weights_min_num_elements,const CustomOpMap & custom_op_map,bool use_updated_hybrid_scheme,const flat_hash_set<BuiltinOperator> & op_denylist,QuantizerType quantizer_type)672 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
673                              const Model* input_model,
674                              uint64_t weights_min_num_elements,
675                              const CustomOpMap& custom_op_map,
676                              bool use_updated_hybrid_scheme,
677                              const flat_hash_set<BuiltinOperator>& op_denylist,
678                              QuantizerType quantizer_type) {
679   if (quantizer_type == QuantizerType::MLIR_QUANTIZER) {
680     mlir::lite::CustomOpMap mlir_custom_op_map;
681     ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map);
682     return mlir::lite::QuantizeWeights(
683         builder, input_model, weights_min_num_elements, mlir_custom_op_map,
684         use_updated_hybrid_scheme, op_denylist);
685   }
686   return QuantizeWeightsInt8(builder, input_model,
687                              /*use_hybrid_evaluation=*/true,
688                              weights_min_num_elements, custom_op_map,
689                              use_updated_hybrid_scheme, op_denylist);
690 }
691 
692 }  // namespace optimize
693 }  // namespace tflite
694