1*5f39d1b3SJooyung Han // Copyright 2015 Google Inc. All Rights Reserved. 2*5f39d1b3SJooyung Han // 3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License"); 4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License. 5*5f39d1b3SJooyung Han // You may obtain a copy of the License at 6*5f39d1b3SJooyung Han // 7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0 8*5f39d1b3SJooyung Han // 9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software 10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS, 11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and 13*5f39d1b3SJooyung Han // limitations under the License. 14*5f39d1b3SJooyung Han 15*5f39d1b3SJooyung Han // output_sse.h: optimized SSE4.2 specializations of the templates in output.h. 16*5f39d1b3SJooyung Han 17*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ 18*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ 19*5f39d1b3SJooyung Han 20*5f39d1b3SJooyung Han #include "output.h" 21*5f39d1b3SJooyung Han 22*5f39d1b3SJooyung Han #include <smmintrin.h> 23*5f39d1b3SJooyung Han 24*5f39d1b3SJooyung Han namespace gemmlowp { 25*5f39d1b3SJooyung Han 26*5f39d1b3SJooyung Han template <> 27*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 28*5f39d1b3SJooyung Han RegBufferInt32<4>> { 29*5f39d1b3SJooyung Han typedef RegBufferInt32<4> InputType; 30*5f39d1b3SJooyung Han typedef RegBufferUint8<4> OutputType; 31*5f39d1b3SJooyung Han 32*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToUint8 OutputStage; 33*5f39d1b3SJooyung Han 34*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 35*5f39d1b3SJooyung Han 36*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 37*5f39d1b3SJooyung Han OutputType output; 38*5f39d1b3SJooyung Han __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]); 39*5f39d1b3SJooyung Han __m128i res_8 = _mm_packus_epi16(res_16, res_16); 40*5f39d1b3SJooyung Han output.reg[0] = _mm_cvtsi128_si32(res_8); 41*5f39d1b3SJooyung Han return output; 42*5f39d1b3SJooyung Han } 43*5f39d1b3SJooyung Han }; 44*5f39d1b3SJooyung Han 45*5f39d1b3SJooyung Han template <> 46*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 47*5f39d1b3SJooyung Han RegBufferInt32<8>> { 48*5f39d1b3SJooyung Han typedef RegBufferInt32<8> InputType; 49*5f39d1b3SJooyung Han typedef RegBufferUint8<8> OutputType; 50*5f39d1b3SJooyung Han 51*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToUint8 OutputStage; 52*5f39d1b3SJooyung Han 53*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 54*5f39d1b3SJooyung Han 55*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 56*5f39d1b3SJooyung Han OutputType output; 57*5f39d1b3SJooyung Han __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]); 58*5f39d1b3SJooyung Han __m128i res_8 = _mm_packus_epi16(res_16, res_16); 59*5f39d1b3SJooyung Han output.reg[0] = _mm_extract_epi32(res_8, 0); 60*5f39d1b3SJooyung Han output.reg[1] = _mm_extract_epi32(res_8, 1); 61*5f39d1b3SJooyung Han return output; 62*5f39d1b3SJooyung Han } 63*5f39d1b3SJooyung Han }; 64*5f39d1b3SJooyung Han 65*5f39d1b3SJooyung Han template <> 66*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 67*5f39d1b3SJooyung Han RegBufferInt32<16>> { 68*5f39d1b3SJooyung Han typedef RegBufferInt32<16> InputType; 69*5f39d1b3SJooyung Han typedef RegBufferUint8<16> OutputType; 70*5f39d1b3SJooyung Han 71*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToUint8 OutputStage; 72*5f39d1b3SJooyung Han 73*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 74*5f39d1b3SJooyung Han 75*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 76*5f39d1b3SJooyung Han OutputType output; 77*5f39d1b3SJooyung Han __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]); 78*5f39d1b3SJooyung Han __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]); 79*5f39d1b3SJooyung Han output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1); 80*5f39d1b3SJooyung Han return output; 81*5f39d1b3SJooyung Han } 82*5f39d1b3SJooyung Han }; 83*5f39d1b3SJooyung Han 84*5f39d1b3SJooyung Han template <> 85*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 86*5f39d1b3SJooyung Han RegBufferInt32<32>> { 87*5f39d1b3SJooyung Han typedef RegBufferInt32<32> InputType; 88*5f39d1b3SJooyung Han typedef RegBufferUint8<32> OutputType; 89*5f39d1b3SJooyung Han 90*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToUint8 OutputStage; 91*5f39d1b3SJooyung Han 92*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 93*5f39d1b3SJooyung Han 94*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 95*5f39d1b3SJooyung Han OutputType output; 96*5f39d1b3SJooyung Han __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]); 97*5f39d1b3SJooyung Han __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]); 98*5f39d1b3SJooyung Han output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1); 99*5f39d1b3SJooyung Han __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]); 100*5f39d1b3SJooyung Han __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]); 101*5f39d1b3SJooyung Han output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3); 102*5f39d1b3SJooyung Han return output; 103*5f39d1b3SJooyung Han } 104*5f39d1b3SJooyung Han }; 105*5f39d1b3SJooyung Han 106*5f39d1b3SJooyung Han template <> 107*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 108*5f39d1b3SJooyung Han RegBufferInt32<4>> { 109*5f39d1b3SJooyung Han typedef RegBufferInt32<4> InputType; 110*5f39d1b3SJooyung Han typedef RegBufferInt16<4> OutputType; 111*5f39d1b3SJooyung Han 112*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToInt16 OutputStage; 113*5f39d1b3SJooyung Han 114*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 115*5f39d1b3SJooyung Han 116*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 117*5f39d1b3SJooyung Han OutputType output; 118*5f39d1b3SJooyung Han __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]); 119*5f39d1b3SJooyung Han output.reg[0] = _mm_extract_epi16(res_16, 0); 120*5f39d1b3SJooyung Han output.reg[1] = _mm_extract_epi16(res_16, 1); 121*5f39d1b3SJooyung Han output.reg[2] = _mm_extract_epi16(res_16, 2); 122*5f39d1b3SJooyung Han output.reg[3] = _mm_extract_epi16(res_16, 3); 123*5f39d1b3SJooyung Han return output; 124*5f39d1b3SJooyung Han } 125*5f39d1b3SJooyung Han }; 126*5f39d1b3SJooyung Han 127*5f39d1b3SJooyung Han template <> 128*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 129*5f39d1b3SJooyung Han RegBufferInt32<8>> { 130*5f39d1b3SJooyung Han typedef RegBufferInt32<8> InputType; 131*5f39d1b3SJooyung Han typedef RegBufferInt16<8> OutputType; 132*5f39d1b3SJooyung Han 133*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToInt16 OutputStage; 134*5f39d1b3SJooyung Han 135*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 136*5f39d1b3SJooyung Han 137*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 138*5f39d1b3SJooyung Han OutputType output; 139*5f39d1b3SJooyung Han output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]); 140*5f39d1b3SJooyung Han return output; 141*5f39d1b3SJooyung Han } 142*5f39d1b3SJooyung Han }; 143*5f39d1b3SJooyung Han 144*5f39d1b3SJooyung Han template <> 145*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 146*5f39d1b3SJooyung Han RegBufferInt32<16>> { 147*5f39d1b3SJooyung Han typedef RegBufferInt32<16> InputType; 148*5f39d1b3SJooyung Han typedef RegBufferInt16<16> OutputType; 149*5f39d1b3SJooyung Han 150*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToInt16 OutputStage; 151*5f39d1b3SJooyung Han 152*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 153*5f39d1b3SJooyung Han 154*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 155*5f39d1b3SJooyung Han OutputType output; 156*5f39d1b3SJooyung Han output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]); 157*5f39d1b3SJooyung Han output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]); 158*5f39d1b3SJooyung Han return output; 159*5f39d1b3SJooyung Han } 160*5f39d1b3SJooyung Han }; 161*5f39d1b3SJooyung Han 162*5f39d1b3SJooyung Han template <> 163*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 164*5f39d1b3SJooyung Han RegBufferInt32<32>> { 165*5f39d1b3SJooyung Han typedef RegBufferInt32<32> InputType; 166*5f39d1b3SJooyung Han typedef RegBufferInt16<32> OutputType; 167*5f39d1b3SJooyung Han 168*5f39d1b3SJooyung Han typedef OutputStageSaturatingCastToInt16 OutputStage; 169*5f39d1b3SJooyung Han 170*5f39d1b3SJooyung Han OutputStageEvalBufferImpl(const OutputStage&) {} 171*5f39d1b3SJooyung Han 172*5f39d1b3SJooyung Han OutputType Eval(InputType input) const { 173*5f39d1b3SJooyung Han OutputType output; 174*5f39d1b3SJooyung Han output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]); 175*5f39d1b3SJooyung Han output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]); 176*5f39d1b3SJooyung Han output.reg[2] = _mm_packs_epi32(input.reg[4], input.reg[5]); 177*5f39d1b3SJooyung Han output.reg[3] = _mm_packs_epi32(input.reg[6], input.reg[7]); 178*5f39d1b3SJooyung Han return output; 179*5f39d1b3SJooyung Han } 180*5f39d1b3SJooyung Han }; 181*5f39d1b3SJooyung Han 182*5f39d1b3SJooyung Han template <typename DstType> 183*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { 184*5f39d1b3SJooyung Han static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, 185*5f39d1b3SJooyung Han int col) { 186*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 187*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 188*5f39d1b3SJooyung Han } else { 189*5f39d1b3SJooyung Han *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 190*5f39d1b3SJooyung Han *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 191*5f39d1b3SJooyung Han *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 192*5f39d1b3SJooyung Han *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 193*5f39d1b3SJooyung Han } 194*5f39d1b3SJooyung Han } 195*5f39d1b3SJooyung Han }; 196*5f39d1b3SJooyung Han 197*5f39d1b3SJooyung Han template <typename DstType> 198*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { 199*5f39d1b3SJooyung Han static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, 200*5f39d1b3SJooyung Han int col) { 201*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 202*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 203*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); 204*5f39d1b3SJooyung Han } else { 205*5f39d1b3SJooyung Han *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 206*5f39d1b3SJooyung Han *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 207*5f39d1b3SJooyung Han *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 208*5f39d1b3SJooyung Han *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 209*5f39d1b3SJooyung Han *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); 210*5f39d1b3SJooyung Han *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); 211*5f39d1b3SJooyung Han *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); 212*5f39d1b3SJooyung Han *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); 213*5f39d1b3SJooyung Han } 214*5f39d1b3SJooyung Han } 215*5f39d1b3SJooyung Han }; 216*5f39d1b3SJooyung Han 217*5f39d1b3SJooyung Han template <typename DstType> 218*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> { 219*5f39d1b3SJooyung Han static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row, 220*5f39d1b3SJooyung Han int col) { 221*5f39d1b3SJooyung Han *dst->data(row + 0, col) = src.buf.reg[0]; 222*5f39d1b3SJooyung Han *dst->data(row + 1, col) = src.buf.reg[1]; 223*5f39d1b3SJooyung Han *dst->data(row + 2, col) = src.buf.reg[2]; 224*5f39d1b3SJooyung Han *dst->data(row + 3, col) = src.buf.reg[3]; 225*5f39d1b3SJooyung Han } 226*5f39d1b3SJooyung Han }; 227*5f39d1b3SJooyung Han 228*5f39d1b3SJooyung Han template <typename DstType> 229*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> { 230*5f39d1b3SJooyung Han static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row, 231*5f39d1b3SJooyung Han int col) { 232*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 233*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row, col), src.buf.reg[0]); 234*5f39d1b3SJooyung Han } else { 235*5f39d1b3SJooyung Han *dst->data(row + 0, col) = _mm_extract_epi16(src.buf.reg[0], 0); 236*5f39d1b3SJooyung Han *dst->data(row + 1, col) = _mm_extract_epi16(src.buf.reg[0], 1); 237*5f39d1b3SJooyung Han *dst->data(row + 2, col) = _mm_extract_epi16(src.buf.reg[0], 2); 238*5f39d1b3SJooyung Han *dst->data(row + 3, col) = _mm_extract_epi16(src.buf.reg[0], 3); 239*5f39d1b3SJooyung Han *dst->data(row + 4, col) = _mm_extract_epi16(src.buf.reg[0], 4); 240*5f39d1b3SJooyung Han *dst->data(row + 5, col) = _mm_extract_epi16(src.buf.reg[0], 5); 241*5f39d1b3SJooyung Han *dst->data(row + 6, col) = _mm_extract_epi16(src.buf.reg[0], 6); 242*5f39d1b3SJooyung Han *dst->data(row + 7, col) = _mm_extract_epi16(src.buf.reg[0], 7); 243*5f39d1b3SJooyung Han } 244*5f39d1b3SJooyung Han } 245*5f39d1b3SJooyung Han }; 246*5f39d1b3SJooyung Han 247*5f39d1b3SJooyung Han inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { 248*5f39d1b3SJooyung Han __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]); 249*5f39d1b3SJooyung Han __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]); 250*5f39d1b3SJooyung Han __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]); 251*5f39d1b3SJooyung Han __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]); 252*5f39d1b3SJooyung Han 253*5f39d1b3SJooyung Han RegBlockInt32<4, 4> result; 254*5f39d1b3SJooyung Han result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1); 255*5f39d1b3SJooyung Han result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1); 256*5f39d1b3SJooyung Han result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3); 257*5f39d1b3SJooyung Han result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3); 258*5f39d1b3SJooyung Han return result; 259*5f39d1b3SJooyung Han } 260*5f39d1b3SJooyung Han 261*5f39d1b3SJooyung Han template <typename DstType> 262*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { 263*5f39d1b3SJooyung Han static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, 264*5f39d1b3SJooyung Han int col) { 265*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 266*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 267*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]); 268*5f39d1b3SJooyung Han } 269*5f39d1b3SJooyung Han } else { 270*5f39d1b3SJooyung Han const auto transpose = Transpose(src); 271*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 272*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]); 273*5f39d1b3SJooyung Han } 274*5f39d1b3SJooyung Han } 275*5f39d1b3SJooyung Han } 276*5f39d1b3SJooyung Han }; 277*5f39d1b3SJooyung Han 278*5f39d1b3SJooyung Han template <typename DstType> 279*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> { 280*5f39d1b3SJooyung Han static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row, 281*5f39d1b3SJooyung Han int col) { 282*5f39d1b3SJooyung Han std::int16_t buf[16]; 283*5f39d1b3SJooyung Han StoreInt16x8(buf + 0, src.buf.reg[0]); 284*5f39d1b3SJooyung Han StoreInt16x8(buf + 8, src.buf.reg[1]); 285*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 286*5f39d1b3SJooyung Han for (int j = 0; j < 4; j++) { 287*5f39d1b3SJooyung Han *dst->data(row + i, col + j) = buf[i + 4 * j]; 288*5f39d1b3SJooyung Han } 289*5f39d1b3SJooyung Han } 290*5f39d1b3SJooyung Han } 291*5f39d1b3SJooyung Han }; 292*5f39d1b3SJooyung Han 293*5f39d1b3SJooyung Han template <typename DstType> 294*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { 295*5f39d1b3SJooyung Han static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, 296*5f39d1b3SJooyung Han int col) { 297*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 298*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 299*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); 300*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); 301*5f39d1b3SJooyung Han } 302*5f39d1b3SJooyung Han } else { 303*5f39d1b3SJooyung Han RegBlockInt32<4, 4> top; 304*5f39d1b3SJooyung Han top.buf.reg[0] = src.buf.reg[0]; 305*5f39d1b3SJooyung Han top.buf.reg[1] = src.buf.reg[2]; 306*5f39d1b3SJooyung Han top.buf.reg[2] = src.buf.reg[4]; 307*5f39d1b3SJooyung Han top.buf.reg[3] = src.buf.reg[6]; 308*5f39d1b3SJooyung Han const auto transpose_top = Transpose(top); 309*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 310*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]); 311*5f39d1b3SJooyung Han } 312*5f39d1b3SJooyung Han RegBlockInt32<4, 4> bottom; 313*5f39d1b3SJooyung Han bottom.buf.reg[0] = src.buf.reg[1]; 314*5f39d1b3SJooyung Han bottom.buf.reg[1] = src.buf.reg[3]; 315*5f39d1b3SJooyung Han bottom.buf.reg[2] = src.buf.reg[5]; 316*5f39d1b3SJooyung Han bottom.buf.reg[3] = src.buf.reg[7]; 317*5f39d1b3SJooyung Han const auto transpose_bottom = Transpose(bottom); 318*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 319*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]); 320*5f39d1b3SJooyung Han } 321*5f39d1b3SJooyung Han } 322*5f39d1b3SJooyung Han } 323*5f39d1b3SJooyung Han }; 324*5f39d1b3SJooyung Han 325*5f39d1b3SJooyung Han template <typename DstType> 326*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> { 327*5f39d1b3SJooyung Han static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row, 328*5f39d1b3SJooyung Han int col) { 329*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 330*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 331*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]); 332*5f39d1b3SJooyung Han } 333*5f39d1b3SJooyung Han } else { 334*5f39d1b3SJooyung Han std::int16_t buf[32]; 335*5f39d1b3SJooyung Han StoreInt16x8(buf + 0, src.buf.reg[0]); 336*5f39d1b3SJooyung Han StoreInt16x8(buf + 8, src.buf.reg[1]); 337*5f39d1b3SJooyung Han StoreInt16x8(buf + 16, src.buf.reg[2]); 338*5f39d1b3SJooyung Han StoreInt16x8(buf + 24, src.buf.reg[3]); 339*5f39d1b3SJooyung Han for (int i = 0; i < 8; i++) { 340*5f39d1b3SJooyung Han for (int j = 0; j < 4; j++) { 341*5f39d1b3SJooyung Han *dst->data(row + i, col + j) = buf[i + 8 * j]; 342*5f39d1b3SJooyung Han } 343*5f39d1b3SJooyung Han } 344*5f39d1b3SJooyung Han } 345*5f39d1b3SJooyung Han } 346*5f39d1b3SJooyung Han }; 347*5f39d1b3SJooyung Han 348*5f39d1b3SJooyung Han template <typename DstType> 349*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { 350*5f39d1b3SJooyung Han static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, 351*5f39d1b3SJooyung Han int col) { 352*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 353*5f39d1b3SJooyung Han for (int i = 0; i < 8; i++) { 354*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); 355*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); 356*5f39d1b3SJooyung Han } 357*5f39d1b3SJooyung Han } else { 358*5f39d1b3SJooyung Han RegBlockInt32<4, 4> top_left; 359*5f39d1b3SJooyung Han top_left.buf.reg[0] = src.buf.reg[0]; 360*5f39d1b3SJooyung Han top_left.buf.reg[1] = src.buf.reg[2]; 361*5f39d1b3SJooyung Han top_left.buf.reg[2] = src.buf.reg[4]; 362*5f39d1b3SJooyung Han top_left.buf.reg[3] = src.buf.reg[6]; 363*5f39d1b3SJooyung Han const auto transpose_top_left = Transpose(top_left); 364*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 365*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]); 366*5f39d1b3SJooyung Han } 367*5f39d1b3SJooyung Han RegBlockInt32<4, 4> bottom_left; 368*5f39d1b3SJooyung Han bottom_left.buf.reg[0] = src.buf.reg[1]; 369*5f39d1b3SJooyung Han bottom_left.buf.reg[1] = src.buf.reg[3]; 370*5f39d1b3SJooyung Han bottom_left.buf.reg[2] = src.buf.reg[5]; 371*5f39d1b3SJooyung Han bottom_left.buf.reg[3] = src.buf.reg[7]; 372*5f39d1b3SJooyung Han const auto transpose_bottom_left = Transpose(bottom_left); 373*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 374*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + 4 + i, col), 375*5f39d1b3SJooyung Han transpose_bottom_left.buf.reg[i]); 376*5f39d1b3SJooyung Han } 377*5f39d1b3SJooyung Han RegBlockInt32<4, 4> top_right; 378*5f39d1b3SJooyung Han top_right.buf.reg[0] = src.buf.reg[8]; 379*5f39d1b3SJooyung Han top_right.buf.reg[1] = src.buf.reg[10]; 380*5f39d1b3SJooyung Han top_right.buf.reg[2] = src.buf.reg[12]; 381*5f39d1b3SJooyung Han top_right.buf.reg[3] = src.buf.reg[14]; 382*5f39d1b3SJooyung Han const auto transpose_top_right = Transpose(top_right); 383*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 384*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + i, col + 4), 385*5f39d1b3SJooyung Han transpose_top_right.buf.reg[i]); 386*5f39d1b3SJooyung Han } 387*5f39d1b3SJooyung Han RegBlockInt32<4, 4> bottom_right; 388*5f39d1b3SJooyung Han bottom_right.buf.reg[0] = src.buf.reg[9]; 389*5f39d1b3SJooyung Han bottom_right.buf.reg[1] = src.buf.reg[11]; 390*5f39d1b3SJooyung Han bottom_right.buf.reg[2] = src.buf.reg[13]; 391*5f39d1b3SJooyung Han bottom_right.buf.reg[3] = src.buf.reg[15]; 392*5f39d1b3SJooyung Han const auto transpose_bottom_right = Transpose(bottom_right); 393*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 394*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row + 4 + i, col + 4), 395*5f39d1b3SJooyung Han transpose_bottom_right.buf.reg[i]); 396*5f39d1b3SJooyung Han } 397*5f39d1b3SJooyung Han } 398*5f39d1b3SJooyung Han } 399*5f39d1b3SJooyung Han }; 400*5f39d1b3SJooyung Han 401*5f39d1b3SJooyung Han template <typename DstType> 402*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> { 403*5f39d1b3SJooyung Han static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row, 404*5f39d1b3SJooyung Han int col) { 405*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 406*5f39d1b3SJooyung Han for (int i = 0; i < 8; i++) { 407*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]); 408*5f39d1b3SJooyung Han } 409*5f39d1b3SJooyung Han } else { 410*5f39d1b3SJooyung Han // top-left 4x4 411*5f39d1b3SJooyung Han __m128i t0 = _mm_unpacklo_epi16(src.buf.reg[0], src.buf.reg[1]); 412*5f39d1b3SJooyung Han __m128i t1 = _mm_unpacklo_epi16(src.buf.reg[2], src.buf.reg[3]); 413*5f39d1b3SJooyung Han __m128i u0 = _mm_unpacklo_epi32(t0, t1); 414*5f39d1b3SJooyung Han __m128i u1 = _mm_unpackhi_epi32(t0, t1); 415*5f39d1b3SJooyung Han // top-right 4x4 416*5f39d1b3SJooyung Han __m128i t2 = _mm_unpacklo_epi16(src.buf.reg[4], src.buf.reg[5]); 417*5f39d1b3SJooyung Han __m128i t3 = _mm_unpacklo_epi16(src.buf.reg[6], src.buf.reg[7]); 418*5f39d1b3SJooyung Han __m128i u2 = _mm_unpacklo_epi32(t2, t3); 419*5f39d1b3SJooyung Han __m128i u3 = _mm_unpackhi_epi32(t2, t3); 420*5f39d1b3SJooyung Han // bottom-left 4x4 421*5f39d1b3SJooyung Han __m128i t4 = _mm_unpackhi_epi16(src.buf.reg[0], src.buf.reg[1]); 422*5f39d1b3SJooyung Han __m128i t5 = _mm_unpackhi_epi16(src.buf.reg[2], src.buf.reg[3]); 423*5f39d1b3SJooyung Han __m128i u4 = _mm_unpacklo_epi32(t4, t5); 424*5f39d1b3SJooyung Han __m128i u5 = _mm_unpackhi_epi32(t4, t5); 425*5f39d1b3SJooyung Han // bottom-right 4x4 426*5f39d1b3SJooyung Han __m128i t6 = _mm_unpackhi_epi16(src.buf.reg[4], src.buf.reg[5]); 427*5f39d1b3SJooyung Han __m128i t7 = _mm_unpackhi_epi16(src.buf.reg[6], src.buf.reg[7]); 428*5f39d1b3SJooyung Han __m128i u6 = _mm_unpacklo_epi32(t6, t7); 429*5f39d1b3SJooyung Han __m128i u7 = _mm_unpackhi_epi32(t6, t7); 430*5f39d1b3SJooyung Han 431*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row + 0, col), _mm_unpacklo_epi64(u0, u2)); 432*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row + 1, col), _mm_unpackhi_epi64(u0, u2)); 433*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row + 2, col), _mm_unpacklo_epi64(u1, u3)); 434*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row + 3, col), _mm_unpackhi_epi64(u1, u3)); 435*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row + 4, col), _mm_unpacklo_epi64(u4, u6)); 436*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row + 5, col), _mm_unpackhi_epi64(u4, u6)); 437*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row + 6, col), _mm_unpacklo_epi64(u5, u7)); 438*5f39d1b3SJooyung Han StoreInt16x8(dst->data(row + 7, col), _mm_unpackhi_epi64(u5, u7)); 439*5f39d1b3SJooyung Han } 440*5f39d1b3SJooyung Han } 441*5f39d1b3SJooyung Han }; 442*5f39d1b3SJooyung Han 443*5f39d1b3SJooyung Han template <typename DstType> 444*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { 445*5f39d1b3SJooyung Han static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, 446*5f39d1b3SJooyung Han int col) { 447*5f39d1b3SJooyung Han if (DstType::kOrder == MapOrder::ColMajor) { 448*5f39d1b3SJooyung Han *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]); 449*5f39d1b3SJooyung Han *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]); 450*5f39d1b3SJooyung Han *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]); 451*5f39d1b3SJooyung Han *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]); 452*5f39d1b3SJooyung Han } else { 453*5f39d1b3SJooyung Han StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 454*5f39d1b3SJooyung Han } 455*5f39d1b3SJooyung Han } 456*5f39d1b3SJooyung Han }; 457*5f39d1b3SJooyung Han 458*5f39d1b3SJooyung Han template <typename DstType> 459*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { 460*5f39d1b3SJooyung Han static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, 461*5f39d1b3SJooyung Han int col) { 462*5f39d1b3SJooyung Han const std::uint32_t src_reg = src.buf.reg[0]; 463*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 464*5f39d1b3SJooyung Han *dst->data(row + i, col) = (src_reg >> (8 * i)); 465*5f39d1b3SJooyung Han } 466*5f39d1b3SJooyung Han } 467*5f39d1b3SJooyung Han }; 468*5f39d1b3SJooyung Han 469*5f39d1b3SJooyung Han template <typename DstType> 470*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { 471*5f39d1b3SJooyung Han static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, 472*5f39d1b3SJooyung Han int col) { 473*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 474*5f39d1b3SJooyung Han *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i)); 475*5f39d1b3SJooyung Han } 476*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 477*5f39d1b3SJooyung Han *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i)); 478*5f39d1b3SJooyung Han } 479*5f39d1b3SJooyung Han } 480*5f39d1b3SJooyung Han }; 481*5f39d1b3SJooyung Han 482*5f39d1b3SJooyung Han template <typename DstType> 483*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { 484*5f39d1b3SJooyung Han static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, 485*5f39d1b3SJooyung Han int col) { 486*5f39d1b3SJooyung Han for (int i = 0; i < 4; i++) { 487*5f39d1b3SJooyung Han *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); 488*5f39d1b3SJooyung Han } 489*5f39d1b3SJooyung Han } 490*5f39d1b3SJooyung Han }; 491*5f39d1b3SJooyung Han 492*5f39d1b3SJooyung Han template <typename DstType> 493*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { 494*5f39d1b3SJooyung Han static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, 495*5f39d1b3SJooyung Han int col) { 496*5f39d1b3SJooyung Han std::uint8_t buf[16]; 497*5f39d1b3SJooyung Han StoreUint8x16(buf, src.buf.reg[0]); 498*5f39d1b3SJooyung Han for (int c = 0; c < 4; c++) { 499*5f39d1b3SJooyung Han for (int r = 0; r < 4; r++) { 500*5f39d1b3SJooyung Han *dst->data(row + r, col + c) = buf[r + 4 * c]; 501*5f39d1b3SJooyung Han } 502*5f39d1b3SJooyung Han } 503*5f39d1b3SJooyung Han } 504*5f39d1b3SJooyung Han }; 505*5f39d1b3SJooyung Han 506*5f39d1b3SJooyung Han template <typename DstType> 507*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { 508*5f39d1b3SJooyung Han static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, 509*5f39d1b3SJooyung Han int col) { 510*5f39d1b3SJooyung Han std::uint8_t buf[32]; 511*5f39d1b3SJooyung Han StoreUint8x16(buf, src.buf.reg[0]); 512*5f39d1b3SJooyung Han StoreUint8x16(buf + 16, src.buf.reg[1]); 513*5f39d1b3SJooyung Han for (int c = 0; c < 4; c++) { 514*5f39d1b3SJooyung Han for (int r = 0; r < 8; r++) { 515*5f39d1b3SJooyung Han *dst->data(row + r, col + c) = buf[r + 8 * c]; 516*5f39d1b3SJooyung Han } 517*5f39d1b3SJooyung Han } 518*5f39d1b3SJooyung Han } 519*5f39d1b3SJooyung Han }; 520*5f39d1b3SJooyung Han 521*5f39d1b3SJooyung Han template <typename DstType> 522*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { 523*5f39d1b3SJooyung Han static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, 524*5f39d1b3SJooyung Han int col) { 525*5f39d1b3SJooyung Han std::uint8_t buf[64]; 526*5f39d1b3SJooyung Han StoreUint8x16(buf, src.buf.reg[0]); 527*5f39d1b3SJooyung Han StoreUint8x16(buf + 16, src.buf.reg[1]); 528*5f39d1b3SJooyung Han StoreUint8x16(buf + 32, src.buf.reg[2]); 529*5f39d1b3SJooyung Han StoreUint8x16(buf + 48, src.buf.reg[3]); 530*5f39d1b3SJooyung Han for (int c = 0; c < 8; c++) { 531*5f39d1b3SJooyung Han for (int r = 0; r < 8; r++) { 532*5f39d1b3SJooyung Han *dst->data(row + r, col + c) = buf[r + 8 * c]; 533*5f39d1b3SJooyung Han } 534*5f39d1b3SJooyung Han } 535*5f39d1b3SJooyung Han } 536*5f39d1b3SJooyung Han }; 537*5f39d1b3SJooyung Han 538*5f39d1b3SJooyung Han // Specialization for MatrixMap, for performance. 539*5f39d1b3SJooyung Han template <typename tScalar, MapOrder tOrder> 540*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, MatrixMap<tScalar, tOrder>> { 541*5f39d1b3SJooyung Han static void Run(const RegBlockUint8<8, 8>& src, 542*5f39d1b3SJooyung Han MatrixMap<tScalar, tOrder>* dst, int row, int col) { 543*5f39d1b3SJooyung Han std::uint8_t buf[64]; 544*5f39d1b3SJooyung Han StoreUint8x16(buf, src.buf.reg[0]); 545*5f39d1b3SJooyung Han StoreUint8x16(buf + 16, src.buf.reg[1]); 546*5f39d1b3SJooyung Han StoreUint8x16(buf + 32, src.buf.reg[2]); 547*5f39d1b3SJooyung Han StoreUint8x16(buf + 48, src.buf.reg[3]); 548*5f39d1b3SJooyung Han // Make a local copy so that the compiler can prove that data_ does not 549*5f39d1b3SJooyung Han // alias &data_ or &stride_. 550*5f39d1b3SJooyung Han MatrixMap<tScalar, tOrder> local = *dst; 551*5f39d1b3SJooyung Han for (int c = 0; c < 8; c++) { 552*5f39d1b3SJooyung Han for (int r = 0; r < 8; r++) { 553*5f39d1b3SJooyung Han *local.data(row + r, col + c) = buf[r + 8 * c]; 554*5f39d1b3SJooyung Han } 555*5f39d1b3SJooyung Han } 556*5f39d1b3SJooyung Han } 557*5f39d1b3SJooyung Han }; 558*5f39d1b3SJooyung Han 559*5f39d1b3SJooyung Han } // namespace gemmlowp 560*5f39d1b3SJooyung Han 561*5f39d1b3SJooyung Han #endif // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ 562