xref: /aosp_15_r20/external/gemmlowp/internal/output.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // output.h: processing the 32-bit accumulators output by the unpack
16*5f39d1b3SJooyung Han // stage, obtaining the final result matrix entries and storing them into
17*5f39d1b3SJooyung Han // the destination matrix.
18*5f39d1b3SJooyung Han 
19*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_OUTPUT_H_
20*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_OUTPUT_H_
21*5f39d1b3SJooyung Han 
22*5f39d1b3SJooyung Han #include <cmath>
23*5f39d1b3SJooyung Han #include <tuple>
24*5f39d1b3SJooyung Han #include <type_traits>
25*5f39d1b3SJooyung Han #include <typeinfo>
26*5f39d1b3SJooyung Han 
27*5f39d1b3SJooyung Han #include "../fixedpoint/fixedpoint.h"
28*5f39d1b3SJooyung Han #include "../public/output_stages.h"
29*5f39d1b3SJooyung Han #include "simd_wrappers.h"
30*5f39d1b3SJooyung Han 
31*5f39d1b3SJooyung Han namespace gemmlowp {
32*5f39d1b3SJooyung Han 
33*5f39d1b3SJooyung Han template <typename OutputStage, typename InputBufferType>
34*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl {
35*5f39d1b3SJooyung Han   // This generic template body should never be hit.
36*5f39d1b3SJooyung Han   static_assert(
37*5f39d1b3SJooyung Han       std::is_same<InputBufferType, void>::value,
38*5f39d1b3SJooyung Han       "Unimplemented: missing implementation of this output pipeline stage "
39*5f39d1b3SJooyung Han       "for this data type. This would happen if some architecture-specific "
40*5f39d1b3SJooyung Han       "SIMD back-end (output_$arch.h) were incomplete.");
41*5f39d1b3SJooyung Han };
42*5f39d1b3SJooyung Han 
43*5f39d1b3SJooyung Han template <typename OutputStage, typename InputType>
44*5f39d1b3SJooyung Han struct OutputStageEvalImpl {
45*5f39d1b3SJooyung Han   static constexpr int kRows = InputType::kRows;
46*5f39d1b3SJooyung Han   static constexpr int kCols = InputType::kCols;
47*5f39d1b3SJooyung Han   using InputBufferType = typename InputType::BufferType;
48*5f39d1b3SJooyung Han   using BufferEvalImplType =
49*5f39d1b3SJooyung Han       OutputStageEvalBufferImpl<OutputStage, InputBufferType>;
50*5f39d1b3SJooyung Han   using OutputBufferType = typename BufferEvalImplType::OutputType;
51*5f39d1b3SJooyung Han   using OutputScalarType = typename OutputBufferType::ScalarType;
52*5f39d1b3SJooyung Han   using OutputType = RegisterBlock<OutputScalarType, kRows, kCols>;
53*5f39d1b3SJooyung Han 
OutputStageEvalImplOutputStageEvalImpl54*5f39d1b3SJooyung Han   OutputStageEvalImpl(const OutputStage& s) : buffer_eval_impl(s) {}
55*5f39d1b3SJooyung Han 
EvalOutputStageEvalImpl56*5f39d1b3SJooyung Han   OutputType Eval(InputType input, int, int) const {
57*5f39d1b3SJooyung Han     OutputType output;
58*5f39d1b3SJooyung Han     output.buf = buffer_eval_impl.Eval(input.buf);
59*5f39d1b3SJooyung Han     return output;
60*5f39d1b3SJooyung Han   }
61*5f39d1b3SJooyung Han 
62*5f39d1b3SJooyung Han   const BufferEvalImplType buffer_eval_impl;
63*5f39d1b3SJooyung Han };
64*5f39d1b3SJooyung Han 
65*5f39d1b3SJooyung Han template <int Size>
66*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageQuantizeDownInt32ToUint8Scale,
67*5f39d1b3SJooyung Han                                  RegisterBuffer<std::int32_t, Size>> {
68*5f39d1b3SJooyung Han   using InputType = RegisterBuffer<std::int32_t, Size>;
69*5f39d1b3SJooyung Han   using OutputType = RegisterBuffer<std::int32_t, Size>;
70*5f39d1b3SJooyung Han 
71*5f39d1b3SJooyung Han   typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage;
72*5f39d1b3SJooyung Han 
73*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
74*5f39d1b3SJooyung Han 
75*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
76*5f39d1b3SJooyung Han     const int result_shift = output_stage.result_shift;
77*5f39d1b3SJooyung Han     const std::int32_t result_mult_int = output_stage.result_mult_int;
78*5f39d1b3SJooyung Han     using RegisterType = typename InputType::RegisterType;
79*5f39d1b3SJooyung Han     const RegisterType result_offset =
80*5f39d1b3SJooyung Han         Dup<RegisterType>(output_stage.result_offset);
81*5f39d1b3SJooyung Han     OutputType output;
82*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
83*5f39d1b3SJooyung Han       output.reg[i] = RoundingDivideByPOT(
84*5f39d1b3SJooyung Han           Mul(Add(input.reg[i], result_offset), result_mult_int), result_shift);
85*5f39d1b3SJooyung Han     }
86*5f39d1b3SJooyung Han     return output;
87*5f39d1b3SJooyung Han   }
88*5f39d1b3SJooyung Han 
89*5f39d1b3SJooyung Han   const OutputStage& output_stage;
90*5f39d1b3SJooyung Han };
91*5f39d1b3SJooyung Han 
92*5f39d1b3SJooyung Han template <int Rows, int Cols, VectorShape Shape>
93*5f39d1b3SJooyung Han struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>,
94*5f39d1b3SJooyung Han                            RegisterBlock<std::int32_t, Rows, Cols>> {
95*5f39d1b3SJooyung Han   typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
96*5f39d1b3SJooyung Han   typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
97*5f39d1b3SJooyung Han   typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> OutputStage;
98*5f39d1b3SJooyung Han 
99*5f39d1b3SJooyung Han   OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
100*5f39d1b3SJooyung Han 
101*5f39d1b3SJooyung Han   OutputType Eval(InputType input, int row, int col) const {
102*5f39d1b3SJooyung Han     OutputType output;
103*5f39d1b3SJooyung Han     const int result_shift = output_stage.result_shift;
104*5f39d1b3SJooyung Han     const int pos = Shape == VectorShape::Col ? row : col;
105*5f39d1b3SJooyung Han     const auto result_mult_int =
106*5f39d1b3SJooyung Han         LoadForBroadcasting<InputType>(output_stage.result_mult_int, pos);
107*5f39d1b3SJooyung Han     const auto result_offset =
108*5f39d1b3SJooyung Han         LoadForBroadcasting<InputType>(output_stage.result_offset, pos);
109*5f39d1b3SJooyung Han     const auto dividend = BroadcastMul<InputType>(
110*5f39d1b3SJooyung Han         BroadcastAdd<InputType>(input, result_offset), result_mult_int);
111*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
112*5f39d1b3SJooyung Han       output.buf.reg[i] =
113*5f39d1b3SJooyung Han           RoundingDivideByPOT(dividend.buf.reg[i], result_shift);
114*5f39d1b3SJooyung Han     }
115*5f39d1b3SJooyung Han     return output;
116*5f39d1b3SJooyung Han   }
117*5f39d1b3SJooyung Han 
118*5f39d1b3SJooyung Han   const OutputStage& output_stage;
119*5f39d1b3SJooyung Han };
120*5f39d1b3SJooyung Han 
121*5f39d1b3SJooyung Han template <int Size>
122*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<
123*5f39d1b3SJooyung Han     OutputStageQuantizeDownInt32ByFixedPoint,
124*5f39d1b3SJooyung Han     RegisterBuffer<std::int32_t, Size>> {
125*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> InputType;
126*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> OutputType;
127*5f39d1b3SJooyung Han 
128*5f39d1b3SJooyung Han   typedef OutputStageQuantizeDownInt32ByFixedPoint OutputStage;
129*5f39d1b3SJooyung Han 
130*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
131*5f39d1b3SJooyung Han 
132*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
133*5f39d1b3SJooyung Han     OutputType output;
134*5f39d1b3SJooyung Han     using RegisterType = typename InputType::RegisterType;
135*5f39d1b3SJooyung Han     const RegisterType result_offset_after_shift =
136*5f39d1b3SJooyung Han         Dup<RegisterType>(output_stage.result_offset_after_shift);
137*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
138*5f39d1b3SJooyung Han       const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul(
139*5f39d1b3SJooyung Han           input.reg[i], output_stage.result_fixedpoint_multiplier);
140*5f39d1b3SJooyung Han       output.reg[i] =
141*5f39d1b3SJooyung Han           Add(RoundingDivideByPOT(mulhigh_val, output_stage.result_shift),
142*5f39d1b3SJooyung Han               result_offset_after_shift);
143*5f39d1b3SJooyung Han     }
144*5f39d1b3SJooyung Han     return output;
145*5f39d1b3SJooyung Han   }
146*5f39d1b3SJooyung Han 
147*5f39d1b3SJooyung Han   const OutputStage& output_stage;
148*5f39d1b3SJooyung Han };
149*5f39d1b3SJooyung Han 
150*5f39d1b3SJooyung Han template <int Size>
151*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent,
152*5f39d1b3SJooyung Han                                  RegisterBuffer<std::int32_t, Size>> {
153*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> InputType;
154*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> OutputType;
155*5f39d1b3SJooyung Han 
156*5f39d1b3SJooyung Han   typedef OutputStageScaleInt32ByFixedPointAndExponent OutputStage;
157*5f39d1b3SJooyung Han 
158*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {
159*5f39d1b3SJooyung Han     left_shift = std::max(0, output_stage.result_exponent);
160*5f39d1b3SJooyung Han     right_shift = std::max(0, -output_stage.result_exponent);
161*5f39d1b3SJooyung Han   }
162*5f39d1b3SJooyung Han 
163*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
164*5f39d1b3SJooyung Han     OutputType output;
165*5f39d1b3SJooyung Han     using RegisterType = typename InputType::RegisterType;
166*5f39d1b3SJooyung Han     const RegisterType result_offset_after_shift =
167*5f39d1b3SJooyung Han         Dup<RegisterType>(output_stage.result_offset_after_shift);
168*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
169*5f39d1b3SJooyung Han       const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul(
170*5f39d1b3SJooyung Han           ShiftLeft(input.reg[i], left_shift),
171*5f39d1b3SJooyung Han           output_stage.result_fixedpoint_multiplier);
172*5f39d1b3SJooyung Han       output.reg[i] = Add(RoundingDivideByPOT(mulhigh_val, right_shift),
173*5f39d1b3SJooyung Han                           result_offset_after_shift);
174*5f39d1b3SJooyung Han     }
175*5f39d1b3SJooyung Han     return output;
176*5f39d1b3SJooyung Han   }
177*5f39d1b3SJooyung Han 
178*5f39d1b3SJooyung Han   const OutputStage& output_stage;
179*5f39d1b3SJooyung Han   int left_shift;
180*5f39d1b3SJooyung Han   int right_shift;
181*5f39d1b3SJooyung Han };
182*5f39d1b3SJooyung Han 
183*5f39d1b3SJooyung Han template <int Rows, int Cols, VectorShape Shape>
184*5f39d1b3SJooyung Han struct OutputStageEvalImpl<
185*5f39d1b3SJooyung Han     OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>,
186*5f39d1b3SJooyung Han     RegisterBlock<std::int32_t, Rows, Cols>> {
187*5f39d1b3SJooyung Han   typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
188*5f39d1b3SJooyung Han   typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
189*5f39d1b3SJooyung Han 
190*5f39d1b3SJooyung Han   typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage;
191*5f39d1b3SJooyung Han 
192*5f39d1b3SJooyung Han   OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
193*5f39d1b3SJooyung Han 
194*5f39d1b3SJooyung Han   OutputType Eval(InputType input, int row, int col) const {
195*5f39d1b3SJooyung Han     OutputType output;
196*5f39d1b3SJooyung Han     const int pos = Shape == VectorShape::Row ? col : row;
197*5f39d1b3SJooyung Han     using RegisterType = typename InputType::RegisterType;
198*5f39d1b3SJooyung Han     const RegisterType result_offset_after_shift =
199*5f39d1b3SJooyung Han         Dup<RegisterType>(output_stage.result_offset_after_shift);
200*5f39d1b3SJooyung Han     auto left_shift =
201*5f39d1b3SJooyung Han         LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
202*5f39d1b3SJooyung Han     auto right_shift =
203*5f39d1b3SJooyung Han         LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
204*5f39d1b3SJooyung Han     const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>(
205*5f39d1b3SJooyung Han         output_stage.result_fixedpoint_multiplier, pos);
206*5f39d1b3SJooyung Han     for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) {
207*5f39d1b3SJooyung Han       left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0);
208*5f39d1b3SJooyung Han       right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0);
209*5f39d1b3SJooyung Han     }
210*5f39d1b3SJooyung Han     const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul(
211*5f39d1b3SJooyung Han         BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier);
212*5f39d1b3SJooyung Han     const auto rdpot_val =
213*5f39d1b3SJooyung Han         BroadcastRoundingDivideByPOT(mulhigh_val, right_shift);
214*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
215*5f39d1b3SJooyung Han       output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift);
216*5f39d1b3SJooyung Han     }
217*5f39d1b3SJooyung Han     return output;
218*5f39d1b3SJooyung Han   }
219*5f39d1b3SJooyung Han 
220*5f39d1b3SJooyung Han   const OutputStage& output_stage;
221*5f39d1b3SJooyung Han };
222*5f39d1b3SJooyung Han 
223*5f39d1b3SJooyung Han // Implementation of OutputStageSaturatingCastToUint8 for scalar data.
224*5f39d1b3SJooyung Han template <int Size>
225*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
226*5f39d1b3SJooyung Han                                  RegisterBuffer<std::int32_t, Size>> {
227*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> InputType;
228*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::uint8_t, Size> OutputType;
229*5f39d1b3SJooyung Han   static_assert(InputType::kRegisterLanes == 1,
230*5f39d1b3SJooyung Han                 "This path is only for scalar values");
231*5f39d1b3SJooyung Han 
232*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToUint8 OutputStage;
233*5f39d1b3SJooyung Han 
234*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
235*5f39d1b3SJooyung Han 
236*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
237*5f39d1b3SJooyung Han     OutputType output;
238*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
239*5f39d1b3SJooyung Han       std::int32_t data = input.reg[i];
240*5f39d1b3SJooyung Han       output.reg[i] = data > 255 ? 255 : data < 0 ? 0 : data;
241*5f39d1b3SJooyung Han     }
242*5f39d1b3SJooyung Han     return output;
243*5f39d1b3SJooyung Han   }
244*5f39d1b3SJooyung Han };
245*5f39d1b3SJooyung Han 
246*5f39d1b3SJooyung Han // Implementation of OutputStageSaturatingCastToInt8 for scalar data.
247*5f39d1b3SJooyung Han template <int Size>
248*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
249*5f39d1b3SJooyung Han                                  RegisterBuffer<std::int32_t, Size>> {
250*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> InputType;
251*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int8_t, Size> OutputType;
252*5f39d1b3SJooyung Han   static_assert(InputType::kRegisterLanes == 1,
253*5f39d1b3SJooyung Han                 "This path is only for scalar values");
254*5f39d1b3SJooyung Han 
255*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToInt8 OutputStage;
256*5f39d1b3SJooyung Han 
257*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
258*5f39d1b3SJooyung Han 
259*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
260*5f39d1b3SJooyung Han     OutputType output;
261*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
262*5f39d1b3SJooyung Han       std::int32_t data = input.reg[i];
263*5f39d1b3SJooyung Han       output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data;
264*5f39d1b3SJooyung Han     }
265*5f39d1b3SJooyung Han     return output;
266*5f39d1b3SJooyung Han   }
267*5f39d1b3SJooyung Han };
268*5f39d1b3SJooyung Han 
269*5f39d1b3SJooyung Han // Implementation of OutputStageSaturatingCastToInt16 for scalar data.
270*5f39d1b3SJooyung Han template <int Size>
271*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
272*5f39d1b3SJooyung Han                                  RegisterBuffer<std::int32_t, Size>> {
273*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> InputType;
274*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int16_t, Size> OutputType;
275*5f39d1b3SJooyung Han   static_assert(InputType::kRegisterLanes == 1,
276*5f39d1b3SJooyung Han                 "This path is only for scalar values");
277*5f39d1b3SJooyung Han 
278*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToInt16 OutputStage;
279*5f39d1b3SJooyung Han 
280*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
281*5f39d1b3SJooyung Han 
282*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
283*5f39d1b3SJooyung Han     OutputType output;
284*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
285*5f39d1b3SJooyung Han       std::int32_t data = input.reg[i];
286*5f39d1b3SJooyung Han       output.reg[i] = data > 32767 ? 32767 : data < -32768 ? -32768 : data;
287*5f39d1b3SJooyung Han     }
288*5f39d1b3SJooyung Han     return output;
289*5f39d1b3SJooyung Han   }
290*5f39d1b3SJooyung Han };
291*5f39d1b3SJooyung Han 
292*5f39d1b3SJooyung Han // Implementation of OutputStageTruncatingCastToUint8 for scalar data
293*5f39d1b3SJooyung Han template <int Size>
294*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
295*5f39d1b3SJooyung Han                                  RegisterBuffer<std::int32_t, Size>> {
296*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> InputType;
297*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::uint8_t, Size> OutputType;
298*5f39d1b3SJooyung Han   static_assert(InputType::kRegisterLanes == 1,
299*5f39d1b3SJooyung Han                 "This path is only for scalar values");
300*5f39d1b3SJooyung Han 
301*5f39d1b3SJooyung Han   typedef OutputStageTruncatingCastToUint8 OutputStage;
302*5f39d1b3SJooyung Han 
303*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
304*5f39d1b3SJooyung Han 
305*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
306*5f39d1b3SJooyung Han     OutputType output;
307*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
308*5f39d1b3SJooyung Han       output.reg[i] = input.reg[i];
309*5f39d1b3SJooyung Han     }
310*5f39d1b3SJooyung Han     return output;
311*5f39d1b3SJooyung Han   }
312*5f39d1b3SJooyung Han };
313*5f39d1b3SJooyung Han 
314*5f39d1b3SJooyung Han template <int Rows, int Cols, typename VectorType>
315*5f39d1b3SJooyung Han struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
316*5f39d1b3SJooyung Han                            RegisterBlock<std::int32_t, Rows, Cols>> {
317*5f39d1b3SJooyung Han   typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
318*5f39d1b3SJooyung Han   typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
319*5f39d1b3SJooyung Han   typedef OutputStageBiasAddition<VectorType> OutputStage;
320*5f39d1b3SJooyung Han 
321*5f39d1b3SJooyung Han   OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
322*5f39d1b3SJooyung Han 
323*5f39d1b3SJooyung Han   OutputType Eval(InputType input, int row, int col) const {
324*5f39d1b3SJooyung Han     const int pos = VectorType::kShape == VectorShape::Row ? col : row;
325*5f39d1b3SJooyung Han     return BroadcastAdd<InputType>(
326*5f39d1b3SJooyung Han         input, LoadForBroadcasting<InputType>(output_stage.bias_vector, pos));
327*5f39d1b3SJooyung Han   }
328*5f39d1b3SJooyung Han 
329*5f39d1b3SJooyung Han   const OutputStage& output_stage;
330*5f39d1b3SJooyung Han };
331*5f39d1b3SJooyung Han 
332*5f39d1b3SJooyung Han template <int Size>
333*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageClamp,
334*5f39d1b3SJooyung Han                                  RegisterBuffer<std::int32_t, Size>> {
335*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> InputType;
336*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> OutputType;
337*5f39d1b3SJooyung Han 
338*5f39d1b3SJooyung Han   typedef OutputStageClamp OutputStage;
339*5f39d1b3SJooyung Han 
340*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
341*5f39d1b3SJooyung Han 
342*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
343*5f39d1b3SJooyung Han     using RegisterType = typename InputType::RegisterType;
344*5f39d1b3SJooyung Han     const RegisterType min = Dup<RegisterType>(output_stage.min);
345*5f39d1b3SJooyung Han     const RegisterType max = Dup<RegisterType>(output_stage.max);
346*5f39d1b3SJooyung Han     OutputType output;
347*5f39d1b3SJooyung Han     for (int i = 0; i < InputType::kRegisterCount; i++) {
348*5f39d1b3SJooyung Han       output.reg[i] = Min(Max(input.reg[i], min), max);
349*5f39d1b3SJooyung Han     }
350*5f39d1b3SJooyung Han     return output;
351*5f39d1b3SJooyung Han   }
352*5f39d1b3SJooyung Han 
353*5f39d1b3SJooyung Han   const OutputStage& output_stage;
354*5f39d1b3SJooyung Han };
355*5f39d1b3SJooyung Han 
356*5f39d1b3SJooyung Han template <int Size>
357*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageTanh,
358*5f39d1b3SJooyung Han                                  RegisterBuffer<std::int32_t, Size>> {
359*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> InputType;
360*5f39d1b3SJooyung Han   typedef RegisterBuffer<std::int32_t, Size> OutputType;
361*5f39d1b3SJooyung Han   using RegisterType = typename InputType::RegisterType;
362*5f39d1b3SJooyung Han   typedef RegisterType DataType;
363*5f39d1b3SJooyung Han   typedef OutputStageTanh OutputStage;
364*5f39d1b3SJooyung Han 
365*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {
366*5f39d1b3SJooyung Han     const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32;
367*5f39d1b3SJooyung Han     const std::int32_t real_amplitude_as_int32 =
368*5f39d1b3SJooyung Han         output_stage.real_amplitude_as_int32;
369*5f39d1b3SJooyung Han 
370*5f39d1b3SJooyung Han     input_cutoff_min = real_zero_as_int32 - 8 * real_amplitude_as_int32;
371*5f39d1b3SJooyung Han     input_cutoff_max = real_zero_as_int32 + 8 * real_amplitude_as_int32;
372*5f39d1b3SJooyung Han     output_min = real_zero_as_int32 - real_amplitude_as_int32;
373*5f39d1b3SJooyung Han     output_max = real_zero_as_int32 + real_amplitude_as_int32;
374*5f39d1b3SJooyung Han 
375*5f39d1b3SJooyung Han     double inverse_amplitude_normalized_double = 1.0 / real_amplitude_as_int32;
376*5f39d1b3SJooyung Han     inverse_amplitude_neg_exponent = 0;
377*5f39d1b3SJooyung Han     while (inverse_amplitude_normalized_double < 0.5) {
378*5f39d1b3SJooyung Han       inverse_amplitude_normalized_double *= 2;
379*5f39d1b3SJooyung Han       inverse_amplitude_neg_exponent++;
380*5f39d1b3SJooyung Han     }
381*5f39d1b3SJooyung Han     inverse_amplitude_normalized = FixedPoint<DataType, 0>::FromDouble(
382*5f39d1b3SJooyung Han         inverse_amplitude_normalized_double);
383*5f39d1b3SJooyung Han 
384*5f39d1b3SJooyung Han     double amplitude_normalized_double = real_amplitude_as_int32;
385*5f39d1b3SJooyung Han     amplitude_exponent = 0;
386*5f39d1b3SJooyung Han     while (amplitude_normalized_double >= 1.0) {
387*5f39d1b3SJooyung Han       amplitude_normalized_double *= 0.5;
388*5f39d1b3SJooyung Han       amplitude_exponent++;
389*5f39d1b3SJooyung Han     }
390*5f39d1b3SJooyung Han     amplitude_normalized =
391*5f39d1b3SJooyung Han         FixedPoint<DataType, 0>::FromDouble(amplitude_normalized_double);
392*5f39d1b3SJooyung Han   }
393*5f39d1b3SJooyung Han 
394*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
395*5f39d1b3SJooyung Han     const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32;
396*5f39d1b3SJooyung Han 
397*5f39d1b3SJooyung Han     typedef FixedPoint<DataType, 3> F3;
398*5f39d1b3SJooyung Han     typedef FixedPoint<DataType, 0> F0;
399*5f39d1b3SJooyung Han 
400*5f39d1b3SJooyung Han     OutputType output;
401*5f39d1b3SJooyung Han 
402*5f39d1b3SJooyung Han     for (int i = 0; i < OutputType::kRegisterCount; i++) {
403*5f39d1b3SJooyung Han       // fixed-point affine transformation
404*5f39d1b3SJooyung Han       DataType input_centered =
405*5f39d1b3SJooyung Han           Sub(input.reg[i], Dup<DataType>(real_zero_as_int32));
406*5f39d1b3SJooyung Han       F3 fixedpoint_input =
407*5f39d1b3SJooyung Han           F3::FromRaw(input_centered) * inverse_amplitude_normalized;
408*5f39d1b3SJooyung Han       // left shift
409*5f39d1b3SJooyung Han       fixedpoint_input.raw() = ShiftLeft(fixedpoint_input.raw(),
410*5f39d1b3SJooyung Han                                          28 - inverse_amplitude_neg_exponent);
411*5f39d1b3SJooyung Han       // fixed-point tanh and multiplication
412*5f39d1b3SJooyung Han       F0 fixedpoint_output = tanh(fixedpoint_input) * amplitude_normalized;
413*5f39d1b3SJooyung Han       // right shift
414*5f39d1b3SJooyung Han       DataType int32_output =
415*5f39d1b3SJooyung Han           Add(Dup<DataType>(real_zero_as_int32),
416*5f39d1b3SJooyung Han               ShiftRight(fixedpoint_output.raw(), 31 - amplitude_exponent));
417*5f39d1b3SJooyung Han 
418*5f39d1b3SJooyung Han       DataType mask_if_below_cutoff_min =
419*5f39d1b3SJooyung Han           MaskIfLessThanOrEqual(input.reg[i], Dup<DataType>(input_cutoff_min));
420*5f39d1b3SJooyung Han       DataType mask_if_above_cutoff_max = MaskIfGreaterThanOrEqual(
421*5f39d1b3SJooyung Han           input.reg[i], Dup<DataType>(input_cutoff_max));
422*5f39d1b3SJooyung Han 
423*5f39d1b3SJooyung Han       output.reg[i] = SelectUsingMask(
424*5f39d1b3SJooyung Han           mask_if_below_cutoff_min, Dup<DataType>(output_min),
425*5f39d1b3SJooyung Han           SelectUsingMask(mask_if_above_cutoff_max, Dup<DataType>(output_max),
426*5f39d1b3SJooyung Han                           int32_output));
427*5f39d1b3SJooyung Han     }
428*5f39d1b3SJooyung Han     return output;
429*5f39d1b3SJooyung Han   }
430*5f39d1b3SJooyung Han 
431*5f39d1b3SJooyung Han   const OutputStage& output_stage;
432*5f39d1b3SJooyung Han   std::int32_t input_cutoff_min, input_cutoff_max;
433*5f39d1b3SJooyung Han   std::int32_t output_min, output_max;
434*5f39d1b3SJooyung Han   FixedPoint<DataType, 0> inverse_amplitude_normalized;
435*5f39d1b3SJooyung Han   int inverse_amplitude_neg_exponent;
436*5f39d1b3SJooyung Han   FixedPoint<DataType, 0> amplitude_normalized;
437*5f39d1b3SJooyung Han   int amplitude_exponent;
438*5f39d1b3SJooyung Han };
439*5f39d1b3SJooyung Han 
440*5f39d1b3SJooyung Han // OutputPipelineOutputType is a helper to determine the output data type of a
441*5f39d1b3SJooyung Han // pipeline, for a
442*5f39d1b3SJooyung Han // given input data type. It is a recursive template; see the explanation on
443*5f39d1b3SJooyung Han // OutputPipelineEvalImpl below.
444*5f39d1b3SJooyung Han template <typename OutputPipelineType, int FirstStage, typename InputType,
445*5f39d1b3SJooyung Han           bool StopRecursion =
446*5f39d1b3SJooyung Han               FirstStage == std::tuple_size<OutputPipelineType>::value>
447*5f39d1b3SJooyung Han struct OutputPipelineOutputType {
448*5f39d1b3SJooyung Han   typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type
449*5f39d1b3SJooyung Han       FirstStageType;
450*5f39d1b3SJooyung Han   typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType
451*5f39d1b3SJooyung Han       FirstStageOutputType;
452*5f39d1b3SJooyung Han   typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage + 1,
453*5f39d1b3SJooyung Han                                             FirstStageOutputType>::Type Type;
454*5f39d1b3SJooyung Han };
455*5f39d1b3SJooyung Han 
456*5f39d1b3SJooyung Han template <typename OutputPipelineType, int FirstStage, typename InputType>
457*5f39d1b3SJooyung Han struct OutputPipelineOutputType<OutputPipelineType, FirstStage, InputType,
458*5f39d1b3SJooyung Han                                 true> {
459*5f39d1b3SJooyung Han   typedef InputType Type;
460*5f39d1b3SJooyung Han };
461*5f39d1b3SJooyung Han 
462*5f39d1b3SJooyung Han // OutputPipelineEvalImpl is a helper to implement the evaluation of
463*5f39d1b3SJooyung Han // the whole pipeline. It is a recursive template to implement compile-time
464*5f39d1b3SJooyung Han // unrolling of the loop over all pipeline stages. The 'FirstStage' parameter
465*5f39d1b3SJooyung Han // is how we implement recursion: each specialization implements only
466*5f39d1b3SJooyung Han // evaluation starting at 'FirstStage'. The StopRecursion parameter is just a
467*5f39d1b3SJooyung Han // helper to implement the termination of the recursion as a partial
468*5f39d1b3SJooyung Han // specialization below.
469*5f39d1b3SJooyung Han template <typename OutputPipelineType, int FirstStage, typename InputType,
470*5f39d1b3SJooyung Han           bool StopRecursion =
471*5f39d1b3SJooyung Han               FirstStage == std::tuple_size<OutputPipelineType>::value>
472*5f39d1b3SJooyung Han struct OutputPipelineEvalImpl {
473*5f39d1b3SJooyung Han   typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type
474*5f39d1b3SJooyung Han       FirstStageType;
475*5f39d1b3SJooyung Han   typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType
476*5f39d1b3SJooyung Han       FirstStageOutputType;
477*5f39d1b3SJooyung Han   typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage,
478*5f39d1b3SJooyung Han                                             InputType>::Type OutputType;
479*5f39d1b3SJooyung Han 
480*5f39d1b3SJooyung Han   OutputPipelineEvalImpl(const OutputPipelineType& output_pipeline)
481*5f39d1b3SJooyung Han       : head_impl(std::get<FirstStage>(output_pipeline)),
482*5f39d1b3SJooyung Han         tail_impl(output_pipeline) {}
483*5f39d1b3SJooyung Han 
484*5f39d1b3SJooyung Han   OutputType Eval(InputType input, int row, int col) const {
485*5f39d1b3SJooyung Han     // Evaluate the first stage.
486*5f39d1b3SJooyung Han     FirstStageOutputType first_stage_output = head_impl.Eval(input, row, col);
487*5f39d1b3SJooyung Han     // Recurse into the remaining stages.
488*5f39d1b3SJooyung Han     return tail_impl.Eval(first_stage_output, row, col);
489*5f39d1b3SJooyung Han   }
490*5f39d1b3SJooyung Han 
491*5f39d1b3SJooyung Han   const OutputStageEvalImpl<FirstStageType, InputType> head_impl;
492*5f39d1b3SJooyung Han   const OutputPipelineEvalImpl<OutputPipelineType, FirstStage + 1,
493*5f39d1b3SJooyung Han                                FirstStageOutputType>
494*5f39d1b3SJooyung Han       tail_impl;
495*5f39d1b3SJooyung Han };
496*5f39d1b3SJooyung Han 
497*5f39d1b3SJooyung Han // Specialization on 'StopRecursion' for terminating the recursion.
498*5f39d1b3SJooyung Han template <typename OutputPipelineType, int FirstStage, typename InputType>
499*5f39d1b3SJooyung Han struct OutputPipelineEvalImpl<OutputPipelineType, FirstStage, InputType, true> {
500*5f39d1b3SJooyung Han   OutputPipelineEvalImpl(const OutputPipelineType&) {}
501*5f39d1b3SJooyung Han 
502*5f39d1b3SJooyung Han   InputType Eval(InputType input, int, int) const {
503*5f39d1b3SJooyung Han     // Terminating the recursion.
504*5f39d1b3SJooyung Han     return input;
505*5f39d1b3SJooyung Han   }
506*5f39d1b3SJooyung Han };
507*5f39d1b3SJooyung Han 
508*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename DstType>
509*5f39d1b3SJooyung Han struct StoreFinalOutputImpl {
510*5f39d1b3SJooyung Han   static_assert(std::is_same<RegisterBlockType, void>::value,
511*5f39d1b3SJooyung Han                 "This generic impl should never be hit");
512*5f39d1b3SJooyung Han };
513*5f39d1b3SJooyung Han 
514*5f39d1b3SJooyung Han template <typename ScalarType, int Rows, int Cols, typename DstType>
515*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegisterBlock<ScalarType, Rows, Cols>, DstType> {
516*5f39d1b3SJooyung Han   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
517*5f39d1b3SJooyung Han   static void Run(const RegisterBlockType& src, DstType* dst, int row,
518*5f39d1b3SJooyung Han                   int col) {
519*5f39d1b3SJooyung Han     for (int r = 0; r < Rows; r++) {
520*5f39d1b3SJooyung Han       for (int c = 0; c < Cols; c++) {
521*5f39d1b3SJooyung Han         *dst->data(row + r, col + c) = src.buf.reg[r + c * Rows];
522*5f39d1b3SJooyung Han       }
523*5f39d1b3SJooyung Han     }
524*5f39d1b3SJooyung Han   }
525*5f39d1b3SJooyung Han };
526*5f39d1b3SJooyung Han 
527*5f39d1b3SJooyung Han // StoreFinalOutput takes the final value at the end of the output pipeline and
528*5f39d1b3SJooyung Han // stores it into the destination matrix. It can be specialized for different
529*5f39d1b3SJooyung Han // data types; the generic implementation here is typically used only for plain
530*5f39d1b3SJooyung Han // old scalar (not SIMD) types.
531*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename DstType>
532*5f39d1b3SJooyung Han void StoreFinalOutput(RegisterBlockType src, DstType* dst, int row, int col) {
533*5f39d1b3SJooyung Han   StoreFinalOutputImpl<RegisterBlockType, DstType>::Run(src, dst, row, col);
534*5f39d1b3SJooyung Han }
535*5f39d1b3SJooyung Han 
536*5f39d1b3SJooyung Han template <typename OutputPipelineType, typename InputType>
537*5f39d1b3SJooyung Han struct OutputPipelineExecutor {
538*5f39d1b3SJooyung Han   OutputPipelineExecutor(const OutputPipelineType& output_pipeline)
539*5f39d1b3SJooyung Han       : output_pipeline_eval_impl_(output_pipeline) {}
540*5f39d1b3SJooyung Han 
541*5f39d1b3SJooyung Han   // Execute is the entry point into the output pipeline evaluation
542*5f39d1b3SJooyung Han   // code. It should be the only thing that unpack code calls. It takes the
543*5f39d1b3SJooyung Han   // result
544*5f39d1b3SJooyung Han   // of the unpack stage and stores it into the destination matrix.
545*5f39d1b3SJooyung Han   template <typename DstType>
546*5f39d1b3SJooyung Han   void Execute(InputType input, DstType* dst, int src_global_row,
547*5f39d1b3SJooyung Han                int src_global_col, int dst_row, int dst_col) const {
548*5f39d1b3SJooyung Han     // Statically assert that the output pipeline matches the given destination
549*5f39d1b3SJooyung Han     // matrix's scalar type.
550*5f39d1b3SJooyung Han     typedef typename OutputPipelineOutputType<
551*5f39d1b3SJooyung Han         OutputPipelineType, 0, InputType>::Type::BufferType::ScalarType
552*5f39d1b3SJooyung Han 
553*5f39d1b3SJooyung Han         ScalarOutputType;
554*5f39d1b3SJooyung Han     typedef typename DstType::Scalar ScalarDstType;
555*5f39d1b3SJooyung Han     static_assert(std::is_same<ScalarOutputType, ScalarDstType>::value,
556*5f39d1b3SJooyung Han                   "mismatched destination scalar type and output pipeline");
557*5f39d1b3SJooyung Han 
558*5f39d1b3SJooyung Han     // Evaluate the output pipeline.
559*5f39d1b3SJooyung Han     auto output =
560*5f39d1b3SJooyung Han         output_pipeline_eval_impl_.Eval(input, src_global_row, src_global_col);
561*5f39d1b3SJooyung Han     // Store the result into the destination matrix.
562*5f39d1b3SJooyung Han     StoreFinalOutput(output, dst, dst_row, dst_col);
563*5f39d1b3SJooyung Han   }
564*5f39d1b3SJooyung Han 
565*5f39d1b3SJooyung Han   const OutputPipelineEvalImpl<OutputPipelineType, 0, InputType>
566*5f39d1b3SJooyung Han       output_pipeline_eval_impl_;
567*5f39d1b3SJooyung Han };
568*5f39d1b3SJooyung Han 
569*5f39d1b3SJooyung Han }  // namespace gemmlowp
570*5f39d1b3SJooyung Han 
571*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON
572*5f39d1b3SJooyung Han #include "output_neon.h"
573*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_SSE4)
574*5f39d1b3SJooyung Han #include "output_sse.h"
575*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_MSA)
576*5f39d1b3SJooyung Han #include "output_msa.h"
577*5f39d1b3SJooyung Han #endif
578*5f39d1b3SJooyung Han 
579*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_OUTPUT_H_
580