xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
18 
19 #include "ruy/matrix.h"  // from @ruy
20 #include "ruy/mul_params.h"  // from @ruy
21 #include "ruy/ruy.h"  // from @ruy
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 
26 namespace tflite {
27 namespace cpu_backend_gemm {
28 namespace detail {
29 
ToRuyCachePolicy(CachePolicy cache_policy)30 inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
31   switch (cache_policy) {
32     case CachePolicy::kNeverCache:
33       return ruy::CachePolicy::kNeverCache;
34     case CachePolicy::kCacheIfLargeSpeedup:
35       return ruy::CachePolicy::kCacheIfLargeSpeedup;
36     case CachePolicy::kAlwaysCache:
37       return ruy::CachePolicy::kAlwaysCache;
38     default:
39       TFLITE_DCHECK(false);
40       return ruy::CachePolicy::kNeverCache;
41   }
42 }
43 
44 template <typename Scalar, typename DataPointer>
45 void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
46                    ruy::Matrix<Scalar>* dst, bool use_caching = false) {
47   ruy::Order ruy_order = params.order == Order::kColMajor
48                              ? ruy::Order::kColMajor
49                              : ruy::Order::kRowMajor;
50   ruy::MakeSimpleLayout(params.rows, params.cols, ruy_order,
51                         dst->mutable_layout());
52   // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
53   // It does care whether we assign to it a Scalar* or a const Scalar*.
54   dst->set_data(data_ptr);
55   dst->set_zero_point(params.zero_point);
56   if (use_caching) {
57     dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
58   }
59 }
60 
61 // Floating-point case.
62 template <typename AccumScalar, typename DstScalar,
63           QuantizationFlavor quantization_flavor>
64 struct MakeRuyMulParamsImpl final {
Runfinal65   static void Run(
66       const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
67       ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
68     static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint,
69                   "");
70     ruy_mul_params->set_bias(params.bias);
71     ruy_mul_params->set_clamp_min(params.clamp_min);
72     ruy_mul_params->set_clamp_max(params.clamp_max);
73   }
74 };
75 
76 // Integer-quantized case with destination type narrower than int32
77 template <typename DstScalar, QuantizationFlavor quantization_flavor>
78 struct MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor>
79     final {
80   static void Run(
81       const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
82       ruy::MulParams<std::int32_t, DstScalar>* ruy_mul_params) {
83     static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "");
84     if (quantization_flavor ==
85         QuantizationFlavor::kIntegerWithUniformMultiplier) {
86       ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
87       ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
88     }
89     if (quantization_flavor ==
90         QuantizationFlavor::kIntegerWithPerRowMultiplier) {
91       ruy_mul_params->set_multiplier_fixedpoint_perchannel(
92           params.multiplier_fixedpoint_perchannel);
93       ruy_mul_params->set_multiplier_exponent_perchannel(
94           params.multiplier_exponent_perchannel);
95     }
96     ruy_mul_params->set_bias(params.bias);
97     ruy_mul_params->set_clamp_min(params.clamp_min);
98     ruy_mul_params->set_clamp_max(params.clamp_max);
99   }
100 };
101 
102 // Raw-integer case with destination type int32.
103 template <QuantizationFlavor quantization_flavor>
104 struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor>
105     final {
106   static void Run(
107       const GemmParams<std::int32_t, std::int32_t, quantization_flavor>& params,
108       ruy::MulParams<std::int32_t, std::int32_t>* ruy_mul_params) {
109     ruy_mul_params->set_bias(params.bias);
110   }
111 };
112 
113 template <typename AccumScalar, typename DstScalar,
114           QuantizationFlavor quantization_flavor>
115 void MakeRuyMulParams(
116     const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
117     ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
118   MakeRuyMulParamsImpl<AccumScalar, DstScalar, quantization_flavor>::Run(
119       params, ruy_mul_params);
120 }
121 
122 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
123           typename DstScalar, QuantizationFlavor quantization_flavor>
124 struct GemmImplUsingRuy {
125   static void Run(
126       const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
127       const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
128       const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
129       const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
130       CpuBackendContext* context) {
131     ruy::Matrix<LhsScalar> ruy_lhs;
132     ruy::Matrix<RhsScalar> ruy_rhs;
133     ruy::Matrix<DstScalar> ruy_dst;
134     MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs, context->use_caching());
135     MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs, context->use_caching());
136     MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
137 
138     ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params;
139     MakeRuyMulParams(params, &ruy_mul_params);
140 
141     ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, context->ruy_context(),
142              &ruy_dst);
143   }
144 };
145 
146 }  // namespace detail
147 }  // namespace cpu_backend_gemm
148 }  // namespace tflite
149 
150 #endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
151