xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_params.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
18 
19 #include <cstdint>
20 #include <limits>
21 #include <type_traits>
22 
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 
25 namespace tflite {
26 
27 namespace cpu_backend_gemm {
28 
29 // Matrix storage order: column-major or row-major.
30 enum class Order { kColMajor, kRowMajor };
31 
32 enum class CachePolicy : std::uint8_t {
33   kNeverCache,
34   kCacheIfLargeSpeedup,
35   kAlwaysCache,
36 };
37 
DefaultCachePolicy(bool is_constant_data)38 inline CachePolicy DefaultCachePolicy(bool is_constant_data) {
39   return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup
40                           : CachePolicy::kNeverCache;
41 }
42 
43 // MatrixParams encapsulates the parameters that Gemm needs about each
44 // matrix, besides the buffer data pointer.
45 // Compare to ruy::Matrix, which also encapsulates the data pointer.
46 // Rationale for leaving the data pointer out of here: doing so
47 // requires complicated const-correctness mechanics. See
48 // ruy::ConstCheckingPtr.
49 template <typename Scalar>
50 struct MatrixParams {
51   // Storage layout order. For now we only do plain linear non-strided
52   // layout. It would be easy to support a stride if needed.
53   Order order = Order::kColMajor;
54   // Number of rows of the matrix.
55   int rows = 0;
56   // Number of columns of the matrix.
57   int cols = 0;
58   // The zero_point, i.e. which Scalar value is to be interpreted as zero.
59   // When Scalar is floating-point, this must be 0.
60   Scalar zero_point = 0;
61   // When the data pointed to by this matrix is constant data, so that it is
62   // valid to assume that equality of pointers implies equality of data,
63   // a CachePolicy may be used instead of the default kNeverCache,
64   // which will enable ruy to take advantage of this constancy of the data to
65   // cache the packing work, which can be a large speedup in matrix*vector
66   // and other narrow shapes.
67   CachePolicy cache_policy = CachePolicy::kNeverCache;
68 };
69 
70 // Enumeration of broad categories of Gemm.
71 //
72 // The primary reason for this to exist is to allow Gemm to compile
73 // only uniform-quantized or only per-channel-quantized code paths.
74 // This is unneeded with ruy as the back-end, as this is only a runtime
75 // difference in ruy, but with gemmlowp these really are separate code
76 // paths and templatizing in a QuantizationFlavor is necessary to avoid
77 // compiling unused gemmlowp code. Indeed, TFLite currently uses
78 // uint8 with uniform quantization and int8 with per-channel quantization,
79 // and does not use uint8 with per-channel. We want to avoid compiling
80 // the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
81 //
82 // It's possible to drop this in the future if gemmlowp goes away and no
83 // other then-relevant backend library handles quantized paths in a way that
84 // requires knowing this at compile-time.
85 enum class QuantizationFlavor {
86   // Floating-point Gemm: the accumulators are not multiplied by any
87   // 'multiplier'.
88   kFloatingPoint,
89   // Quantized Gemm using a single multiplier for all accumulators.
90   kIntegerWithUniformMultiplier,
91   // Quantized Gemm using a separate multipliers for accumulators of each
92   // row of the destination matrix. This is what is called 'per-channel'
93   // in GemmParams. Here we use the more specific 'per-row' terminology
94   // to allow for the possibility of 'per-column' in the future, and to
95   // allow for that to be a separate code path in some back-end such as
96   // gemmlowp.
97   kIntegerWithPerRowMultiplier
98 };
99 
100 // Additional parameters that Gemm needs, beyond what falls into
101 // the MatrixParams that it takes. Compare to ruy::Spec.
102 //
103 // Decoupling AccumScalar from DstScalar (rather than deducing it from that)
104 // is useful future-proofing. Think of a float16 path using float32 accum.
105 //
106 // QuantizationFlavor is passed here even though it's technically not used
107 // in this class. This is so that we retain the ability in the future to
108 // specialize this class for quantization flavor, and this allows for
109 // Gemm to be templatized in quantization_flavor via the GemmParams that it
110 // takes, allowing for automatic template parameter deduction to take place,
111 // so that most call sites don't need to specify a QuantizationFlavor
112 // (only those that need perchannel quantization do).
113 template <typename AccumScalar, typename DstScalar,
114           QuantizationFlavor quantization_flavor =
115               std::is_floating_point<AccumScalar>::value
116                   ? QuantizationFlavor::kFloatingPoint
117                   : QuantizationFlavor::kIntegerWithUniformMultiplier>
118 struct GemmParams {
119   // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
120   // of the multiplier by which accumulators are multiplied before being casted
121   // to the destination type.
122   AccumScalar multiplier_fixedpoint = 0;
123   // Only for non-floating-point cases. The exponent part of the aforementioned
124   // multiplier.
125   int multiplier_exponent = 0;
126   // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
127   // point to a buffer of as many values as there are rows in the destination
128   // matrix. Each row of the destination matrix will use the corresponding
129   // buffer element instead of multiplier_fixedpoint.
130   const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
131   // Per-channel variant of multiplier_exponent. If not nullptr, this must
132   // point to a buffer of as many values as there are rows in the destination
133   // matrix. Each row of the destination matrix will use the corresponding
134   // buffer element instead of multiplier_exponent.
135   //
136   // Either none or both of multiplier_exponent_perchannel and
137   // multiplier_fixedpoint_perchannel must be nullptr.
138   const int* multiplier_exponent_perchannel = nullptr;
139   // The bias vector data, if not null.
140   const AccumScalar* bias = nullptr;
141   // min clamp bound of destination values.
142   DstScalar clamp_min = std::is_floating_point<DstScalar>::value
143                             ? -std::numeric_limits<DstScalar>::infinity()
144                             : std::numeric_limits<DstScalar>::lowest();
145   // max clamp bound of destination values.
146   DstScalar clamp_max = std::is_floating_point<DstScalar>::value
147                             ? std::numeric_limits<DstScalar>::infinity()
148                             : std::numeric_limits<DstScalar>::max();
149 };
150 
151 /* Convenience typedefs */
152 
153 template <typename DstScalar>
154 using QuantizedGemmParams = GemmParams<std::int32_t, DstScalar>;
155 
156 using FloatGemmParams = GemmParams<float, float>;
157 
158 /* Validation functions */
159 
160 // Note that this uses TFLITE_DCHECK from kernels/internal/compatibility.h
161 // and not TF_LITE_ASSERT from op_macros.h. We want this to be explicitly
162 // debug-build-only assertions so that there's not reason not to
163 // generously validate, and TF_LITE_ASSERT is actually at the moment
164 // a release-build assertion. See b/131587258.
165 
166 // Validates self-consistency of GemmParams.
167 template <typename AccumScalar, typename DstScalar,
168           QuantizationFlavor quantization_flavor>
ValidateGemmParams(const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params)169 void ValidateGemmParams(
170     const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
171   // Guard consistency of the quantized multiplier fields.
172   if (quantization_flavor == QuantizationFlavor::kFloatingPoint) {
173     TFLITE_DCHECK(!params.multiplier_fixedpoint);
174     TFLITE_DCHECK(!params.multiplier_exponent);
175     TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
176     TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
177   } else if (quantization_flavor ==
178                  QuantizationFlavor::kIntegerWithUniformMultiplier &&
179              !std::is_same<DstScalar, int32_t>::value) {
180     TFLITE_DCHECK(params.multiplier_fixedpoint);
181     // Nothing to check about multiplier_exponent
182     TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
183     TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
184   } else if (quantization_flavor ==
185                  QuantizationFlavor::kIntegerWithPerRowMultiplier &&
186              !std::is_same<DstScalar, int32_t>::value) {
187     TFLITE_DCHECK(!params.multiplier_fixedpoint);
188     TFLITE_DCHECK(!params.multiplier_exponent);
189     TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel);
190     TFLITE_DCHECK(params.multiplier_exponent_perchannel);
191   } else {
192     // For the get raw accumulator case, we should make sure none of the
193     // quantization params are set.
194     TFLITE_DCHECK(!params.multiplier_fixedpoint);
195     TFLITE_DCHECK(!params.multiplier_exponent);
196     TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
197     TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
198   }
199 }
200 
201 namespace detail {
202 
203 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
204           typename DstScalar, QuantizationFlavor quantization_flavor>
205 struct ValidateTypes {
206   // This generic implementation is for quantized flavors.
207   // kFloatingPoint will be a specialization below.
208   static_assert(!std::is_floating_point<LhsScalar>::value, "");
209   static_assert(!std::is_floating_point<RhsScalar>::value, "");
210   static_assert(!std::is_floating_point<AccumScalar>::value, "");
211   // No requirement on DstScalar --- we might in the future allow it
212   // to be floating point even in a quantized Gemm.
213 };
214 
215 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
216           typename DstScalar>
217 struct ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
218                      QuantizationFlavor::kFloatingPoint> {
219   static_assert(std::is_floating_point<LhsScalar>::value, "");
220   static_assert(std::is_floating_point<RhsScalar>::value, "");
221   static_assert(std::is_floating_point<AccumScalar>::value, "");
222   static_assert(std::is_floating_point<DstScalar>::value, "");
223 };
224 
225 }  // namespace detail
226 
227 // Validates overall consistency of all the parameters taken by a Gemm call:
228 // the 3 MatrixParams and the GemmParams.
229 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
230           typename DstScalar, QuantizationFlavor quantization_flavor>
231 void ValidateParams(
232     const MatrixParams<LhsScalar>& lhs_params,
233     const MatrixParams<RhsScalar>& rhs_params,
234     const MatrixParams<DstScalar>& dst_params,
235     const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
236   (void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
237                               quantization_flavor>();
238   ValidateGemmParams(params);
239 }
240 
241 // Test if the Gemm is degenerate in some way, e.g. nonsensical dimenions.
242 template <typename LhsScalar, typename RhsScalar, typename DstScalar>
243 bool IsValidGemm(const MatrixParams<LhsScalar>& lhs_params,
244                  const MatrixParams<RhsScalar>& rhs_params,
245                  const MatrixParams<DstScalar>& dst_params) {
246   bool valid = true;
247   valid &= lhs_params.rows >= 1;
248   valid &= lhs_params.cols >= 1;
249   valid &= rhs_params.rows >= 1;
250   valid &= rhs_params.cols >= 1;
251   valid &= dst_params.rows >= 1;
252   valid &= dst_params.cols >= 1;
253   valid &= lhs_params.cols == rhs_params.rows;
254   valid &= rhs_params.cols == dst_params.cols;
255   valid &= lhs_params.rows == lhs_params.rows;
256   return valid;
257 }
258 
259 }  // namespace cpu_backend_gemm
260 
261 }  // namespace tflite
262 
263 #endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
264