xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/graph_transformations/quantization_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
16 
17 #include <memory>
18 #include <string>
19 
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24 
25 namespace toco {
26 
InferQuantizedDataTypeFromFakeQuant(const FakeQuantOperator & op,ArrayDataType * out_quantized_data_type)27 bool InferQuantizedDataTypeFromFakeQuant(
28     const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type) {
29   if (op.num_bits <= 8) {
30     *out_quantized_data_type = ArrayDataType::kUint8;
31     return true;
32   } else if (op.num_bits <= 16) {
33     *out_quantized_data_type = ArrayDataType::kInt16;
34     return true;
35   } else {
36     *out_quantized_data_type = ArrayDataType::kNone;
37     return false;
38   }
39 }
40 
GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,double * out_min_value,double * out_max_value)41 bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
42                                         double* out_min_value,
43                                         double* out_max_value) {
44   switch (data_type) {
45     case ArrayDataType::kUint8:
46       *out_min_value = 0;
47       *out_max_value = 255;
48       return true;
49     case ArrayDataType::kInt16:
50       *out_min_value = -32768;
51       *out_max_value = 32767;
52       return true;
53     default:
54       return false;
55   }
56 }
57 
GetQuantizedDataType(const Array & array,ArrayDataType default_type)58 ArrayDataType GetQuantizedDataType(const Array& array,
59                                    ArrayDataType default_type) {
60   switch (array.final_data_type) {
61     case ArrayDataType::kInt8:
62     case ArrayDataType::kUint8:
63     case ArrayDataType::kInt16:
64     case ArrayDataType::kUint16:
65     case ArrayDataType::kInt32:
66     case ArrayDataType::kUint32:
67     case ArrayDataType::kInt64:
68     case ArrayDataType::kUint64:
69       return array.final_data_type;
70     case ArrayDataType::kFloat:
71     case ArrayDataType::kNone:
72       return default_type;
73     default:
74       LOG(FATAL) << "Unhandled final quantization type "
75                  << static_cast<int>(array.final_data_type);
76   }
77 }
78 
79 template <ArrayDataType A>
ChooseQuantizationParamsForArrayAndQuantizedDataType(const Array & array,QuantizationParams * quantization_params)80 void ChooseQuantizationParamsForArrayAndQuantizedDataType(
81     const Array& array, QuantizationParams* quantization_params) {
82   *quantization_params = ::tflite::ChooseQuantizationParams<DataType<A>>(
83       array.minmax->min, array.minmax->max, array.narrow_range);
84 }
85 
ChooseQuantizationParamsForArrayAndQuantizedDataType(const Array & array,ArrayDataType quantized_data_type,QuantizationParams * quantization_params)86 void ChooseQuantizationParamsForArrayAndQuantizedDataType(
87     const Array& array, ArrayDataType quantized_data_type,
88     QuantizationParams* quantization_params) {
89   switch (quantized_data_type) {
90     case ArrayDataType::kInt8:
91       ChooseQuantizationParamsForArrayAndQuantizedDataType<
92           ArrayDataType::kInt8>(array, quantization_params);
93       break;
94     case ArrayDataType::kUint8:
95       ChooseQuantizationParamsForArrayAndQuantizedDataType<
96           ArrayDataType::kUint8>(array, quantization_params);
97       break;
98     case ArrayDataType::kInt16:
99       ChooseQuantizationParamsForArrayAndQuantizedDataType<
100           ArrayDataType::kInt16>(array, quantization_params);
101       break;
102     case ArrayDataType::kUint16:
103       ChooseQuantizationParamsForArrayAndQuantizedDataType<
104           ArrayDataType::kUint16>(array, quantization_params);
105       break;
106     case ArrayDataType::kInt32:
107       ChooseQuantizationParamsForArrayAndQuantizedDataType<
108           ArrayDataType::kInt32>(array, quantization_params);
109       break;
110     case ArrayDataType::kUint32:
111       ChooseQuantizationParamsForArrayAndQuantizedDataType<
112           ArrayDataType::kUint32>(array, quantization_params);
113       break;
114     case ArrayDataType::kInt64:
115       ChooseQuantizationParamsForArrayAndQuantizedDataType<
116           ArrayDataType::kInt64>(array, quantization_params);
117       break;
118     case ArrayDataType::kUint64:
119       ChooseQuantizationParamsForArrayAndQuantizedDataType<
120           ArrayDataType::kUint64>(array, quantization_params);
121       break;
122     case ArrayDataType::kFloat:
123     case ArrayDataType::kComplex64:
124     case ArrayDataType::kNone:
125     default:
126       LOG(FATAL) << "Unhandled final quantization type "
127                  << static_cast<int>(quantized_data_type);
128   }
129 }
130 
131 namespace {
132 
133 template <ArrayDataType A>
QuantizeBuffer(const Array & array,const QuantizationParams & quantization_params)134 std::unique_ptr<GenericBuffer> QuantizeBuffer(
135     const Array& array, const QuantizationParams& quantization_params) {
136   const GenericBuffer& buffer = *array.buffer;
137   const auto inverse_scale = 1. / quantization_params.scale;
138   CHECK(buffer.type == ArrayDataType::kFloat);
139   const auto& float_buffer =
140       static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
141   auto* quantized_buffer = new Buffer<A>;
142   quantized_buffer->data.resize(float_buffer.data.size());
143   for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
144     const float src_val = float_buffer.data[i];
145     double scaled_val;  // Astonishingly, using 'float' degrades accuracy just
146                         // enough to make a few tests fail!
147     if (quantization_params.scale == 0) {
148       CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
149                            << "so all its values should be 0.";
150       scaled_val = quantization_params.zero_point;
151     } else {
152       scaled_val = quantization_params.zero_point + inverse_scale * src_val;
153     }
154     auto integer_val = tflite::SafeCast<DataType<A>>(std::round(scaled_val));
155     // In addition to its effect on the choice of quantization params upstream
156     // of here, narrow_range also means nudge the min quantized value by +1,
157     // so e.g. uint8 values get constrained to [1, 255].
158     if (integer_val == std::numeric_limits<DataType<A>>::min() &&
159         array.narrow_range) {
160       integer_val++;
161     }
162     quantized_buffer->data[i] = integer_val;
163   }
164   return std::unique_ptr<GenericBuffer>(quantized_buffer);
165 }
166 
167 template <ArrayDataType A>
QuantizeArray(GraphTransformation * transformation,Model * model,const std::string & name,const QuantizationParams & quantization_params)168 void QuantizeArray(GraphTransformation* transformation, Model* model,
169                    const std::string& name,
170                    const QuantizationParams& quantization_params) {
171   auto& array = model->GetArray(name);
172   CHECK(array.data_type == ArrayDataType::kFloat);
173   CHECK(!array.quantization_params);
174   array.GetOrCreateQuantizationParams() = quantization_params;
175   if (array.buffer) {
176     array.buffer = QuantizeBuffer<A>(array, quantization_params);
177   }
178   array.data_type = A;
179   array.final_data_type = A;
180   transformation->AddMessageF(
181       "Quantized array %s to %s zero_point=%g, scale=%g", name,
182       ArrayDataTypeName(array.data_type), quantization_params.zero_point,
183       quantization_params.scale);
184 }
185 
186 }  // namespace
187 
QuantizeArray(GraphTransformation * transformation,Model * model,const std::string & name,ArrayDataType quantized_data_type,const QuantizationParams & quantization_params)188 void QuantizeArray(GraphTransformation* transformation, Model* model,
189                    const std::string& name, ArrayDataType quantized_data_type,
190                    const QuantizationParams& quantization_params) {
191   ArrayDataType adjusted_data_type = quantized_data_type;
192   auto& array = model->GetArray(name);
193   if (array.final_data_type == ArrayDataType::kInt16) {
194     adjusted_data_type = array.final_data_type;
195   }
196 
197   switch (adjusted_data_type) {
198     case ArrayDataType::kUint8:
199       return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
200                                                   quantization_params);
201     case ArrayDataType::kInt16:
202       return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name,
203                                                   quantization_params);
204     case ArrayDataType::kInt32:
205       return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
206                                                   quantization_params);
207     default:
208       LOG(FATAL) << "Unhandled case.";
209   }
210 }
211 
IsArrayQuantizedRangeSubset(GraphTransformation * transformation,const Array & array,double clamp_min,double clamp_max)212 bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
213                                  const Array& array, double clamp_min,
214                                  double clamp_max) {
215   ArrayDataType quantized_data_type =
216       GetQuantizedDataType(array, array.data_type);
217   if (quantized_data_type == ArrayDataType::kNone ||
218       quantized_data_type == ArrayDataType::kFloat) {
219     // The array is not (or never will be) quantized.
220     return false;
221   }
222 
223   QuantizationParams quantization_params;
224   if (!array.quantization_params) {
225     if (!array.minmax) {
226       transformation->AddMessageF("No quantization params and no minmax");
227       return false;
228     } else {
229       // Work around cases where we are asking for this prior to the Quantize
230       // transformation having added the quantization_params.
231       ChooseQuantizationParamsForArrayAndQuantizedDataType(
232           array, quantized_data_type, &quantization_params);
233       transformation->AddMessageF(
234           "No quantization params - inferring from data type %s with minmax "
235           "%g,%g as zero_point=%g, scale=%g",
236           ArrayDataTypeName(quantized_data_type), array.minmax->min,
237           array.minmax->max, quantization_params.zero_point,
238           quantization_params.scale);
239     }
240   } else {
241     quantization_params = array.GetQuantizationParams();
242   }
243 
244   double quantized_min, quantized_max;
245   CHECK(GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min,
246                                            &quantized_max))
247       << "Type is not quantized";
248 
249   bool has_nontrivial_min_bound = false;
250   bool has_nontrivial_max_bound = false;
251 
252   double lowest_representable_output =
253       (quantized_min - quantization_params.zero_point) *
254       quantization_params.scale;
255   if (lowest_representable_output < clamp_min) {
256     has_nontrivial_min_bound = true;
257     transformation->AddMessageF(
258         "Quantized activation function is not trivial: "
259         "the lowest representable output value %g"
260         " less than the clamp min bound %g.",
261         lowest_representable_output, clamp_min);
262   }
263 
264   double highest_representable_output =
265       (quantized_max - quantization_params.zero_point) *
266       quantization_params.scale;
267   if (highest_representable_output > clamp_max) {
268     has_nontrivial_max_bound = true;
269     transformation->AddMessageF(
270         "Quantized activation function is not trivial: "
271         "the highest representable output value %g"
272         " is greater than the clamp max bound %g.",
273         highest_representable_output, clamp_max);
274   }
275 
276   return !has_nontrivial_min_bound && !has_nontrivial_max_bound;
277 }
278 
279 }  // namespace toco
280