1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2*5f39d1b3SJooyung Han // 3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License"); 4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License. 5*5f39d1b3SJooyung Han // You may obtain a copy of the License at 6*5f39d1b3SJooyung Han // 7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0 8*5f39d1b3SJooyung Han // 9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software 10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS, 11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and 13*5f39d1b3SJooyung Han // limitations under the License. 14*5f39d1b3SJooyung Han 15*5f39d1b3SJooyung Han // output.h: processing the 32-bit accumulators output by the unpack 16*5f39d1b3SJooyung Han // stage, obtaining the final result matrix entries and storing them into 17*5f39d1b3SJooyung Han // the destination matrix. 18*5f39d1b3SJooyung Han 19*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_OUTPUT_H_ 20*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_OUTPUT_H_ 21*5f39d1b3SJooyung Han 22*5f39d1b3SJooyung Han #include <cmath> 23*5f39d1b3SJooyung Han #include <tuple> 24*5f39d1b3SJooyung Han #include <type_traits> 25*5f39d1b3SJooyung Han #include <typeinfo> 26*5f39d1b3SJooyung Han 27*5f39d1b3SJooyung Han #include "../fixedpoint/fixedpoint.h" 28*5f39d1b3SJooyung Han #include "../public/output_stages.h" 29*5f39d1b3SJooyung Han #include "simd_wrappers.h" 30*5f39d1b3SJooyung Han 31*5f39d1b3SJooyung Han namespace gemmlowp { 32*5f39d1b3SJooyung Han 33*5f39d1b3SJooyung Han template <typename OutputStage, typename InputBufferType> 34*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl { 35*5f39d1b3SJooyung Han // This generic template body should never be hit. 36*5f39d1b3SJooyung Han static_assert( 37*5f39d1b3SJooyung Han std::is_same<InputBufferType, void>::value, 38*5f39d1b3SJooyung Han "Unimplemented: missing implementation of this output pipeline stage " 39*5f39d1b3SJooyung Han "for this data type. This would happen if some architecture-specific " 40*5f39d1b3SJooyung Han "SIMD back-end (output_$arch.h) were incomplete."); 41*5f39d1b3SJooyung Han }; 42*5f39d1b3SJooyung Han 43*5f39d1b3SJooyung Han template <typename OutputStage, typename InputType> 44*5f39d1b3SJooyung Han struct OutputStageEvalImpl { 45*5f39d1b3SJooyung Han static constexpr int kRows = InputType::kRows; 46*5f39d1b3SJooyung Han static constexpr int kCols = InputType::kCols; 47*5f39d1b3SJooyung Han using InputBufferType = typename InputType::BufferType; 48*5f39d1b3SJooyung Han using BufferEvalImplType = 49*5f39d1b3SJooyung Han OutputStageEvalBufferImpl<OutputStage, InputBufferType>; 50*5f39d1b3SJooyung Han using OutputBufferType = typename BufferEvalImplType::OutputType; 51*5f39d1b3SJooyung Han using OutputScalarType = typename OutputBufferType::ScalarType; 52*5f39d1b3SJooyung Han using OutputType = RegisterBlock<OutputScalarType, kRows, kCols>; 53*5f39d1b3SJooyung Han OutputStageEvalImplOutputStageEvalImpl54*5f39d1b3SJooyung Han OutputStageEvalImpl(const OutputStage& s) : buffer_eval_impl(s) {} 55*5f39d1b3SJooyung Han EvalOutputStageEvalImpl56*5f39d1b3SJooyung Han OutputType Eval(InputType input, int, int) const { 57*5f39d1b3SJooyung Han OutputType output; 58*5f39d1b3SJooyung Han output.buf = buffer_eval_impl.Eval(input.buf); 59*5f39d1b3SJooyung Han return output; 60*5f39d1b3SJooyung Han } 61*5f39d1b3SJooyung Han 62*5f39d1b3SJooyung Han const BufferEvalImplType buffer_eval_impl; 63*5f39d1b3SJooyung Han }; 64*5f39d1b3SJooyung Han 65*5f39d1b3SJooyung Han template <int Size> 66*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageQuantizeDownInt32ToUint8Scale, 67*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 68*5f39d1b3SJooyung Han using InputType = RegisterBuffer<std::int32_t, Size>; 69*5f39d1b3SJooyung Han using OutputType = RegisterBuffer<std::int32_t, Size>; 70*5f39d1b3SJooyung Han 71*5f39d1b3SJooyung Han typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage; 72*5f39d1b3SJooyung Han 73*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} 74*5f39d1b3SJooyung Han 75*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 76*5f39d1b3SJooyung Han const int result_shift = output_stage.result_shift; 77*5f39d1b3SJooyung Han const std::int32_t result_mult_int = output_stage.result_mult_int; 78*5f39d1b3SJooyung Han using RegisterType = typename InputType::RegisterType; 79*5f39d1b3SJooyung Han const RegisterType result_offset = 80*5f39d1b3SJooyung Han Dup<RegisterType>(output_stage.result_offset); 81*5f39d1b3SJooyung Han OutputType output; 82*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 83*5f39d1b3SJooyung Han output.reg[i] = RoundingDivideByPOT( 84*5f39d1b3SJooyung Han Mul(Add(input.reg[i], result_offset), result_mult_int), result_shift); 85*5f39d1b3SJooyung Han } 86*5f39d1b3SJooyung Han return output; 87*5f39d1b3SJooyung Han } 88*5f39d1b3SJooyung Han 89*5f39d1b3SJooyung Han const OutputStage& output_stage; 90*5f39d1b3SJooyung Han }; 91*5f39d1b3SJooyung Han 92*5f39d1b3SJooyung Han template <int Rows, int Cols, VectorShape Shape> 93*5f39d1b3SJooyung Han struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>, 94*5f39d1b3SJooyung Han RegisterBlock<std::int32_t, Rows, Cols>> { 95*5f39d1b3SJooyung Han typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; 96*5f39d1b3SJooyung Han typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; 97*5f39d1b3SJooyung Han typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> OutputStage; 98*5f39d1b3SJooyung Han 99*5f39d1b3SJooyung Han OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} 100*5f39d1b3SJooyung Han 101*5f39d1b3SJooyung Han OutputType Eval(InputType input, int row, int col) const { 102*5f39d1b3SJooyung Han OutputType output; 103*5f39d1b3SJooyung Han const int result_shift = output_stage.result_shift; 104*5f39d1b3SJooyung Han const int pos = Shape == VectorShape::Col ? row : col; 105*5f39d1b3SJooyung Han const auto result_mult_int = 106*5f39d1b3SJooyung Han LoadForBroadcasting<InputType>(output_stage.result_mult_int, pos); 107*5f39d1b3SJooyung Han const auto result_offset = 108*5f39d1b3SJooyung Han LoadForBroadcasting<InputType>(output_stage.result_offset, pos); 109*5f39d1b3SJooyung Han const auto dividend = BroadcastMul<InputType>( 110*5f39d1b3SJooyung Han BroadcastAdd<InputType>(input, result_offset), result_mult_int); 111*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 112*5f39d1b3SJooyung Han output.buf.reg[i] = 113*5f39d1b3SJooyung Han RoundingDivideByPOT(dividend.buf.reg[i], result_shift); 114*5f39d1b3SJooyung Han } 115*5f39d1b3SJooyung Han return output; 116*5f39d1b3SJooyung Han } 117*5f39d1b3SJooyung Han 118*5f39d1b3SJooyung Han const OutputStage& output_stage; 119*5f39d1b3SJooyung Han }; 120*5f39d1b3SJooyung Han 121*5f39d1b3SJooyung Han template <int Size> 122*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl< 123*5f39d1b3SJooyung Han OutputStageQuantizeDownInt32ByFixedPoint, 124*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 125*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> InputType; 126*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> OutputType; 127*5f39d1b3SJooyung Han 128*5f39d1b3SJooyung Han typedef OutputStageQuantizeDownInt32ByFixedPoint OutputStage; 129*5f39d1b3SJooyung Han 130*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} 131*5f39d1b3SJooyung Han 132*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 133*5f39d1b3SJooyung Han OutputType output; 134*5f39d1b3SJooyung Han using RegisterType = typename InputType::RegisterType; 135*5f39d1b3SJooyung Han const RegisterType result_offset_after_shift = 136*5f39d1b3SJooyung Han Dup<RegisterType>(output_stage.result_offset_after_shift); 137*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 138*5f39d1b3SJooyung Han const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul( 139*5f39d1b3SJooyung Han input.reg[i], output_stage.result_fixedpoint_multiplier); 140*5f39d1b3SJooyung Han output.reg[i] = 141*5f39d1b3SJooyung Han Add(RoundingDivideByPOT(mulhigh_val, output_stage.result_shift), 142*5f39d1b3SJooyung Han result_offset_after_shift); 143*5f39d1b3SJooyung Han } 144*5f39d1b3SJooyung Han return output; 145*5f39d1b3SJooyung Han } 146*5f39d1b3SJooyung Han 147*5f39d1b3SJooyung Han const OutputStage& output_stage; 148*5f39d1b3SJooyung Han }; 149*5f39d1b3SJooyung Han 150*5f39d1b3SJooyung Han template <int Size> 151*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent, 152*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 153*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> InputType; 154*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> OutputType; 155*5f39d1b3SJooyung Han 156*5f39d1b3SJooyung Han typedef OutputStageScaleInt32ByFixedPointAndExponent OutputStage; 157*5f39d1b3SJooyung Han 158*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) { 159*5f39d1b3SJooyung Han left_shift = std::max(0, output_stage.result_exponent); 160*5f39d1b3SJooyung Han right_shift = std::max(0, -output_stage.result_exponent); 161*5f39d1b3SJooyung Han } 162*5f39d1b3SJooyung Han 163*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 164*5f39d1b3SJooyung Han OutputType output; 165*5f39d1b3SJooyung Han using RegisterType = typename InputType::RegisterType; 166*5f39d1b3SJooyung Han const RegisterType result_offset_after_shift = 167*5f39d1b3SJooyung Han Dup<RegisterType>(output_stage.result_offset_after_shift); 168*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 169*5f39d1b3SJooyung Han const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul( 170*5f39d1b3SJooyung Han ShiftLeft(input.reg[i], left_shift), 171*5f39d1b3SJooyung Han output_stage.result_fixedpoint_multiplier); 172*5f39d1b3SJooyung Han output.reg[i] = Add(RoundingDivideByPOT(mulhigh_val, right_shift), 173*5f39d1b3SJooyung Han result_offset_after_shift); 174*5f39d1b3SJooyung Han } 175*5f39d1b3SJooyung Han return output; 176*5f39d1b3SJooyung Han } 177*5f39d1b3SJooyung Han 178*5f39d1b3SJooyung Han const OutputStage& output_stage; 179*5f39d1b3SJooyung Han int left_shift; 180*5f39d1b3SJooyung Han int right_shift; 181*5f39d1b3SJooyung Han }; 182*5f39d1b3SJooyung Han 183*5f39d1b3SJooyung Han template <int Rows, int Cols, VectorShape Shape> 184*5f39d1b3SJooyung Han struct OutputStageEvalImpl< 185*5f39d1b3SJooyung Han OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>, 186*5f39d1b3SJooyung Han RegisterBlock<std::int32_t, Rows, Cols>> { 187*5f39d1b3SJooyung Han typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; 188*5f39d1b3SJooyung Han typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; 189*5f39d1b3SJooyung Han 190*5f39d1b3SJooyung Han typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage; 191*5f39d1b3SJooyung Han 192*5f39d1b3SJooyung Han OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} 193*5f39d1b3SJooyung Han 194*5f39d1b3SJooyung Han OutputType Eval(InputType input, int row, int col) const { 195*5f39d1b3SJooyung Han OutputType output; 196*5f39d1b3SJooyung Han const int pos = Shape == VectorShape::Row ? col : row; 197*5f39d1b3SJooyung Han using RegisterType = typename InputType::RegisterType; 198*5f39d1b3SJooyung Han const RegisterType result_offset_after_shift = 199*5f39d1b3SJooyung Han Dup<RegisterType>(output_stage.result_offset_after_shift); 200*5f39d1b3SJooyung Han auto left_shift = 201*5f39d1b3SJooyung Han LoadForBroadcasting<InputType>(output_stage.result_exponent, pos); 202*5f39d1b3SJooyung Han auto right_shift = 203*5f39d1b3SJooyung Han LoadForBroadcasting<InputType>(output_stage.result_exponent, pos); 204*5f39d1b3SJooyung Han const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>( 205*5f39d1b3SJooyung Han output_stage.result_fixedpoint_multiplier, pos); 206*5f39d1b3SJooyung Han for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) { 207*5f39d1b3SJooyung Han left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0); 208*5f39d1b3SJooyung Han right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0); 209*5f39d1b3SJooyung Han } 210*5f39d1b3SJooyung Han const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul( 211*5f39d1b3SJooyung Han BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier); 212*5f39d1b3SJooyung Han const auto rdpot_val = 213*5f39d1b3SJooyung Han BroadcastRoundingDivideByPOT(mulhigh_val, right_shift); 214*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 215*5f39d1b3SJooyung Han output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift); 216*5f39d1b3SJooyung Han } 217*5f39d1b3SJooyung Han return output; 218*5f39d1b3SJooyung Han } 219*5f39d1b3SJooyung Han 220*5f39d1b3SJooyung Han const OutputStage& output_stage; 221*5f39d1b3SJooyung Han }; 222*5f39d1b3SJooyung Han 223*5f39d1b3SJooyung Han // Implementation of OutputStageSaturatingCastToUint8 for scalar data. 224*5f39d1b3SJooyung Han template <int Size> 225*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 226*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 227*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> InputType; 228*5f39d1b3SJooyung Han typedef RegisterBuffer<std::uint8_t, Size> OutputType; 229*5f39d1b3SJooyung Han static_assert(InputType::kRegisterLanes == 1, 230*5f39d1b3SJooyung Han "This path is only for scalar values"); 231*5f39d1b3SJooyung Han 232*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToUint8 OutputStage; 233*5f39d1b3SJooyung Han 234*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 235*5f39d1b3SJooyung Han 236*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 237*5f39d1b3SJooyung Han OutputType output; 238*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 239*5f39d1b3SJooyung Han std::int32_t data = input.reg[i]; 240*5f39d1b3SJooyung Han output.reg[i] = data > 255 ? 255 : data < 0 ? 0 : data; 241*5f39d1b3SJooyung Han } 242*5f39d1b3SJooyung Han return output; 243*5f39d1b3SJooyung Han } 244*5f39d1b3SJooyung Han }; 245*5f39d1b3SJooyung Han 246*5f39d1b3SJooyung Han // Implementation of OutputStageSaturatingCastToInt8 for scalar data. 247*5f39d1b3SJooyung Han template <int Size> 248*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, 249*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 250*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> InputType; 251*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int8_t, Size> OutputType; 252*5f39d1b3SJooyung Han static_assert(InputType::kRegisterLanes == 1, 253*5f39d1b3SJooyung Han "This path is only for scalar values"); 254*5f39d1b3SJooyung Han 255*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToInt8 OutputStage; 256*5f39d1b3SJooyung Han 257*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 258*5f39d1b3SJooyung Han 259*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 260*5f39d1b3SJooyung Han OutputType output; 261*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 262*5f39d1b3SJooyung Han std::int32_t data = input.reg[i]; 263*5f39d1b3SJooyung Han output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data; 264*5f39d1b3SJooyung Han } 265*5f39d1b3SJooyung Han return output; 266*5f39d1b3SJooyung Han } 267*5f39d1b3SJooyung Han }; 268*5f39d1b3SJooyung Han 269*5f39d1b3SJooyung Han // Implementation of OutputStageSaturatingCastToInt16 for scalar data. 270*5f39d1b3SJooyung Han template <int Size> 271*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 272*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 273*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> InputType; 274*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int16_t, Size> OutputType; 275*5f39d1b3SJooyung Han static_assert(InputType::kRegisterLanes == 1, 276*5f39d1b3SJooyung Han "This path is only for scalar values"); 277*5f39d1b3SJooyung Han 278*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToInt16 OutputStage; 279*5f39d1b3SJooyung Han 280*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 281*5f39d1b3SJooyung Han 282*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 283*5f39d1b3SJooyung Han OutputType output; 284*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 285*5f39d1b3SJooyung Han std::int32_t data = input.reg[i]; 286*5f39d1b3SJooyung Han output.reg[i] = data > 32767 ? 32767 : data < -32768 ? -32768 : data; 287*5f39d1b3SJooyung Han } 288*5f39d1b3SJooyung Han return output; 289*5f39d1b3SJooyung Han } 290*5f39d1b3SJooyung Han }; 291*5f39d1b3SJooyung Han 292*5f39d1b3SJooyung Han // Implementation of OutputStageTruncatingCastToUint8 for scalar data 293*5f39d1b3SJooyung Han template <int Size> 294*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, 295*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 296*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> InputType; 297*5f39d1b3SJooyung Han typedef RegisterBuffer<std::uint8_t, Size> OutputType; 298*5f39d1b3SJooyung Han static_assert(InputType::kRegisterLanes == 1, 299*5f39d1b3SJooyung Han "This path is only for scalar values"); 300*5f39d1b3SJooyung Han 301*5f39d1b3SJooyung Han typedef OutputStageTruncatingCastToUint8 OutputStage; 302*5f39d1b3SJooyung Han 303*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 304*5f39d1b3SJooyung Han 305*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 306*5f39d1b3SJooyung Han OutputType output; 307*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 308*5f39d1b3SJooyung Han output.reg[i] = input.reg[i]; 309*5f39d1b3SJooyung Han } 310*5f39d1b3SJooyung Han return output; 311*5f39d1b3SJooyung Han } 312*5f39d1b3SJooyung Han }; 313*5f39d1b3SJooyung Han 314*5f39d1b3SJooyung Han template <int Rows, int Cols, typename VectorType> 315*5f39d1b3SJooyung Han struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>, 316*5f39d1b3SJooyung Han RegisterBlock<std::int32_t, Rows, Cols>> { 317*5f39d1b3SJooyung Han typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; 318*5f39d1b3SJooyung Han typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; 319*5f39d1b3SJooyung Han typedef OutputStageBiasAddition<VectorType> OutputStage; 320*5f39d1b3SJooyung Han 321*5f39d1b3SJooyung Han OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} 322*5f39d1b3SJooyung Han 323*5f39d1b3SJooyung Han OutputType Eval(InputType input, int row, int col) const { 324*5f39d1b3SJooyung Han const int pos = VectorType::kShape == VectorShape::Row ? col : row; 325*5f39d1b3SJooyung Han return BroadcastAdd<InputType>( 326*5f39d1b3SJooyung Han input, LoadForBroadcasting<InputType>(output_stage.bias_vector, pos)); 327*5f39d1b3SJooyung Han } 328*5f39d1b3SJooyung Han 329*5f39d1b3SJooyung Han const OutputStage& output_stage; 330*5f39d1b3SJooyung Han }; 331*5f39d1b3SJooyung Han 332*5f39d1b3SJooyung Han template <int Size> 333*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageClamp, 334*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 335*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> InputType; 336*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> OutputType; 337*5f39d1b3SJooyung Han 338*5f39d1b3SJooyung Han typedef OutputStageClamp OutputStage; 339*5f39d1b3SJooyung Han 340*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} 341*5f39d1b3SJooyung Han 342*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 343*5f39d1b3SJooyung Han using RegisterType = typename InputType::RegisterType; 344*5f39d1b3SJooyung Han const RegisterType min = Dup<RegisterType>(output_stage.min); 345*5f39d1b3SJooyung Han const RegisterType max = Dup<RegisterType>(output_stage.max); 346*5f39d1b3SJooyung Han OutputType output; 347*5f39d1b3SJooyung Han for (int i = 0; i < InputType::kRegisterCount; i++) { 348*5f39d1b3SJooyung Han output.reg[i] = Min(Max(input.reg[i], min), max); 349*5f39d1b3SJooyung Han } 350*5f39d1b3SJooyung Han return output; 351*5f39d1b3SJooyung Han } 352*5f39d1b3SJooyung Han 353*5f39d1b3SJooyung Han const OutputStage& output_stage; 354*5f39d1b3SJooyung Han }; 355*5f39d1b3SJooyung Han 356*5f39d1b3SJooyung Han template <int Size> 357*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageTanh, 358*5f39d1b3SJooyung Han RegisterBuffer<std::int32_t, Size>> { 359*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> InputType; 360*5f39d1b3SJooyung Han typedef RegisterBuffer<std::int32_t, Size> OutputType; 361*5f39d1b3SJooyung Han using RegisterType = typename InputType::RegisterType; 362*5f39d1b3SJooyung Han typedef RegisterType DataType; 363*5f39d1b3SJooyung Han typedef OutputStageTanh OutputStage; 364*5f39d1b3SJooyung Han 365*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) { 366*5f39d1b3SJooyung Han const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32; 367*5f39d1b3SJooyung Han const std::int32_t real_amplitude_as_int32 = 368*5f39d1b3SJooyung Han output_stage.real_amplitude_as_int32; 369*5f39d1b3SJooyung Han 370*5f39d1b3SJooyung Han input_cutoff_min = real_zero_as_int32 - 8 * real_amplitude_as_int32; 371*5f39d1b3SJooyung Han input_cutoff_max = real_zero_as_int32 + 8 * real_amplitude_as_int32; 372*5f39d1b3SJooyung Han output_min = real_zero_as_int32 - real_amplitude_as_int32; 373*5f39d1b3SJooyung Han output_max = real_zero_as_int32 + real_amplitude_as_int32; 374*5f39d1b3SJooyung Han 375*5f39d1b3SJooyung Han double inverse_amplitude_normalized_double = 1.0 / real_amplitude_as_int32; 376*5f39d1b3SJooyung Han inverse_amplitude_neg_exponent = 0; 377*5f39d1b3SJooyung Han while (inverse_amplitude_normalized_double < 0.5) { 378*5f39d1b3SJooyung Han inverse_amplitude_normalized_double *= 2; 379*5f39d1b3SJooyung Han inverse_amplitude_neg_exponent++; 380*5f39d1b3SJooyung Han } 381*5f39d1b3SJooyung Han inverse_amplitude_normalized = FixedPoint<DataType, 0>::FromDouble( 382*5f39d1b3SJooyung Han inverse_amplitude_normalized_double); 383*5f39d1b3SJooyung Han 384*5f39d1b3SJooyung Han double amplitude_normalized_double = real_amplitude_as_int32; 385*5f39d1b3SJooyung Han amplitude_exponent = 0; 386*5f39d1b3SJooyung Han while (amplitude_normalized_double >= 1.0) { 387*5f39d1b3SJooyung Han amplitude_normalized_double *= 0.5; 388*5f39d1b3SJooyung Han amplitude_exponent++; 389*5f39d1b3SJooyung Han } 390*5f39d1b3SJooyung Han amplitude_normalized = 391*5f39d1b3SJooyung Han FixedPoint<DataType, 0>::FromDouble(amplitude_normalized_double); 392*5f39d1b3SJooyung Han } 393*5f39d1b3SJooyung Han 394*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 395*5f39d1b3SJooyung Han const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32; 396*5f39d1b3SJooyung Han 397*5f39d1b3SJooyung Han typedef FixedPoint<DataType, 3> F3; 398*5f39d1b3SJooyung Han typedef FixedPoint<DataType, 0> F0; 399*5f39d1b3SJooyung Han 400*5f39d1b3SJooyung Han OutputType output; 401*5f39d1b3SJooyung Han 402*5f39d1b3SJooyung Han for (int i = 0; i < OutputType::kRegisterCount; i++) { 403*5f39d1b3SJooyung Han // fixed-point affine transformation 404*5f39d1b3SJooyung Han DataType input_centered = 405*5f39d1b3SJooyung Han Sub(input.reg[i], Dup<DataType>(real_zero_as_int32)); 406*5f39d1b3SJooyung Han F3 fixedpoint_input = 407*5f39d1b3SJooyung Han F3::FromRaw(input_centered) * inverse_amplitude_normalized; 408*5f39d1b3SJooyung Han // left shift 409*5f39d1b3SJooyung Han fixedpoint_input.raw() = ShiftLeft(fixedpoint_input.raw(), 410*5f39d1b3SJooyung Han 28 - inverse_amplitude_neg_exponent); 411*5f39d1b3SJooyung Han // fixed-point tanh and multiplication 412*5f39d1b3SJooyung Han F0 fixedpoint_output = tanh(fixedpoint_input) * amplitude_normalized; 413*5f39d1b3SJooyung Han // right shift 414*5f39d1b3SJooyung Han DataType int32_output = 415*5f39d1b3SJooyung Han Add(Dup<DataType>(real_zero_as_int32), 416*5f39d1b3SJooyung Han ShiftRight(fixedpoint_output.raw(), 31 - amplitude_exponent)); 417*5f39d1b3SJooyung Han 418*5f39d1b3SJooyung Han DataType mask_if_below_cutoff_min = 419*5f39d1b3SJooyung Han MaskIfLessThanOrEqual(input.reg[i], Dup<DataType>(input_cutoff_min)); 420*5f39d1b3SJooyung Han DataType mask_if_above_cutoff_max = MaskIfGreaterThanOrEqual( 421*5f39d1b3SJooyung Han input.reg[i], Dup<DataType>(input_cutoff_max)); 422*5f39d1b3SJooyung Han 423*5f39d1b3SJooyung Han output.reg[i] = SelectUsingMask( 424*5f39d1b3SJooyung Han mask_if_below_cutoff_min, Dup<DataType>(output_min), 425*5f39d1b3SJooyung Han SelectUsingMask(mask_if_above_cutoff_max, Dup<DataType>(output_max), 426*5f39d1b3SJooyung Han int32_output)); 427*5f39d1b3SJooyung Han } 428*5f39d1b3SJooyung Han return output; 429*5f39d1b3SJooyung Han } 430*5f39d1b3SJooyung Han 431*5f39d1b3SJooyung Han const OutputStage& output_stage; 432*5f39d1b3SJooyung Han std::int32_t input_cutoff_min, input_cutoff_max; 433*5f39d1b3SJooyung Han std::int32_t output_min, output_max; 434*5f39d1b3SJooyung Han FixedPoint<DataType, 0> inverse_amplitude_normalized; 435*5f39d1b3SJooyung Han int inverse_amplitude_neg_exponent; 436*5f39d1b3SJooyung Han FixedPoint<DataType, 0> amplitude_normalized; 437*5f39d1b3SJooyung Han int amplitude_exponent; 438*5f39d1b3SJooyung Han }; 439*5f39d1b3SJooyung Han 440*5f39d1b3SJooyung Han // OutputPipelineOutputType is a helper to determine the output data type of a 441*5f39d1b3SJooyung Han // pipeline, for a 442*5f39d1b3SJooyung Han // given input data type. It is a recursive template; see the explanation on 443*5f39d1b3SJooyung Han // OutputPipelineEvalImpl below. 444*5f39d1b3SJooyung Han template <typename OutputPipelineType, int FirstStage, typename InputType, 445*5f39d1b3SJooyung Han bool StopRecursion = 446*5f39d1b3SJooyung Han FirstStage == std::tuple_size<OutputPipelineType>::value> 447*5f39d1b3SJooyung Han struct OutputPipelineOutputType { 448*5f39d1b3SJooyung Han typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type 449*5f39d1b3SJooyung Han FirstStageType; 450*5f39d1b3SJooyung Han typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType 451*5f39d1b3SJooyung Han FirstStageOutputType; 452*5f39d1b3SJooyung Han typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage + 1, 453*5f39d1b3SJooyung Han FirstStageOutputType>::Type Type; 454*5f39d1b3SJooyung Han }; 455*5f39d1b3SJooyung Han 456*5f39d1b3SJooyung Han template <typename OutputPipelineType, int FirstStage, typename InputType> 457*5f39d1b3SJooyung Han struct OutputPipelineOutputType<OutputPipelineType, FirstStage, InputType, 458*5f39d1b3SJooyung Han true> { 459*5f39d1b3SJooyung Han typedef InputType Type; 460*5f39d1b3SJooyung Han }; 461*5f39d1b3SJooyung Han 462*5f39d1b3SJooyung Han // OutputPipelineEvalImpl is a helper to implement the evaluation of 463*5f39d1b3SJooyung Han // the whole pipeline. It is a recursive template to implement compile-time 464*5f39d1b3SJooyung Han // unrolling of the loop over all pipeline stages. The 'FirstStage' parameter 465*5f39d1b3SJooyung Han // is how we implement recursion: each specialization implements only 466*5f39d1b3SJooyung Han // evaluation starting at 'FirstStage'. The StopRecursion parameter is just a 467*5f39d1b3SJooyung Han // helper to implement the termination of the recursion as a partial 468*5f39d1b3SJooyung Han // specialization below. 469*5f39d1b3SJooyung Han template <typename OutputPipelineType, int FirstStage, typename InputType, 470*5f39d1b3SJooyung Han bool StopRecursion = 471*5f39d1b3SJooyung Han FirstStage == std::tuple_size<OutputPipelineType>::value> 472*5f39d1b3SJooyung Han struct OutputPipelineEvalImpl { 473*5f39d1b3SJooyung Han typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type 474*5f39d1b3SJooyung Han FirstStageType; 475*5f39d1b3SJooyung Han typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType 476*5f39d1b3SJooyung Han FirstStageOutputType; 477*5f39d1b3SJooyung Han typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage, 478*5f39d1b3SJooyung Han InputType>::Type OutputType; 479*5f39d1b3SJooyung Han 480*5f39d1b3SJooyung Han OutputPipelineEvalImpl(const OutputPipelineType& output_pipeline) 481*5f39d1b3SJooyung Han : head_impl(std::get<FirstStage>(output_pipeline)), 482*5f39d1b3SJooyung Han tail_impl(output_pipeline) {} 483*5f39d1b3SJooyung Han 484*5f39d1b3SJooyung Han OutputType Eval(InputType input, int row, int col) const { 485*5f39d1b3SJooyung Han // Evaluate the first stage. 486*5f39d1b3SJooyung Han FirstStageOutputType first_stage_output = head_impl.Eval(input, row, col); 487*5f39d1b3SJooyung Han // Recurse into the remaining stages. 488*5f39d1b3SJooyung Han return tail_impl.Eval(first_stage_output, row, col); 489*5f39d1b3SJooyung Han } 490*5f39d1b3SJooyung Han 491*5f39d1b3SJooyung Han const OutputStageEvalImpl<FirstStageType, InputType> head_impl; 492*5f39d1b3SJooyung Han const OutputPipelineEvalImpl<OutputPipelineType, FirstStage + 1, 493*5f39d1b3SJooyung Han FirstStageOutputType> 494*5f39d1b3SJooyung Han tail_impl; 495*5f39d1b3SJooyung Han }; 496*5f39d1b3SJooyung Han 497*5f39d1b3SJooyung Han // Specialization on 'StopRecursion' for terminating the recursion. 498*5f39d1b3SJooyung Han template <typename OutputPipelineType, int FirstStage, typename InputType> 499*5f39d1b3SJooyung Han struct OutputPipelineEvalImpl<OutputPipelineType, FirstStage, InputType, true> { 500*5f39d1b3SJooyung Han OutputPipelineEvalImpl(const OutputPipelineType&) {} 501*5f39d1b3SJooyung Han 502*5f39d1b3SJooyung Han InputType Eval(InputType input, int, int) const { 503*5f39d1b3SJooyung Han // Terminating the recursion. 504*5f39d1b3SJooyung Han return input; 505*5f39d1b3SJooyung Han } 506*5f39d1b3SJooyung Han }; 507*5f39d1b3SJooyung Han 508*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename DstType> 509*5f39d1b3SJooyung Han struct StoreFinalOutputImpl { 510*5f39d1b3SJooyung Han static_assert(std::is_same<RegisterBlockType, void>::value, 511*5f39d1b3SJooyung Han "This generic impl should never be hit"); 512*5f39d1b3SJooyung Han }; 513*5f39d1b3SJooyung Han 514*5f39d1b3SJooyung Han template <typename ScalarType, int Rows, int Cols, typename DstType> 515*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegisterBlock<ScalarType, Rows, Cols>, DstType> { 516*5f39d1b3SJooyung Han using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; 517*5f39d1b3SJooyung Han static void Run(const RegisterBlockType& src, DstType* dst, int row, 518*5f39d1b3SJooyung Han int col) { 519*5f39d1b3SJooyung Han for (int r = 0; r < Rows; r++) { 520*5f39d1b3SJooyung Han for (int c = 0; c < Cols; c++) { 521*5f39d1b3SJooyung Han *dst->data(row + r, col + c) = src.buf.reg[r + c * Rows]; 522*5f39d1b3SJooyung Han } 523*5f39d1b3SJooyung Han } 524*5f39d1b3SJooyung Han } 525*5f39d1b3SJooyung Han }; 526*5f39d1b3SJooyung Han 527*5f39d1b3SJooyung Han // StoreFinalOutput takes the final value at the end of the output pipeline and 528*5f39d1b3SJooyung Han // stores it into the destination matrix. It can be specialized for different 529*5f39d1b3SJooyung Han // data types; the generic implementation here is typically used only for plain 530*5f39d1b3SJooyung Han // old scalar (not SIMD) types. 531*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename DstType> 532*5f39d1b3SJooyung Han void StoreFinalOutput(RegisterBlockType src, DstType* dst, int row, int col) { 533*5f39d1b3SJooyung Han StoreFinalOutputImpl<RegisterBlockType, DstType>::Run(src, dst, row, col); 534*5f39d1b3SJooyung Han } 535*5f39d1b3SJooyung Han 536*5f39d1b3SJooyung Han template <typename OutputPipelineType, typename InputType> 537*5f39d1b3SJooyung Han struct OutputPipelineExecutor { 538*5f39d1b3SJooyung Han OutputPipelineExecutor(const OutputPipelineType& output_pipeline) 539*5f39d1b3SJooyung Han : output_pipeline_eval_impl_(output_pipeline) {} 540*5f39d1b3SJooyung Han 541*5f39d1b3SJooyung Han // Execute is the entry point into the output pipeline evaluation 542*5f39d1b3SJooyung Han // code. It should be the only thing that unpack code calls. It takes the 543*5f39d1b3SJooyung Han // result 544*5f39d1b3SJooyung Han // of the unpack stage and stores it into the destination matrix. 545*5f39d1b3SJooyung Han template <typename DstType> 546*5f39d1b3SJooyung Han void Execute(InputType input, DstType* dst, int src_global_row, 547*5f39d1b3SJooyung Han int src_global_col, int dst_row, int dst_col) const { 548*5f39d1b3SJooyung Han // Statically assert that the output pipeline matches the given destination 549*5f39d1b3SJooyung Han // matrix's scalar type. 550*5f39d1b3SJooyung Han typedef typename OutputPipelineOutputType< 551*5f39d1b3SJooyung Han OutputPipelineType, 0, InputType>::Type::BufferType::ScalarType 552*5f39d1b3SJooyung Han 553*5f39d1b3SJooyung Han ScalarOutputType; 554*5f39d1b3SJooyung Han typedef typename DstType::Scalar ScalarDstType; 555*5f39d1b3SJooyung Han static_assert(std::is_same<ScalarOutputType, ScalarDstType>::value, 556*5f39d1b3SJooyung Han "mismatched destination scalar type and output pipeline"); 557*5f39d1b3SJooyung Han 558*5f39d1b3SJooyung Han // Evaluate the output pipeline. 559*5f39d1b3SJooyung Han auto output = 560*5f39d1b3SJooyung Han output_pipeline_eval_impl_.Eval(input, src_global_row, src_global_col); 561*5f39d1b3SJooyung Han // Store the result into the destination matrix. 562*5f39d1b3SJooyung Han StoreFinalOutput(output, dst, dst_row, dst_col); 563*5f39d1b3SJooyung Han } 564*5f39d1b3SJooyung Han 565*5f39d1b3SJooyung Han const OutputPipelineEvalImpl<OutputPipelineType, 0, InputType> 566*5f39d1b3SJooyung Han output_pipeline_eval_impl_; 567*5f39d1b3SJooyung Han }; 568*5f39d1b3SJooyung Han 569*5f39d1b3SJooyung Han } // namespace gemmlowp 570*5f39d1b3SJooyung Han 571*5f39d1b3SJooyung Han #ifdef GEMMLOWP_NEON 572*5f39d1b3SJooyung Han #include "output_neon.h" 573*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_SSE4) 574*5f39d1b3SJooyung Han #include "output_sse.h" 575*5f39d1b3SJooyung Han #elif defined(GEMMLOWP_MSA) 576*5f39d1b3SJooyung Han #include "output_msa.h" 577*5f39d1b3SJooyung Han #endif 578*5f39d1b3SJooyung Han 579*5f39d1b3SJooyung Han #endif // GEMMLOWP_INTERNAL_OUTPUT_H_ 580