xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
18 
19 #include <cstdint>
20 
21 #include "ruy/profiler/instrumentation.h"  // from @ruy
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h"
24 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
25 #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
26 
27 #ifndef TFLITE_WITH_RUY
28 #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
29 #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
30 #include "tensorflow/lite/kernels/cpu_backend_gemm_x86.h"
31 #endif
32 
33 namespace tflite {
34 
35 namespace cpu_backend_gemm {
36 
37 // The main entry point for CpuBackendGemm::Gemm.
38 //
39 // If TFLITE_WITH_RUY is set, CpuBackendGemm::Gemm will always go to Ruy aka
40 // GemmImplUsingRuy. Other cases are as follows:
41 //
42 //                    |Quantized (uint8)|Quantized (int8)| Float |
43 // TFLITE_WITH_RUY    |      Ruy        |      Ruy       | Ruy   |
44 // !TFLITE_WITH_RUY   |      gemmlowp   |  Ruy/gemmlowp* | eigen |
45 // * - Ruy if NEON is not available.
46 
47 //  On x86 platforms:
48 //  (default)         |      gemmlowp   |     Ruy        | eigen |
49 //  TFLITE_X86_RUY_\  |      Ruy        |     Ruy        | Ruy   |
50 //  ENABLED && (AVX
51 //  or above available)
52 
53 #if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM)
54 /* GEMM dispatch implementation for x86.
55  */
56 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
57           typename DstScalar, QuantizationFlavor quantization_flavor>
58 struct GemmImpl : detail::GemmImplX86<LhsScalar, RhsScalar, AccumScalar,
59                                       DstScalar, quantization_flavor> {};
60 #else
61 /* Generic implementation using ruy.
62  * Non-ruy implementation will be partial specializations of this template.
63  */
64 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
65           typename DstScalar, QuantizationFlavor quantization_flavor>
66 struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
67                                            DstScalar, quantization_flavor> {};
68 
69 #if !defined(TFLITE_WITH_RUY)
70 
71 /* Specializations using gemmlowp */
72 template <typename SrcScalar, typename DstScalar,
73           QuantizationFlavor quantization_flavor>
74 struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
75                 quantization_flavor>
76     : detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t,
77                                     DstScalar, quantization_flavor> {};
78 
79 // When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile
80 // outside of NEON. We avoid the compilation failure by subspecializing these
81 // cases, rerouting it back to ruy.
82 #if !defined(GEMMLOWP_NEON)
83 template <typename SrcScalar, QuantizationFlavor quantization_flavor>
84 struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
85                 quantization_flavor>
86     : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
87                                quantization_flavor> {};
88 
89 template <typename DstScalar, QuantizationFlavor quantization_flavor>
90 struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar,
91                 quantization_flavor>
92     : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
93                                DstScalar, quantization_flavor> {};
94 
95 template <QuantizationFlavor quantization_flavor>
96 struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
97                 quantization_flavor>
98     : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
99                                std::int8_t, quantization_flavor> {};
100 #endif  // not GEMMLOWP_NEON
101 
102 /* Specializations using Eigen */
103 
104 template <>
105 struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
106     : detail::GemmImplUsingEigen {};
107 
108 #endif  // not TFLITE_WITH_RUY
109 
110 #endif  // not TFLITE_WITH_RUY and TFLITE_X86_PLATFORM
111 
112 /* Public entry point */
113 
114 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
115           typename DstScalar, QuantizationFlavor quantization_flavor>
Gemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<DstScalar> & dst_params,DstScalar * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,CpuBackendContext * context)116 void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
117           const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
118           const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
119           const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
120           CpuBackendContext* context) {
121   ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
122   ValidateParams(lhs_params, rhs_params, dst_params, params);
123   if (!IsValidGemm(lhs_params, rhs_params, dst_params)) {
124     // For now, assert in debug mode, return in opt.
125     // TODO(b/183099395) Eliminate debug/release discrepancy by plumbing in
126     // TFLiteStatus so we can return an error here.
127     TFLITE_DCHECK(false);
128     return;
129   }
130   // In some cases we want to unconditionally use ruy as the backend, overriding
131   // the `tflite_with_ruy` setting and the platform default.
132   bool must_use_ruy = false;
133   if (context->use_caching()) {
134     // Only ruy supports caching of pre-packed matrices. Due to the large
135     // performance impact in the cases where it's typically used, this overrides
136     // the default.
137     must_use_ruy = true;
138   }
139   if (lhs_params.order != Order::kRowMajor ||
140       rhs_params.order != Order::kColMajor ||
141       dst_params.order != Order::kColMajor) {
142     // ruy supports all 2^3=8 combinations of storage orders with comparable
143     // performance. In ruy, it's only a runtime switch. In other backends
144     // (gemmlowp, Eigen), storage orders are template parameters, supporting
145     // all 8 combinations would be up to a 8-fold code size increase, so we
146     // prefer to force usage of ruy in these cases.
147     must_use_ruy = true;
148   }
149   if (must_use_ruy) {
150     detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
151                              quantization_flavor>::Run(lhs_params, lhs_data,
152                                                        rhs_params, rhs_data,
153                                                        dst_params, dst_data,
154                                                        params, context);
155     return;
156   }
157   // If we did not choose to force usage of ruy above, then we may now consider
158   // using custom GEMV code for the matrix*vector cases.
159   const bool try_custom_gemv = (dst_params.cols == 1);
160   if (try_custom_gemv) {
161     // GEMV case: try a custom fast GEMV path. It will return true if it
162     // actually handled it.
163     if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
164                            dst_params, dst_data, params, context)) {
165       return;
166     }
167   }
168   // Generic case: dispatch to any backend as a general GEMM.
169   GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
170            quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
171                                      dst_params, dst_data, params, context);
172 }
173 
174 // Special path for 16x8 quant gemm.
175 template <QuantizationFlavor quantization_flavor>
Gemm(const MatrixParams<int8_t> & lhs_params,const int8_t * lhs_data,const MatrixParams<int16_t> & rhs_params,const int16_t * rhs_data,const MatrixParams<int16_t> & dst_params,int16_t * dst_data,const GemmParams<int32_t,int16,quantization_flavor> & params,CpuBackendContext * context)176 void Gemm(const MatrixParams<int8_t>& lhs_params, const int8_t* lhs_data,
177           const MatrixParams<int16_t>& rhs_params, const int16_t* rhs_data,
178           const MatrixParams<int16_t>& dst_params, int16_t* dst_data,
179           const GemmParams<int32_t, int16, quantization_flavor>& params,
180           CpuBackendContext* context) {
181   ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
182   ValidateParams(lhs_params, rhs_params, dst_params, params);
183   if (!IsValidGemm(lhs_params, rhs_params, dst_params)) {
184     TFLITE_DCHECK(false);
185     return;
186   }
187 
188   // Currently, only Ruy backend supports 16x8 quant gemm so we use ruy
189   // only.
190   detail::GemmImplUsingRuy<int8_t, int16_t, int32_t, int16,
191                            quantization_flavor>::Run(lhs_params, lhs_data,
192                                                      rhs_params, rhs_data,
193                                                      dst_params, dst_data,
194                                                      params, context);
195 }
196 
197 // Special path for gemm with raw accumulator case. i.e. AccumScalar ==
198 // DstScalar == int32 case.
199 template <typename LhsScalar, typename RhsScalar,
200           QuantizationFlavor quantization_flavor>
Gemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<int32_t> & dst_params,int32_t * dst_data,const GemmParams<int32_t,int32_t,quantization_flavor> & params,CpuBackendContext * context)201 void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
202           const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
203           const MatrixParams<int32_t>& dst_params, int32_t* dst_data,
204           const GemmParams<int32_t, int32_t, quantization_flavor>& params,
205           CpuBackendContext* context) {
206   ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
207   ValidateParams(lhs_params, rhs_params, dst_params, params);
208 
209   // Currently, only Ruy backend supports get raw accumulator, so we use ruy
210   // only.
211   ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
212   detail::GemmImplUsingRuy<LhsScalar, RhsScalar, int32_t, int32_t,
213                            quantization_flavor>::Run(lhs_params, lhs_data,
214                                                      rhs_params, rhs_data,
215                                                      dst_params, dst_data,
216                                                      params, context);
217 }
218 
219 }  // namespace cpu_backend_gemm
220 
221 }  // namespace tflite
222 
223 #endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
224