1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/graph_transformations/quantization_util.h"
16
17 #include <memory>
18 #include <string>
19
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24
25 namespace toco {
26
InferQuantizedDataTypeFromFakeQuant(const FakeQuantOperator & op,ArrayDataType * out_quantized_data_type)27 bool InferQuantizedDataTypeFromFakeQuant(
28 const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type) {
29 if (op.num_bits <= 8) {
30 *out_quantized_data_type = ArrayDataType::kUint8;
31 return true;
32 } else if (op.num_bits <= 16) {
33 *out_quantized_data_type = ArrayDataType::kInt16;
34 return true;
35 } else {
36 *out_quantized_data_type = ArrayDataType::kNone;
37 return false;
38 }
39 }
40
GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,double * out_min_value,double * out_max_value)41 bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
42 double* out_min_value,
43 double* out_max_value) {
44 switch (data_type) {
45 case ArrayDataType::kUint8:
46 *out_min_value = 0;
47 *out_max_value = 255;
48 return true;
49 case ArrayDataType::kInt16:
50 *out_min_value = -32768;
51 *out_max_value = 32767;
52 return true;
53 default:
54 return false;
55 }
56 }
57
GetQuantizedDataType(const Array & array,ArrayDataType default_type)58 ArrayDataType GetQuantizedDataType(const Array& array,
59 ArrayDataType default_type) {
60 switch (array.final_data_type) {
61 case ArrayDataType::kInt8:
62 case ArrayDataType::kUint8:
63 case ArrayDataType::kInt16:
64 case ArrayDataType::kUint16:
65 case ArrayDataType::kInt32:
66 case ArrayDataType::kUint32:
67 case ArrayDataType::kInt64:
68 case ArrayDataType::kUint64:
69 return array.final_data_type;
70 case ArrayDataType::kFloat:
71 case ArrayDataType::kNone:
72 return default_type;
73 default:
74 LOG(FATAL) << "Unhandled final quantization type "
75 << static_cast<int>(array.final_data_type);
76 }
77 }
78
79 template <ArrayDataType A>
ChooseQuantizationParamsForArrayAndQuantizedDataType(const Array & array,QuantizationParams * quantization_params)80 void ChooseQuantizationParamsForArrayAndQuantizedDataType(
81 const Array& array, QuantizationParams* quantization_params) {
82 *quantization_params = ::tflite::ChooseQuantizationParams<DataType<A>>(
83 array.minmax->min, array.minmax->max, array.narrow_range);
84 }
85
ChooseQuantizationParamsForArrayAndQuantizedDataType(const Array & array,ArrayDataType quantized_data_type,QuantizationParams * quantization_params)86 void ChooseQuantizationParamsForArrayAndQuantizedDataType(
87 const Array& array, ArrayDataType quantized_data_type,
88 QuantizationParams* quantization_params) {
89 switch (quantized_data_type) {
90 case ArrayDataType::kInt8:
91 ChooseQuantizationParamsForArrayAndQuantizedDataType<
92 ArrayDataType::kInt8>(array, quantization_params);
93 break;
94 case ArrayDataType::kUint8:
95 ChooseQuantizationParamsForArrayAndQuantizedDataType<
96 ArrayDataType::kUint8>(array, quantization_params);
97 break;
98 case ArrayDataType::kInt16:
99 ChooseQuantizationParamsForArrayAndQuantizedDataType<
100 ArrayDataType::kInt16>(array, quantization_params);
101 break;
102 case ArrayDataType::kUint16:
103 ChooseQuantizationParamsForArrayAndQuantizedDataType<
104 ArrayDataType::kUint16>(array, quantization_params);
105 break;
106 case ArrayDataType::kInt32:
107 ChooseQuantizationParamsForArrayAndQuantizedDataType<
108 ArrayDataType::kInt32>(array, quantization_params);
109 break;
110 case ArrayDataType::kUint32:
111 ChooseQuantizationParamsForArrayAndQuantizedDataType<
112 ArrayDataType::kUint32>(array, quantization_params);
113 break;
114 case ArrayDataType::kInt64:
115 ChooseQuantizationParamsForArrayAndQuantizedDataType<
116 ArrayDataType::kInt64>(array, quantization_params);
117 break;
118 case ArrayDataType::kUint64:
119 ChooseQuantizationParamsForArrayAndQuantizedDataType<
120 ArrayDataType::kUint64>(array, quantization_params);
121 break;
122 case ArrayDataType::kFloat:
123 case ArrayDataType::kComplex64:
124 case ArrayDataType::kNone:
125 default:
126 LOG(FATAL) << "Unhandled final quantization type "
127 << static_cast<int>(quantized_data_type);
128 }
129 }
130
131 namespace {
132
133 template <ArrayDataType A>
QuantizeBuffer(const Array & array,const QuantizationParams & quantization_params)134 std::unique_ptr<GenericBuffer> QuantizeBuffer(
135 const Array& array, const QuantizationParams& quantization_params) {
136 const GenericBuffer& buffer = *array.buffer;
137 const auto inverse_scale = 1. / quantization_params.scale;
138 CHECK(buffer.type == ArrayDataType::kFloat);
139 const auto& float_buffer =
140 static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
141 auto* quantized_buffer = new Buffer<A>;
142 quantized_buffer->data.resize(float_buffer.data.size());
143 for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
144 const float src_val = float_buffer.data[i];
145 double scaled_val; // Astonishingly, using 'float' degrades accuracy just
146 // enough to make a few tests fail!
147 if (quantization_params.scale == 0) {
148 CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
149 << "so all its values should be 0.";
150 scaled_val = quantization_params.zero_point;
151 } else {
152 scaled_val = quantization_params.zero_point + inverse_scale * src_val;
153 }
154 auto integer_val = tflite::SafeCast<DataType<A>>(std::round(scaled_val));
155 // In addition to its effect on the choice of quantization params upstream
156 // of here, narrow_range also means nudge the min quantized value by +1,
157 // so e.g. uint8 values get constrained to [1, 255].
158 if (integer_val == std::numeric_limits<DataType<A>>::min() &&
159 array.narrow_range) {
160 integer_val++;
161 }
162 quantized_buffer->data[i] = integer_val;
163 }
164 return std::unique_ptr<GenericBuffer>(quantized_buffer);
165 }
166
167 template <ArrayDataType A>
QuantizeArray(GraphTransformation * transformation,Model * model,const std::string & name,const QuantizationParams & quantization_params)168 void QuantizeArray(GraphTransformation* transformation, Model* model,
169 const std::string& name,
170 const QuantizationParams& quantization_params) {
171 auto& array = model->GetArray(name);
172 CHECK(array.data_type == ArrayDataType::kFloat);
173 CHECK(!array.quantization_params);
174 array.GetOrCreateQuantizationParams() = quantization_params;
175 if (array.buffer) {
176 array.buffer = QuantizeBuffer<A>(array, quantization_params);
177 }
178 array.data_type = A;
179 array.final_data_type = A;
180 transformation->AddMessageF(
181 "Quantized array %s to %s zero_point=%g, scale=%g", name,
182 ArrayDataTypeName(array.data_type), quantization_params.zero_point,
183 quantization_params.scale);
184 }
185
186 } // namespace
187
QuantizeArray(GraphTransformation * transformation,Model * model,const std::string & name,ArrayDataType quantized_data_type,const QuantizationParams & quantization_params)188 void QuantizeArray(GraphTransformation* transformation, Model* model,
189 const std::string& name, ArrayDataType quantized_data_type,
190 const QuantizationParams& quantization_params) {
191 ArrayDataType adjusted_data_type = quantized_data_type;
192 auto& array = model->GetArray(name);
193 if (array.final_data_type == ArrayDataType::kInt16) {
194 adjusted_data_type = array.final_data_type;
195 }
196
197 switch (adjusted_data_type) {
198 case ArrayDataType::kUint8:
199 return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
200 quantization_params);
201 case ArrayDataType::kInt16:
202 return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name,
203 quantization_params);
204 case ArrayDataType::kInt32:
205 return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
206 quantization_params);
207 default:
208 LOG(FATAL) << "Unhandled case.";
209 }
210 }
211
IsArrayQuantizedRangeSubset(GraphTransformation * transformation,const Array & array,double clamp_min,double clamp_max)212 bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
213 const Array& array, double clamp_min,
214 double clamp_max) {
215 ArrayDataType quantized_data_type =
216 GetQuantizedDataType(array, array.data_type);
217 if (quantized_data_type == ArrayDataType::kNone ||
218 quantized_data_type == ArrayDataType::kFloat) {
219 // The array is not (or never will be) quantized.
220 return false;
221 }
222
223 QuantizationParams quantization_params;
224 if (!array.quantization_params) {
225 if (!array.minmax) {
226 transformation->AddMessageF("No quantization params and no minmax");
227 return false;
228 } else {
229 // Work around cases where we are asking for this prior to the Quantize
230 // transformation having added the quantization_params.
231 ChooseQuantizationParamsForArrayAndQuantizedDataType(
232 array, quantized_data_type, &quantization_params);
233 transformation->AddMessageF(
234 "No quantization params - inferring from data type %s with minmax "
235 "%g,%g as zero_point=%g, scale=%g",
236 ArrayDataTypeName(quantized_data_type), array.minmax->min,
237 array.minmax->max, quantization_params.zero_point,
238 quantization_params.scale);
239 }
240 } else {
241 quantization_params = array.GetQuantizationParams();
242 }
243
244 double quantized_min, quantized_max;
245 CHECK(GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min,
246 &quantized_max))
247 << "Type is not quantized";
248
249 bool has_nontrivial_min_bound = false;
250 bool has_nontrivial_max_bound = false;
251
252 double lowest_representable_output =
253 (quantized_min - quantization_params.zero_point) *
254 quantization_params.scale;
255 if (lowest_representable_output < clamp_min) {
256 has_nontrivial_min_bound = true;
257 transformation->AddMessageF(
258 "Quantized activation function is not trivial: "
259 "the lowest representable output value %g"
260 " less than the clamp min bound %g.",
261 lowest_representable_output, clamp_min);
262 }
263
264 double highest_representable_output =
265 (quantized_max - quantization_params.zero_point) *
266 quantization_params.scale;
267 if (highest_representable_output > clamp_max) {
268 has_nontrivial_max_bound = true;
269 transformation->AddMessageF(
270 "Quantized activation function is not trivial: "
271 "the highest representable output value %g"
272 " is greater than the clamp max bound %g.",
273 highest_representable_output, clamp_max);
274 }
275
276 return !has_nontrivial_min_bound && !has_nontrivial_max_bound;
277 }
278
279 } // namespace toco
280