xref: /aosp_15_r20/external/gemmlowp/internal/output_sse.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2015 Google Inc. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // output_sse.h: optimized SSE4.2 specializations of the templates in output.h.
16*5f39d1b3SJooyung Han 
17*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
18*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
19*5f39d1b3SJooyung Han 
20*5f39d1b3SJooyung Han #include "output.h"
21*5f39d1b3SJooyung Han 
22*5f39d1b3SJooyung Han #include <smmintrin.h>
23*5f39d1b3SJooyung Han 
24*5f39d1b3SJooyung Han namespace gemmlowp {
25*5f39d1b3SJooyung Han 
26*5f39d1b3SJooyung Han template <>
27*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
28*5f39d1b3SJooyung Han                                  RegBufferInt32<4>> {
29*5f39d1b3SJooyung Han   typedef RegBufferInt32<4> InputType;
30*5f39d1b3SJooyung Han   typedef RegBufferUint8<4> OutputType;
31*5f39d1b3SJooyung Han 
32*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToUint8 OutputStage;
33*5f39d1b3SJooyung Han 
34*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
35*5f39d1b3SJooyung Han 
36*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
37*5f39d1b3SJooyung Han     OutputType output;
38*5f39d1b3SJooyung Han     __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
39*5f39d1b3SJooyung Han     __m128i res_8 = _mm_packus_epi16(res_16, res_16);
40*5f39d1b3SJooyung Han     output.reg[0] = _mm_cvtsi128_si32(res_8);
41*5f39d1b3SJooyung Han     return output;
42*5f39d1b3SJooyung Han   }
43*5f39d1b3SJooyung Han };
44*5f39d1b3SJooyung Han 
45*5f39d1b3SJooyung Han template <>
46*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
47*5f39d1b3SJooyung Han                                  RegBufferInt32<8>> {
48*5f39d1b3SJooyung Han   typedef RegBufferInt32<8> InputType;
49*5f39d1b3SJooyung Han   typedef RegBufferUint8<8> OutputType;
50*5f39d1b3SJooyung Han 
51*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToUint8 OutputStage;
52*5f39d1b3SJooyung Han 
53*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
54*5f39d1b3SJooyung Han 
55*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
56*5f39d1b3SJooyung Han     OutputType output;
57*5f39d1b3SJooyung Han     __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]);
58*5f39d1b3SJooyung Han     __m128i res_8 = _mm_packus_epi16(res_16, res_16);
59*5f39d1b3SJooyung Han     output.reg[0] = _mm_extract_epi32(res_8, 0);
60*5f39d1b3SJooyung Han     output.reg[1] = _mm_extract_epi32(res_8, 1);
61*5f39d1b3SJooyung Han     return output;
62*5f39d1b3SJooyung Han   }
63*5f39d1b3SJooyung Han };
64*5f39d1b3SJooyung Han 
65*5f39d1b3SJooyung Han template <>
66*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
67*5f39d1b3SJooyung Han                                  RegBufferInt32<16>> {
68*5f39d1b3SJooyung Han   typedef RegBufferInt32<16> InputType;
69*5f39d1b3SJooyung Han   typedef RegBufferUint8<16> OutputType;
70*5f39d1b3SJooyung Han 
71*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToUint8 OutputStage;
72*5f39d1b3SJooyung Han 
73*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
74*5f39d1b3SJooyung Han 
75*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
76*5f39d1b3SJooyung Han     OutputType output;
77*5f39d1b3SJooyung Han     __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
78*5f39d1b3SJooyung Han     __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
79*5f39d1b3SJooyung Han     output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
80*5f39d1b3SJooyung Han     return output;
81*5f39d1b3SJooyung Han   }
82*5f39d1b3SJooyung Han };
83*5f39d1b3SJooyung Han 
84*5f39d1b3SJooyung Han template <>
85*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
86*5f39d1b3SJooyung Han                                  RegBufferInt32<32>> {
87*5f39d1b3SJooyung Han   typedef RegBufferInt32<32> InputType;
88*5f39d1b3SJooyung Han   typedef RegBufferUint8<32> OutputType;
89*5f39d1b3SJooyung Han 
90*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToUint8 OutputStage;
91*5f39d1b3SJooyung Han 
92*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
93*5f39d1b3SJooyung Han 
94*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
95*5f39d1b3SJooyung Han     OutputType output;
96*5f39d1b3SJooyung Han     __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
97*5f39d1b3SJooyung Han     __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
98*5f39d1b3SJooyung Han     output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
99*5f39d1b3SJooyung Han     __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]);
100*5f39d1b3SJooyung Han     __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]);
101*5f39d1b3SJooyung Han     output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3);
102*5f39d1b3SJooyung Han     return output;
103*5f39d1b3SJooyung Han   }
104*5f39d1b3SJooyung Han };
105*5f39d1b3SJooyung Han 
106*5f39d1b3SJooyung Han template <>
107*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
108*5f39d1b3SJooyung Han                                  RegBufferInt32<4>> {
109*5f39d1b3SJooyung Han   typedef RegBufferInt32<4> InputType;
110*5f39d1b3SJooyung Han   typedef RegBufferInt16<4> OutputType;
111*5f39d1b3SJooyung Han 
112*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToInt16 OutputStage;
113*5f39d1b3SJooyung Han 
114*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
115*5f39d1b3SJooyung Han 
116*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
117*5f39d1b3SJooyung Han     OutputType output;
118*5f39d1b3SJooyung Han     __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
119*5f39d1b3SJooyung Han     output.reg[0] = _mm_extract_epi16(res_16, 0);
120*5f39d1b3SJooyung Han     output.reg[1] = _mm_extract_epi16(res_16, 1);
121*5f39d1b3SJooyung Han     output.reg[2] = _mm_extract_epi16(res_16, 2);
122*5f39d1b3SJooyung Han     output.reg[3] = _mm_extract_epi16(res_16, 3);
123*5f39d1b3SJooyung Han     return output;
124*5f39d1b3SJooyung Han   }
125*5f39d1b3SJooyung Han };
126*5f39d1b3SJooyung Han 
127*5f39d1b3SJooyung Han template <>
128*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
129*5f39d1b3SJooyung Han                                  RegBufferInt32<8>> {
130*5f39d1b3SJooyung Han   typedef RegBufferInt32<8> InputType;
131*5f39d1b3SJooyung Han   typedef RegBufferInt16<8> OutputType;
132*5f39d1b3SJooyung Han 
133*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToInt16 OutputStage;
134*5f39d1b3SJooyung Han 
135*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
136*5f39d1b3SJooyung Han 
137*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
138*5f39d1b3SJooyung Han     OutputType output;
139*5f39d1b3SJooyung Han     output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
140*5f39d1b3SJooyung Han     return output;
141*5f39d1b3SJooyung Han   }
142*5f39d1b3SJooyung Han };
143*5f39d1b3SJooyung Han 
144*5f39d1b3SJooyung Han template <>
145*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
146*5f39d1b3SJooyung Han                                  RegBufferInt32<16>> {
147*5f39d1b3SJooyung Han   typedef RegBufferInt32<16> InputType;
148*5f39d1b3SJooyung Han   typedef RegBufferInt16<16> OutputType;
149*5f39d1b3SJooyung Han 
150*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToInt16 OutputStage;
151*5f39d1b3SJooyung Han 
152*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
153*5f39d1b3SJooyung Han 
154*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
155*5f39d1b3SJooyung Han     OutputType output;
156*5f39d1b3SJooyung Han     output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
157*5f39d1b3SJooyung Han     output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]);
158*5f39d1b3SJooyung Han     return output;
159*5f39d1b3SJooyung Han   }
160*5f39d1b3SJooyung Han };
161*5f39d1b3SJooyung Han 
162*5f39d1b3SJooyung Han template <>
163*5f39d1b3SJooyung Han struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
164*5f39d1b3SJooyung Han                                  RegBufferInt32<32>> {
165*5f39d1b3SJooyung Han   typedef RegBufferInt32<32> InputType;
166*5f39d1b3SJooyung Han   typedef RegBufferInt16<32> OutputType;
167*5f39d1b3SJooyung Han 
168*5f39d1b3SJooyung Han   typedef OutputStageSaturatingCastToInt16 OutputStage;
169*5f39d1b3SJooyung Han 
170*5f39d1b3SJooyung Han   OutputStageEvalBufferImpl(const OutputStage&) {}
171*5f39d1b3SJooyung Han 
172*5f39d1b3SJooyung Han   OutputType Eval(InputType input) const {
173*5f39d1b3SJooyung Han     OutputType output;
174*5f39d1b3SJooyung Han     output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
175*5f39d1b3SJooyung Han     output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]);
176*5f39d1b3SJooyung Han     output.reg[2] = _mm_packs_epi32(input.reg[4], input.reg[5]);
177*5f39d1b3SJooyung Han     output.reg[3] = _mm_packs_epi32(input.reg[6], input.reg[7]);
178*5f39d1b3SJooyung Han     return output;
179*5f39d1b3SJooyung Han   }
180*5f39d1b3SJooyung Han };
181*5f39d1b3SJooyung Han 
182*5f39d1b3SJooyung Han template <typename DstType>
183*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
184*5f39d1b3SJooyung Han   static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
185*5f39d1b3SJooyung Han                   int col) {
186*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
187*5f39d1b3SJooyung Han       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
188*5f39d1b3SJooyung Han     } else {
189*5f39d1b3SJooyung Han       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
190*5f39d1b3SJooyung Han       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
191*5f39d1b3SJooyung Han       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
192*5f39d1b3SJooyung Han       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
193*5f39d1b3SJooyung Han     }
194*5f39d1b3SJooyung Han   }
195*5f39d1b3SJooyung Han };
196*5f39d1b3SJooyung Han 
197*5f39d1b3SJooyung Han template <typename DstType>
198*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
199*5f39d1b3SJooyung Han   static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
200*5f39d1b3SJooyung Han                   int col) {
201*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
202*5f39d1b3SJooyung Han       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
203*5f39d1b3SJooyung Han       StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
204*5f39d1b3SJooyung Han     } else {
205*5f39d1b3SJooyung Han       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
206*5f39d1b3SJooyung Han       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
207*5f39d1b3SJooyung Han       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
208*5f39d1b3SJooyung Han       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
209*5f39d1b3SJooyung Han       *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
210*5f39d1b3SJooyung Han       *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
211*5f39d1b3SJooyung Han       *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
212*5f39d1b3SJooyung Han       *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
213*5f39d1b3SJooyung Han     }
214*5f39d1b3SJooyung Han   }
215*5f39d1b3SJooyung Han };
216*5f39d1b3SJooyung Han 
217*5f39d1b3SJooyung Han template <typename DstType>
218*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
219*5f39d1b3SJooyung Han   static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
220*5f39d1b3SJooyung Han                   int col) {
221*5f39d1b3SJooyung Han     *dst->data(row + 0, col) = src.buf.reg[0];
222*5f39d1b3SJooyung Han     *dst->data(row + 1, col) = src.buf.reg[1];
223*5f39d1b3SJooyung Han     *dst->data(row + 2, col) = src.buf.reg[2];
224*5f39d1b3SJooyung Han     *dst->data(row + 3, col) = src.buf.reg[3];
225*5f39d1b3SJooyung Han   }
226*5f39d1b3SJooyung Han };
227*5f39d1b3SJooyung Han 
228*5f39d1b3SJooyung Han template <typename DstType>
229*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
230*5f39d1b3SJooyung Han   static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
231*5f39d1b3SJooyung Han                   int col) {
232*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
233*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
234*5f39d1b3SJooyung Han     } else {
235*5f39d1b3SJooyung Han       *dst->data(row + 0, col) = _mm_extract_epi16(src.buf.reg[0], 0);
236*5f39d1b3SJooyung Han       *dst->data(row + 1, col) = _mm_extract_epi16(src.buf.reg[0], 1);
237*5f39d1b3SJooyung Han       *dst->data(row + 2, col) = _mm_extract_epi16(src.buf.reg[0], 2);
238*5f39d1b3SJooyung Han       *dst->data(row + 3, col) = _mm_extract_epi16(src.buf.reg[0], 3);
239*5f39d1b3SJooyung Han       *dst->data(row + 4, col) = _mm_extract_epi16(src.buf.reg[0], 4);
240*5f39d1b3SJooyung Han       *dst->data(row + 5, col) = _mm_extract_epi16(src.buf.reg[0], 5);
241*5f39d1b3SJooyung Han       *dst->data(row + 6, col) = _mm_extract_epi16(src.buf.reg[0], 6);
242*5f39d1b3SJooyung Han       *dst->data(row + 7, col) = _mm_extract_epi16(src.buf.reg[0], 7);
243*5f39d1b3SJooyung Han     }
244*5f39d1b3SJooyung Han   }
245*5f39d1b3SJooyung Han };
246*5f39d1b3SJooyung Han 
247*5f39d1b3SJooyung Han inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
248*5f39d1b3SJooyung Han   __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]);
249*5f39d1b3SJooyung Han   __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]);
250*5f39d1b3SJooyung Han   __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]);
251*5f39d1b3SJooyung Han   __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]);
252*5f39d1b3SJooyung Han 
253*5f39d1b3SJooyung Han   RegBlockInt32<4, 4> result;
254*5f39d1b3SJooyung Han   result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1);
255*5f39d1b3SJooyung Han   result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1);
256*5f39d1b3SJooyung Han   result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3);
257*5f39d1b3SJooyung Han   result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3);
258*5f39d1b3SJooyung Han   return result;
259*5f39d1b3SJooyung Han }
260*5f39d1b3SJooyung Han 
261*5f39d1b3SJooyung Han template <typename DstType>
262*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
263*5f39d1b3SJooyung Han   static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
264*5f39d1b3SJooyung Han                   int col) {
265*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
266*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
267*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
268*5f39d1b3SJooyung Han       }
269*5f39d1b3SJooyung Han     } else {
270*5f39d1b3SJooyung Han       const auto transpose = Transpose(src);
271*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
272*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
273*5f39d1b3SJooyung Han       }
274*5f39d1b3SJooyung Han     }
275*5f39d1b3SJooyung Han   }
276*5f39d1b3SJooyung Han };
277*5f39d1b3SJooyung Han 
278*5f39d1b3SJooyung Han template <typename DstType>
279*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
280*5f39d1b3SJooyung Han   static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
281*5f39d1b3SJooyung Han                   int col) {
282*5f39d1b3SJooyung Han     std::int16_t buf[16];
283*5f39d1b3SJooyung Han     StoreInt16x8(buf + 0, src.buf.reg[0]);
284*5f39d1b3SJooyung Han     StoreInt16x8(buf + 8, src.buf.reg[1]);
285*5f39d1b3SJooyung Han     for (int i = 0; i < 4; i++) {
286*5f39d1b3SJooyung Han       for (int j = 0; j < 4; j++) {
287*5f39d1b3SJooyung Han         *dst->data(row + i, col + j) = buf[i + 4 * j];
288*5f39d1b3SJooyung Han       }
289*5f39d1b3SJooyung Han     }
290*5f39d1b3SJooyung Han   }
291*5f39d1b3SJooyung Han };
292*5f39d1b3SJooyung Han 
293*5f39d1b3SJooyung Han template <typename DstType>
294*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
295*5f39d1b3SJooyung Han   static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
296*5f39d1b3SJooyung Han                   int col) {
297*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
298*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
299*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
300*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
301*5f39d1b3SJooyung Han       }
302*5f39d1b3SJooyung Han     } else {
303*5f39d1b3SJooyung Han       RegBlockInt32<4, 4> top;
304*5f39d1b3SJooyung Han       top.buf.reg[0] = src.buf.reg[0];
305*5f39d1b3SJooyung Han       top.buf.reg[1] = src.buf.reg[2];
306*5f39d1b3SJooyung Han       top.buf.reg[2] = src.buf.reg[4];
307*5f39d1b3SJooyung Han       top.buf.reg[3] = src.buf.reg[6];
308*5f39d1b3SJooyung Han       const auto transpose_top = Transpose(top);
309*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
310*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
311*5f39d1b3SJooyung Han       }
312*5f39d1b3SJooyung Han       RegBlockInt32<4, 4> bottom;
313*5f39d1b3SJooyung Han       bottom.buf.reg[0] = src.buf.reg[1];
314*5f39d1b3SJooyung Han       bottom.buf.reg[1] = src.buf.reg[3];
315*5f39d1b3SJooyung Han       bottom.buf.reg[2] = src.buf.reg[5];
316*5f39d1b3SJooyung Han       bottom.buf.reg[3] = src.buf.reg[7];
317*5f39d1b3SJooyung Han       const auto transpose_bottom = Transpose(bottom);
318*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
319*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
320*5f39d1b3SJooyung Han       }
321*5f39d1b3SJooyung Han     }
322*5f39d1b3SJooyung Han   }
323*5f39d1b3SJooyung Han };
324*5f39d1b3SJooyung Han 
325*5f39d1b3SJooyung Han template <typename DstType>
326*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
327*5f39d1b3SJooyung Han   static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
328*5f39d1b3SJooyung Han                   int col) {
329*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
330*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
331*5f39d1b3SJooyung Han         StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
332*5f39d1b3SJooyung Han       }
333*5f39d1b3SJooyung Han     } else {
334*5f39d1b3SJooyung Han       std::int16_t buf[32];
335*5f39d1b3SJooyung Han       StoreInt16x8(buf + 0, src.buf.reg[0]);
336*5f39d1b3SJooyung Han       StoreInt16x8(buf + 8, src.buf.reg[1]);
337*5f39d1b3SJooyung Han       StoreInt16x8(buf + 16, src.buf.reg[2]);
338*5f39d1b3SJooyung Han       StoreInt16x8(buf + 24, src.buf.reg[3]);
339*5f39d1b3SJooyung Han       for (int i = 0; i < 8; i++) {
340*5f39d1b3SJooyung Han         for (int j = 0; j < 4; j++) {
341*5f39d1b3SJooyung Han           *dst->data(row + i, col + j) = buf[i + 8 * j];
342*5f39d1b3SJooyung Han         }
343*5f39d1b3SJooyung Han       }
344*5f39d1b3SJooyung Han     }
345*5f39d1b3SJooyung Han   }
346*5f39d1b3SJooyung Han };
347*5f39d1b3SJooyung Han 
348*5f39d1b3SJooyung Han template <typename DstType>
349*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
350*5f39d1b3SJooyung Han   static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
351*5f39d1b3SJooyung Han                   int col) {
352*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
353*5f39d1b3SJooyung Han       for (int i = 0; i < 8; i++) {
354*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
355*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
356*5f39d1b3SJooyung Han       }
357*5f39d1b3SJooyung Han     } else {
358*5f39d1b3SJooyung Han       RegBlockInt32<4, 4> top_left;
359*5f39d1b3SJooyung Han       top_left.buf.reg[0] = src.buf.reg[0];
360*5f39d1b3SJooyung Han       top_left.buf.reg[1] = src.buf.reg[2];
361*5f39d1b3SJooyung Han       top_left.buf.reg[2] = src.buf.reg[4];
362*5f39d1b3SJooyung Han       top_left.buf.reg[3] = src.buf.reg[6];
363*5f39d1b3SJooyung Han       const auto transpose_top_left = Transpose(top_left);
364*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
365*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
366*5f39d1b3SJooyung Han       }
367*5f39d1b3SJooyung Han       RegBlockInt32<4, 4> bottom_left;
368*5f39d1b3SJooyung Han       bottom_left.buf.reg[0] = src.buf.reg[1];
369*5f39d1b3SJooyung Han       bottom_left.buf.reg[1] = src.buf.reg[3];
370*5f39d1b3SJooyung Han       bottom_left.buf.reg[2] = src.buf.reg[5];
371*5f39d1b3SJooyung Han       bottom_left.buf.reg[3] = src.buf.reg[7];
372*5f39d1b3SJooyung Han       const auto transpose_bottom_left = Transpose(bottom_left);
373*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
374*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + 4 + i, col),
375*5f39d1b3SJooyung Han                      transpose_bottom_left.buf.reg[i]);
376*5f39d1b3SJooyung Han       }
377*5f39d1b3SJooyung Han       RegBlockInt32<4, 4> top_right;
378*5f39d1b3SJooyung Han       top_right.buf.reg[0] = src.buf.reg[8];
379*5f39d1b3SJooyung Han       top_right.buf.reg[1] = src.buf.reg[10];
380*5f39d1b3SJooyung Han       top_right.buf.reg[2] = src.buf.reg[12];
381*5f39d1b3SJooyung Han       top_right.buf.reg[3] = src.buf.reg[14];
382*5f39d1b3SJooyung Han       const auto transpose_top_right = Transpose(top_right);
383*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
384*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + i, col + 4),
385*5f39d1b3SJooyung Han                      transpose_top_right.buf.reg[i]);
386*5f39d1b3SJooyung Han       }
387*5f39d1b3SJooyung Han       RegBlockInt32<4, 4> bottom_right;
388*5f39d1b3SJooyung Han       bottom_right.buf.reg[0] = src.buf.reg[9];
389*5f39d1b3SJooyung Han       bottom_right.buf.reg[1] = src.buf.reg[11];
390*5f39d1b3SJooyung Han       bottom_right.buf.reg[2] = src.buf.reg[13];
391*5f39d1b3SJooyung Han       bottom_right.buf.reg[3] = src.buf.reg[15];
392*5f39d1b3SJooyung Han       const auto transpose_bottom_right = Transpose(bottom_right);
393*5f39d1b3SJooyung Han       for (int i = 0; i < 4; i++) {
394*5f39d1b3SJooyung Han         StoreInt32x4(dst->data(row + 4 + i, col + 4),
395*5f39d1b3SJooyung Han                      transpose_bottom_right.buf.reg[i]);
396*5f39d1b3SJooyung Han       }
397*5f39d1b3SJooyung Han     }
398*5f39d1b3SJooyung Han   }
399*5f39d1b3SJooyung Han };
400*5f39d1b3SJooyung Han 
401*5f39d1b3SJooyung Han template <typename DstType>
402*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
403*5f39d1b3SJooyung Han   static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
404*5f39d1b3SJooyung Han                   int col) {
405*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
406*5f39d1b3SJooyung Han       for (int i = 0; i < 8; i++) {
407*5f39d1b3SJooyung Han         StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
408*5f39d1b3SJooyung Han       }
409*5f39d1b3SJooyung Han     } else {
410*5f39d1b3SJooyung Han       // top-left 4x4
411*5f39d1b3SJooyung Han       __m128i t0 = _mm_unpacklo_epi16(src.buf.reg[0], src.buf.reg[1]);
412*5f39d1b3SJooyung Han       __m128i t1 = _mm_unpacklo_epi16(src.buf.reg[2], src.buf.reg[3]);
413*5f39d1b3SJooyung Han       __m128i u0 = _mm_unpacklo_epi32(t0, t1);
414*5f39d1b3SJooyung Han       __m128i u1 = _mm_unpackhi_epi32(t0, t1);
415*5f39d1b3SJooyung Han       // top-right 4x4
416*5f39d1b3SJooyung Han       __m128i t2 = _mm_unpacklo_epi16(src.buf.reg[4], src.buf.reg[5]);
417*5f39d1b3SJooyung Han       __m128i t3 = _mm_unpacklo_epi16(src.buf.reg[6], src.buf.reg[7]);
418*5f39d1b3SJooyung Han       __m128i u2 = _mm_unpacklo_epi32(t2, t3);
419*5f39d1b3SJooyung Han       __m128i u3 = _mm_unpackhi_epi32(t2, t3);
420*5f39d1b3SJooyung Han       // bottom-left 4x4
421*5f39d1b3SJooyung Han       __m128i t4 = _mm_unpackhi_epi16(src.buf.reg[0], src.buf.reg[1]);
422*5f39d1b3SJooyung Han       __m128i t5 = _mm_unpackhi_epi16(src.buf.reg[2], src.buf.reg[3]);
423*5f39d1b3SJooyung Han       __m128i u4 = _mm_unpacklo_epi32(t4, t5);
424*5f39d1b3SJooyung Han       __m128i u5 = _mm_unpackhi_epi32(t4, t5);
425*5f39d1b3SJooyung Han       // bottom-right 4x4
426*5f39d1b3SJooyung Han       __m128i t6 = _mm_unpackhi_epi16(src.buf.reg[4], src.buf.reg[5]);
427*5f39d1b3SJooyung Han       __m128i t7 = _mm_unpackhi_epi16(src.buf.reg[6], src.buf.reg[7]);
428*5f39d1b3SJooyung Han       __m128i u6 = _mm_unpacklo_epi32(t6, t7);
429*5f39d1b3SJooyung Han       __m128i u7 = _mm_unpackhi_epi32(t6, t7);
430*5f39d1b3SJooyung Han 
431*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row + 0, col), _mm_unpacklo_epi64(u0, u2));
432*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row + 1, col), _mm_unpackhi_epi64(u0, u2));
433*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row + 2, col), _mm_unpacklo_epi64(u1, u3));
434*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row + 3, col), _mm_unpackhi_epi64(u1, u3));
435*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row + 4, col), _mm_unpacklo_epi64(u4, u6));
436*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row + 5, col), _mm_unpackhi_epi64(u4, u6));
437*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row + 6, col), _mm_unpacklo_epi64(u5, u7));
438*5f39d1b3SJooyung Han       StoreInt16x8(dst->data(row + 7, col), _mm_unpackhi_epi64(u5, u7));
439*5f39d1b3SJooyung Han     }
440*5f39d1b3SJooyung Han   }
441*5f39d1b3SJooyung Han };
442*5f39d1b3SJooyung Han 
443*5f39d1b3SJooyung Han template <typename DstType>
444*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
445*5f39d1b3SJooyung Han   static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
446*5f39d1b3SJooyung Han                   int col) {
447*5f39d1b3SJooyung Han     if (DstType::kOrder == MapOrder::ColMajor) {
448*5f39d1b3SJooyung Han       *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
449*5f39d1b3SJooyung Han       *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
450*5f39d1b3SJooyung Han       *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
451*5f39d1b3SJooyung Han       *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
452*5f39d1b3SJooyung Han     } else {
453*5f39d1b3SJooyung Han       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
454*5f39d1b3SJooyung Han     }
455*5f39d1b3SJooyung Han   }
456*5f39d1b3SJooyung Han };
457*5f39d1b3SJooyung Han 
458*5f39d1b3SJooyung Han template <typename DstType>
459*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
460*5f39d1b3SJooyung Han   static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
461*5f39d1b3SJooyung Han                   int col) {
462*5f39d1b3SJooyung Han     const std::uint32_t src_reg = src.buf.reg[0];
463*5f39d1b3SJooyung Han     for (int i = 0; i < 4; i++) {
464*5f39d1b3SJooyung Han       *dst->data(row + i, col) = (src_reg >> (8 * i));
465*5f39d1b3SJooyung Han     }
466*5f39d1b3SJooyung Han   }
467*5f39d1b3SJooyung Han };
468*5f39d1b3SJooyung Han 
469*5f39d1b3SJooyung Han template <typename DstType>
470*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
471*5f39d1b3SJooyung Han   static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
472*5f39d1b3SJooyung Han                   int col) {
473*5f39d1b3SJooyung Han     for (int i = 0; i < 4; i++) {
474*5f39d1b3SJooyung Han       *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
475*5f39d1b3SJooyung Han     }
476*5f39d1b3SJooyung Han     for (int i = 0; i < 4; i++) {
477*5f39d1b3SJooyung Han       *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
478*5f39d1b3SJooyung Han     }
479*5f39d1b3SJooyung Han   }
480*5f39d1b3SJooyung Han };
481*5f39d1b3SJooyung Han 
482*5f39d1b3SJooyung Han template <typename DstType>
483*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
484*5f39d1b3SJooyung Han   static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
485*5f39d1b3SJooyung Han                   int col) {
486*5f39d1b3SJooyung Han     for (int i = 0; i < 4; i++) {
487*5f39d1b3SJooyung Han       *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
488*5f39d1b3SJooyung Han     }
489*5f39d1b3SJooyung Han   }
490*5f39d1b3SJooyung Han };
491*5f39d1b3SJooyung Han 
492*5f39d1b3SJooyung Han template <typename DstType>
493*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
494*5f39d1b3SJooyung Han   static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
495*5f39d1b3SJooyung Han                   int col) {
496*5f39d1b3SJooyung Han     std::uint8_t buf[16];
497*5f39d1b3SJooyung Han     StoreUint8x16(buf, src.buf.reg[0]);
498*5f39d1b3SJooyung Han     for (int c = 0; c < 4; c++) {
499*5f39d1b3SJooyung Han       for (int r = 0; r < 4; r++) {
500*5f39d1b3SJooyung Han         *dst->data(row + r, col + c) = buf[r + 4 * c];
501*5f39d1b3SJooyung Han       }
502*5f39d1b3SJooyung Han     }
503*5f39d1b3SJooyung Han   }
504*5f39d1b3SJooyung Han };
505*5f39d1b3SJooyung Han 
506*5f39d1b3SJooyung Han template <typename DstType>
507*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
508*5f39d1b3SJooyung Han   static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
509*5f39d1b3SJooyung Han                   int col) {
510*5f39d1b3SJooyung Han     std::uint8_t buf[32];
511*5f39d1b3SJooyung Han     StoreUint8x16(buf, src.buf.reg[0]);
512*5f39d1b3SJooyung Han     StoreUint8x16(buf + 16, src.buf.reg[1]);
513*5f39d1b3SJooyung Han     for (int c = 0; c < 4; c++) {
514*5f39d1b3SJooyung Han       for (int r = 0; r < 8; r++) {
515*5f39d1b3SJooyung Han         *dst->data(row + r, col + c) = buf[r + 8 * c];
516*5f39d1b3SJooyung Han       }
517*5f39d1b3SJooyung Han     }
518*5f39d1b3SJooyung Han   }
519*5f39d1b3SJooyung Han };
520*5f39d1b3SJooyung Han 
521*5f39d1b3SJooyung Han template <typename DstType>
522*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
523*5f39d1b3SJooyung Han   static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
524*5f39d1b3SJooyung Han                   int col) {
525*5f39d1b3SJooyung Han     std::uint8_t buf[64];
526*5f39d1b3SJooyung Han     StoreUint8x16(buf, src.buf.reg[0]);
527*5f39d1b3SJooyung Han     StoreUint8x16(buf + 16, src.buf.reg[1]);
528*5f39d1b3SJooyung Han     StoreUint8x16(buf + 32, src.buf.reg[2]);
529*5f39d1b3SJooyung Han     StoreUint8x16(buf + 48, src.buf.reg[3]);
530*5f39d1b3SJooyung Han     for (int c = 0; c < 8; c++) {
531*5f39d1b3SJooyung Han       for (int r = 0; r < 8; r++) {
532*5f39d1b3SJooyung Han         *dst->data(row + r, col + c) = buf[r + 8 * c];
533*5f39d1b3SJooyung Han       }
534*5f39d1b3SJooyung Han     }
535*5f39d1b3SJooyung Han   }
536*5f39d1b3SJooyung Han };
537*5f39d1b3SJooyung Han 
538*5f39d1b3SJooyung Han // Specialization for MatrixMap, for performance.
539*5f39d1b3SJooyung Han template <typename tScalar, MapOrder tOrder>
540*5f39d1b3SJooyung Han struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, MatrixMap<tScalar, tOrder>> {
541*5f39d1b3SJooyung Han   static void Run(const RegBlockUint8<8, 8>& src,
542*5f39d1b3SJooyung Han                   MatrixMap<tScalar, tOrder>* dst, int row, int col) {
543*5f39d1b3SJooyung Han     std::uint8_t buf[64];
544*5f39d1b3SJooyung Han     StoreUint8x16(buf, src.buf.reg[0]);
545*5f39d1b3SJooyung Han     StoreUint8x16(buf + 16, src.buf.reg[1]);
546*5f39d1b3SJooyung Han     StoreUint8x16(buf + 32, src.buf.reg[2]);
547*5f39d1b3SJooyung Han     StoreUint8x16(buf + 48, src.buf.reg[3]);
548*5f39d1b3SJooyung Han     // Make a local copy so that the compiler can prove that data_ does not
549*5f39d1b3SJooyung Han     // alias &data_ or &stride_.
550*5f39d1b3SJooyung Han     MatrixMap<tScalar, tOrder> local = *dst;
551*5f39d1b3SJooyung Han     for (int c = 0; c < 8; c++) {
552*5f39d1b3SJooyung Han       for (int r = 0; r < 8; r++) {
553*5f39d1b3SJooyung Han         *local.data(row + r, col + c) = buf[r + 8 * c];
554*5f39d1b3SJooyung Han       }
555*5f39d1b3SJooyung Han     }
556*5f39d1b3SJooyung Han   }
557*5f39d1b3SJooyung Han };
558*5f39d1b3SJooyung Han 
559*5f39d1b3SJooyung Han }  // namespace gemmlowp
560*5f39d1b3SJooyung Han 
561*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
562