xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
17 
18 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
19 
20 using namespace mlir;
21 using namespace mlir::quantfork;
22 
getDefaultStorageParams(unsigned numBits,bool narrowRange,bool isSigned,MLIRContext * ctx,Type & storageType,int64_t & qmin,int64_t & qmax)23 static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
24                                     bool isSigned, MLIRContext *ctx,
25                                     Type &storageType, int64_t &qmin,
26                                     int64_t &qmax) {
27   // Hard-coded type mapping from TFLite.
28   if (numBits <= 8) {
29     storageType = IntegerType::get(ctx, 8);
30     if (isSigned) {
31       qmin = -128;
32       qmax = 127;
33     } else {
34       qmin = 0;
35       qmax = 255;
36     }
37   } else if (numBits <= 16) {
38     storageType = IntegerType::get(ctx, 16);
39     if (isSigned) {
40       qmin = -32768;
41       qmax = 32767;
42     } else {
43       qmin = 0;
44       qmax = 65535;
45     }
46   } else if (numBits <= 32) {
47     storageType = IntegerType::get(ctx, 32);
48     if (isSigned) {
49       qmin = std::numeric_limits<int32_t>::min();
50       qmax = std::numeric_limits<int32_t>::max();
51     } else {
52       qmin = std::numeric_limits<uint32_t>::min();
53       qmax = std::numeric_limits<uint32_t>::max();
54     }
55   } else {
56     return true;
57   }
58 
59   // Handle narrowRange.
60   if (narrowRange) {
61     qmin += 1;
62   }
63   return false;
64 }
65 
66 // This is a specific implementation of nudging:
67 // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
68 // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
69 // point is derived from the shifted range, and the scale isn't changed. As
70 // a consequence some values, which are supposed in the original [rmin, rmax]
71 // range will be outside the shifted range and be clamped during quantization.
72 // TODO: we should nudge the scale as well, but that requires the
73 // fake quant op used in the training to use the nudged scale as well.
getNudgedScaleAndZeroPoint(int64_t qmin,int64_t qmax,double rmin,double rmax,double & scale,int64_t & nudgedZeroPoint)74 static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
75                                        double rmax, double &scale,
76                                        int64_t &nudgedZeroPoint) {
77   // Determine the scale.
78   const double qminDouble = qmin;
79   const double qmaxDouble = qmax;
80   scale = (rmax - rmin) / (qmaxDouble - qminDouble);
81 
82   // Zero point computation.
83   // In float, solve the affine equation for any known pair
84   // (real value, corresponding quantized value), of which, two such pairs
85   // are known: (rmin, qmin), (rmax, qmax).
86   // The arithmetic error on the zero point computed from either pair will be
87   // roughly machine_epsilon * (sum of absolute values of terms).
88   // Use the variant that adds the smaller error.
89   const double zeroPointFromMin = qminDouble - rmin / scale;
90   const double zeroPointFromMinError =
91       std::abs(qminDouble) + std::abs(rmin / scale);
92   const double zeroPointFromMax = qmaxDouble - rmax / scale;
93   const double zeroPointFromMaxError =
94       std::abs(qmaxDouble) + std::abs(rmax / scale);
95 
96   const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
97                                      ? zeroPointFromMin
98                                      : zeroPointFromMax;
99 
100   // Now nudge the zero point to be an integer.
101   nudgedZeroPoint = 0;
102   if (zeroPointDouble < qminDouble) {
103     nudgedZeroPoint = qmin;
104   } else if (zeroPointDouble > qmaxDouble) {
105     nudgedZeroPoint = qmax;
106   } else {
107     nudgedZeroPoint = round(zeroPointDouble);
108   }
109 
110   // By construction, the nudged zero point should always be in range.
111   assert(nudgedZeroPoint >= qmin);
112   assert(nudgedZeroPoint <= qmax);
113 }
114 
fakeQuantAttrsToType(Location loc,unsigned numBits,double rmin,double rmax,bool narrowRange,Type expressedType,bool isSigned)115 quant::UniformQuantizedType mlir::quantfork::fakeQuantAttrsToType(
116     Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange,
117     Type expressedType, bool isSigned) {
118   MLIRContext *ctx = expressedType.getContext();
119   unsigned flags = isSigned ? quant::QuantizationFlags::Signed : 0;
120   Type storageType;
121   int64_t qmin;
122   int64_t qmax;
123   if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
124                               qmin, qmax)) {
125     return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
126             nullptr);
127   }
128 
129   // Special case where min/max is close enough. The tensor contents are all
130   // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
131   // points and dequantized to 0.0.
132   if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
133     return quant::UniformQuantizedType::getChecked(
134         loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax);
135   }
136 
137   double scale;
138   int64_t nudgedZeroPoint;
139   getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
140 
141   return quant::UniformQuantizedType::getChecked(loc, flags, storageType,
142                                                  expressedType, scale,
143                                                  nudgedZeroPoint, qmin, qmax);
144 }
145 
fakeQuantAttrsToType(Location loc,unsigned numBits,int32_t quantizedDimension,ArrayRef<double> rmins,ArrayRef<double> rmaxs,bool narrowRange,Type expressedType,bool isSigned)146 quant::UniformQuantizedPerAxisType mlir::quantfork::fakeQuantAttrsToType(
147     Location loc, unsigned numBits, int32_t quantizedDimension,
148     ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
149     Type expressedType, bool isSigned) {
150   size_t axisSize = rmins.size();
151   if (axisSize != rmaxs.size()) {
152     return (emitError(loc, "mismatched per-axis min and max size: ")
153                 << axisSize << " vs. " << rmaxs.size(),
154             nullptr);
155   }
156 
157   MLIRContext *ctx = expressedType.getContext();
158   Type storageType;
159   int64_t qmin;
160   int64_t qmax;
161   if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
162                               qmin, qmax)) {
163     return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
164             nullptr);
165   }
166 
167   SmallVector<double, 4> scales;
168   SmallVector<int64_t, 4> zeroPoints;
169   scales.reserve(axisSize);
170   zeroPoints.reserve(axisSize);
171   for (size_t axis = 0; axis != axisSize; ++axis) {
172     double rmin = rmins[axis];
173     double rmax = rmaxs[axis];
174     if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
175       scales.push_back(1.0);
176       zeroPoints.push_back(qmin);
177       continue;
178     }
179 
180     double scale;
181     int64_t nudgedZeroPoint;
182     getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
183     scales.push_back(scale);
184     zeroPoints.push_back(nudgedZeroPoint);
185   }
186 
187   unsigned flags = isSigned ? quant::QuantizationFlags::Signed : 0;
188   return quant::UniformQuantizedPerAxisType::getChecked(
189       loc, flags, storageType, expressedType, scales, zeroPoints,
190       quantizedDimension, qmin, qmax);
191 }
192