1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
17
18 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
19
20 using namespace mlir;
21 using namespace mlir::quantfork;
22
getDefaultStorageParams(unsigned numBits,bool narrowRange,bool isSigned,MLIRContext * ctx,Type & storageType,int64_t & qmin,int64_t & qmax)23 static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
24 bool isSigned, MLIRContext *ctx,
25 Type &storageType, int64_t &qmin,
26 int64_t &qmax) {
27 // Hard-coded type mapping from TFLite.
28 if (numBits <= 8) {
29 storageType = IntegerType::get(ctx, 8);
30 if (isSigned) {
31 qmin = -128;
32 qmax = 127;
33 } else {
34 qmin = 0;
35 qmax = 255;
36 }
37 } else if (numBits <= 16) {
38 storageType = IntegerType::get(ctx, 16);
39 if (isSigned) {
40 qmin = -32768;
41 qmax = 32767;
42 } else {
43 qmin = 0;
44 qmax = 65535;
45 }
46 } else if (numBits <= 32) {
47 storageType = IntegerType::get(ctx, 32);
48 if (isSigned) {
49 qmin = std::numeric_limits<int32_t>::min();
50 qmax = std::numeric_limits<int32_t>::max();
51 } else {
52 qmin = std::numeric_limits<uint32_t>::min();
53 qmax = std::numeric_limits<uint32_t>::max();
54 }
55 } else {
56 return true;
57 }
58
59 // Handle narrowRange.
60 if (narrowRange) {
61 qmin += 1;
62 }
63 return false;
64 }
65
66 // This is a specific implementation of nudging:
67 // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
68 // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
69 // point is derived from the shifted range, and the scale isn't changed. As
70 // a consequence some values, which are supposed in the original [rmin, rmax]
71 // range will be outside the shifted range and be clamped during quantization.
72 // TODO: we should nudge the scale as well, but that requires the
73 // fake quant op used in the training to use the nudged scale as well.
getNudgedScaleAndZeroPoint(int64_t qmin,int64_t qmax,double rmin,double rmax,double & scale,int64_t & nudgedZeroPoint)74 static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
75 double rmax, double &scale,
76 int64_t &nudgedZeroPoint) {
77 // Determine the scale.
78 const double qminDouble = qmin;
79 const double qmaxDouble = qmax;
80 scale = (rmax - rmin) / (qmaxDouble - qminDouble);
81
82 // Zero point computation.
83 // In float, solve the affine equation for any known pair
84 // (real value, corresponding quantized value), of which, two such pairs
85 // are known: (rmin, qmin), (rmax, qmax).
86 // The arithmetic error on the zero point computed from either pair will be
87 // roughly machine_epsilon * (sum of absolute values of terms).
88 // Use the variant that adds the smaller error.
89 const double zeroPointFromMin = qminDouble - rmin / scale;
90 const double zeroPointFromMinError =
91 std::abs(qminDouble) + std::abs(rmin / scale);
92 const double zeroPointFromMax = qmaxDouble - rmax / scale;
93 const double zeroPointFromMaxError =
94 std::abs(qmaxDouble) + std::abs(rmax / scale);
95
96 const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
97 ? zeroPointFromMin
98 : zeroPointFromMax;
99
100 // Now nudge the zero point to be an integer.
101 nudgedZeroPoint = 0;
102 if (zeroPointDouble < qminDouble) {
103 nudgedZeroPoint = qmin;
104 } else if (zeroPointDouble > qmaxDouble) {
105 nudgedZeroPoint = qmax;
106 } else {
107 nudgedZeroPoint = round(zeroPointDouble);
108 }
109
110 // By construction, the nudged zero point should always be in range.
111 assert(nudgedZeroPoint >= qmin);
112 assert(nudgedZeroPoint <= qmax);
113 }
114
fakeQuantAttrsToType(Location loc,unsigned numBits,double rmin,double rmax,bool narrowRange,Type expressedType,bool isSigned)115 quant::UniformQuantizedType mlir::quantfork::fakeQuantAttrsToType(
116 Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange,
117 Type expressedType, bool isSigned) {
118 MLIRContext *ctx = expressedType.getContext();
119 unsigned flags = isSigned ? quant::QuantizationFlags::Signed : 0;
120 Type storageType;
121 int64_t qmin;
122 int64_t qmax;
123 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
124 qmin, qmax)) {
125 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
126 nullptr);
127 }
128
129 // Special case where min/max is close enough. The tensor contents are all
130 // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
131 // points and dequantized to 0.0.
132 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
133 return quant::UniformQuantizedType::getChecked(
134 loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax);
135 }
136
137 double scale;
138 int64_t nudgedZeroPoint;
139 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
140
141 return quant::UniformQuantizedType::getChecked(loc, flags, storageType,
142 expressedType, scale,
143 nudgedZeroPoint, qmin, qmax);
144 }
145
fakeQuantAttrsToType(Location loc,unsigned numBits,int32_t quantizedDimension,ArrayRef<double> rmins,ArrayRef<double> rmaxs,bool narrowRange,Type expressedType,bool isSigned)146 quant::UniformQuantizedPerAxisType mlir::quantfork::fakeQuantAttrsToType(
147 Location loc, unsigned numBits, int32_t quantizedDimension,
148 ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
149 Type expressedType, bool isSigned) {
150 size_t axisSize = rmins.size();
151 if (axisSize != rmaxs.size()) {
152 return (emitError(loc, "mismatched per-axis min and max size: ")
153 << axisSize << " vs. " << rmaxs.size(),
154 nullptr);
155 }
156
157 MLIRContext *ctx = expressedType.getContext();
158 Type storageType;
159 int64_t qmin;
160 int64_t qmax;
161 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
162 qmin, qmax)) {
163 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
164 nullptr);
165 }
166
167 SmallVector<double, 4> scales;
168 SmallVector<int64_t, 4> zeroPoints;
169 scales.reserve(axisSize);
170 zeroPoints.reserve(axisSize);
171 for (size_t axis = 0; axis != axisSize; ++axis) {
172 double rmin = rmins[axis];
173 double rmax = rmaxs[axis];
174 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
175 scales.push_back(1.0);
176 zeroPoints.push_back(qmin);
177 continue;
178 }
179
180 double scale;
181 int64_t nudgedZeroPoint;
182 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
183 scales.push_back(scale);
184 zeroPoints.push_back(nudgedZeroPoint);
185 }
186
187 unsigned flags = isSigned ? quant::QuantizationFlags::Signed : 0;
188 return quant::UniformQuantizedPerAxisType::getChecked(
189 loc, flags, storageType, expressedType, scales, zeroPoints,
190 quantizedDimension, qmin, qmax);
191 }
192