1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
18
19 #include "ruy/matrix.h" // from @ruy
20 #include "ruy/mul_params.h" // from @ruy
21 #include "ruy/ruy.h" // from @ruy
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25
26 namespace tflite {
27 namespace cpu_backend_gemm {
28 namespace detail {
29
ToRuyCachePolicy(CachePolicy cache_policy)30 inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
31 switch (cache_policy) {
32 case CachePolicy::kNeverCache:
33 return ruy::CachePolicy::kNeverCache;
34 case CachePolicy::kCacheIfLargeSpeedup:
35 return ruy::CachePolicy::kCacheIfLargeSpeedup;
36 case CachePolicy::kAlwaysCache:
37 return ruy::CachePolicy::kAlwaysCache;
38 default:
39 TFLITE_DCHECK(false);
40 return ruy::CachePolicy::kNeverCache;
41 }
42 }
43
44 template <typename Scalar, typename DataPointer>
45 void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
46 ruy::Matrix<Scalar>* dst, bool use_caching = false) {
47 ruy::Order ruy_order = params.order == Order::kColMajor
48 ? ruy::Order::kColMajor
49 : ruy::Order::kRowMajor;
50 ruy::MakeSimpleLayout(params.rows, params.cols, ruy_order,
51 dst->mutable_layout());
52 // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
53 // It does care whether we assign to it a Scalar* or a const Scalar*.
54 dst->set_data(data_ptr);
55 dst->set_zero_point(params.zero_point);
56 if (use_caching) {
57 dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
58 }
59 }
60
61 // Floating-point case.
62 template <typename AccumScalar, typename DstScalar,
63 QuantizationFlavor quantization_flavor>
64 struct MakeRuyMulParamsImpl final {
Runfinal65 static void Run(
66 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
67 ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
68 static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint,
69 "");
70 ruy_mul_params->set_bias(params.bias);
71 ruy_mul_params->set_clamp_min(params.clamp_min);
72 ruy_mul_params->set_clamp_max(params.clamp_max);
73 }
74 };
75
76 // Integer-quantized case with destination type narrower than int32
77 template <typename DstScalar, QuantizationFlavor quantization_flavor>
78 struct MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor>
79 final {
80 static void Run(
81 const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
82 ruy::MulParams<std::int32_t, DstScalar>* ruy_mul_params) {
83 static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "");
84 if (quantization_flavor ==
85 QuantizationFlavor::kIntegerWithUniformMultiplier) {
86 ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
87 ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
88 }
89 if (quantization_flavor ==
90 QuantizationFlavor::kIntegerWithPerRowMultiplier) {
91 ruy_mul_params->set_multiplier_fixedpoint_perchannel(
92 params.multiplier_fixedpoint_perchannel);
93 ruy_mul_params->set_multiplier_exponent_perchannel(
94 params.multiplier_exponent_perchannel);
95 }
96 ruy_mul_params->set_bias(params.bias);
97 ruy_mul_params->set_clamp_min(params.clamp_min);
98 ruy_mul_params->set_clamp_max(params.clamp_max);
99 }
100 };
101
102 // Raw-integer case with destination type int32.
103 template <QuantizationFlavor quantization_flavor>
104 struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor>
105 final {
106 static void Run(
107 const GemmParams<std::int32_t, std::int32_t, quantization_flavor>& params,
108 ruy::MulParams<std::int32_t, std::int32_t>* ruy_mul_params) {
109 ruy_mul_params->set_bias(params.bias);
110 }
111 };
112
113 template <typename AccumScalar, typename DstScalar,
114 QuantizationFlavor quantization_flavor>
115 void MakeRuyMulParams(
116 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
117 ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
118 MakeRuyMulParamsImpl<AccumScalar, DstScalar, quantization_flavor>::Run(
119 params, ruy_mul_params);
120 }
121
122 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
123 typename DstScalar, QuantizationFlavor quantization_flavor>
124 struct GemmImplUsingRuy {
125 static void Run(
126 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
127 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
128 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
129 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
130 CpuBackendContext* context) {
131 ruy::Matrix<LhsScalar> ruy_lhs;
132 ruy::Matrix<RhsScalar> ruy_rhs;
133 ruy::Matrix<DstScalar> ruy_dst;
134 MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs, context->use_caching());
135 MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs, context->use_caching());
136 MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
137
138 ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params;
139 MakeRuyMulParams(params, &ruy_mul_params);
140
141 ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, context->ruy_context(),
142 &ruy_dst);
143 }
144 };
145
146 } // namespace detail
147 } // namespace cpu_backend_gemm
148 } // namespace tflite
149
150 #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
151