xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
18 
19 #include <tuple>
20 
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22 #ifndef TFLITE_WITH_RUY
23 
24 #include <cstdint>
25 #include <type_traits>
26 
27 #include "public/gemmlowp.h"
28 #include "tensorflow/lite/kernels/cpu_backend_context.h"
29 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
30 #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
31 
32 namespace tflite {
33 namespace cpu_backend_gemm {
34 namespace detail {
35 
36 template <typename DstScalar>
37 struct GemmlowpSaturatingCastStage {};
38 
39 template <>
40 struct GemmlowpSaturatingCastStage<std::uint8_t> {
41   using Type = gemmlowp::OutputStageSaturatingCastToUint8;
42 };
43 
44 template <>
45 struct GemmlowpSaturatingCastStage<std::int8_t> {
46   using Type = gemmlowp::OutputStageSaturatingCastToInt8;
47 };
48 
49 template <>
50 struct GemmlowpSaturatingCastStage<std::int16_t> {
51   using Type = gemmlowp::OutputStageSaturatingCastToInt16;
52 };
53 
54 template <typename DstScalar>
55 struct GemmlowpBitDepthParams {};
56 
57 template <>
58 struct GemmlowpBitDepthParams<std::uint8_t> {
59   using Type = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
60 };
61 
62 template <>
63 struct GemmlowpBitDepthParams<std::int8_t> {
64   using Type = gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams;
65 };
66 
67 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
68           typename DstScalar, QuantizationFlavor quantization_flavor>
69 struct GemmImplUsingGemmlowp {};
70 
71 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
72           typename DstScalar>
73 struct GemmImplUsingGemmlowp<
74     LhsScalar, RhsScalar, AccumScalar, DstScalar,
75     QuantizationFlavor::kIntegerWithUniformMultiplier> {
76   static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
77   static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
78   using SrcScalar = LhsScalar;
79 
80   static void Run(
81       const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data,
82       const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data,
83       const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
84       const GemmParams<std::int32_t, DstScalar,
85                        QuantizationFlavor::kIntegerWithUniformMultiplier>&
86           params,
87       CpuBackendContext* context) {
88     gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
89         gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
90     gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
91         gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
92     gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
93         dst_data, dst_params.rows, dst_params.cols);
94 
95     using ColVectorMap =
96         gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
97     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
98     scale_stage.result_offset_after_shift = dst_params.zero_point;
99     scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint;
100     scale_stage.result_exponent = params.multiplier_exponent;
101     using SaturatingCastStageType =
102         typename GemmlowpSaturatingCastStage<DstScalar>::Type;
103     gemmlowp::OutputStageClamp clamp_stage;
104     clamp_stage.min = params.clamp_min;
105     clamp_stage.max = params.clamp_max;
106     SaturatingCastStageType saturating_cast_stage;
107     using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
108     if (params.bias) {
109       ColVectorMap bias_vector(params.bias, lhs_params.rows);
110       gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
111       bias_addition_stage.bias_vector = bias_vector;
112       auto output_pipeline = std::make_tuple(
113           bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage);
114       gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
115           context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
116           &gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
117           output_pipeline);
118     } else {
119       auto output_pipeline =
120           std::make_tuple(scale_stage, clamp_stage, saturating_cast_stage);
121       gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
122           context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
123           &gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
124           output_pipeline);
125     }
126   }
127 };
128 
129 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
130           typename DstScalar>
131 struct GemmImplUsingGemmlowp<LhsScalar, RhsScalar, AccumScalar, DstScalar,
132                              QuantizationFlavor::kIntegerWithPerRowMultiplier> {
133   static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
134   static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
135   using SrcScalar = LhsScalar;
136 
137   static void Run(
138       const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data,
139       const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data,
140       const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
141       const GemmParams<std::int32_t, DstScalar,
142                        QuantizationFlavor::kIntegerWithPerRowMultiplier>&
143           params,
144       CpuBackendContext* context) {
145     // gemmlowp support for this per-channel path is limited to NEON.
146     // We fall back to ruy outside of NEON.
147 #ifdef GEMMLOWP_NEON
148     gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
149         gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
150     gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
151         gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
152     gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
153         dst_data, dst_params.rows, dst_params.cols);
154 
155     using ColVectorMap =
156         gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
157     ColVectorMap bias_vector(params.bias, lhs_params.rows);
158     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
159     bias_addition_stage.bias_vector = bias_vector;
160     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
161         gemmlowp::VectorShape::Col>
162         scale_stage;
163     scale_stage.result_offset_after_shift = dst_params.zero_point;
164     scale_stage.result_fixedpoint_multiplier =
165         ColVectorMap(params.multiplier_fixedpoint_perchannel, dst_params.rows);
166     scale_stage.result_exponent =
167         ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows);
168     using SaturatingCastStageType =
169         typename GemmlowpSaturatingCastStage<DstScalar>::Type;
170     gemmlowp::OutputStageClamp clamp_stage;
171     clamp_stage.min = params.clamp_min;
172     clamp_stage.max = params.clamp_max;
173     SaturatingCastStageType saturating_cast_stage;
174     auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
175                                            clamp_stage, saturating_cast_stage);
176     using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
177     gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
178         context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
179         -lhs_params.zero_point, -rhs_params.zero_point, output_pipeline);
180 #else
181     GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
182                      QuantizationFlavor::kIntegerWithPerRowMultiplier>::
183         Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
184             params, context);
185 #endif
186   }
187 };
188 
189 }  // namespace detail
190 }  // namespace cpu_backend_gemm
191 }  // namespace tflite
192 
193 #endif  // not TFLITE_WITH_RUY
194 
195 #endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_
196