xref: /aosp_15_r20/external/gemmlowp/internal/pack_msa.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // pack_msa.h: optimized MSA specializations of the templates in pack.h.
16*5f39d1b3SJooyung Han 
17*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_PACK_MSA_H_
18*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_PACK_MSA_H_
19*5f39d1b3SJooyung Han 
20*5f39d1b3SJooyung Han #include "pack.h"
21*5f39d1b3SJooyung Han 
22*5f39d1b3SJooyung Han #include <msa.h>
23*5f39d1b3SJooyung Han 
24*5f39d1b3SJooyung Han namespace gemmlowp {
25*5f39d1b3SJooyung Han 
26*5f39d1b3SJooyung Han typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
27*5f39d1b3SJooyung Han     WidthMajorUint8SideMap;
28*5f39d1b3SJooyung Han 
29*5f39d1b3SJooyung Han template <int Cells>
30*5f39d1b3SJooyung Han using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
31*5f39d1b3SJooyung Han 
32*5f39d1b3SJooyung Han template <int Cells>
33*5f39d1b3SJooyung Han class PackingRegisterBlock<
34*5f39d1b3SJooyung Han     WidthMajorUint8SideMap,
35*5f39d1b3SJooyung Han     PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>>
36*5f39d1b3SJooyung Han     : public PackingRegisterBlockBase<
37*5f39d1b3SJooyung Han           WidthMajorUint8SideMap,
38*5f39d1b3SJooyung Han           PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> {
39*5f39d1b3SJooyung Han  public:
40*5f39d1b3SJooyung Han   typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
41*5f39d1b3SJooyung Han   typedef typename KernelSideFormat::Cell CellFormat;
42*5f39d1b3SJooyung Han   static constexpr int kCells = KernelSideFormat::kCells;
43*5f39d1b3SJooyung Han   static const int kCellWidth = CellFormat::kWidth;
44*5f39d1b3SJooyung Han   static const int kKernelWidth = CellFormat::kWidth * kCells;
45*5f39d1b3SJooyung Han   static const int kCellDepth = CellFormat::kDepth;
46*5f39d1b3SJooyung Han   static const int kCellSize = CellFormat::kSize;
47*5f39d1b3SJooyung Han 
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)48*5f39d1b3SJooyung Han   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
49*5f39d1b3SJooyung Han     std::uint8_t* dst_ptr = dst->current_data();
50*5f39d1b3SJooyung Han     const std::uint8_t* const src_ptr = this->complete_src_.data();
51*5f39d1b3SJooyung Han     const int stride = this->complete_src_.stride();
52*5f39d1b3SJooyung Han     // Load source WidthMajor data
53*5f39d1b3SJooyung Han     v16i8 src_lines[4 * kCells];
54*5f39d1b3SJooyung Han     for (int i = 0; i < 4 * kCells; i++) {
55*5f39d1b3SJooyung Han       src_lines[i] = __builtin_msa_ld_b(
56*5f39d1b3SJooyung Han           const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
57*5f39d1b3SJooyung Han     }
58*5f39d1b3SJooyung Han     // Reorder the data within registers to make DepthMajor 4x2 cells
59*5f39d1b3SJooyung Han     v16i8 src_lines_intertwined_2x[2 * kCells][2];
60*5f39d1b3SJooyung Han     for (int i = 0; i < kCells; i++) {
61*5f39d1b3SJooyung Han       src_lines_intertwined_2x[2 * i][0] =
62*5f39d1b3SJooyung Han           __builtin_msa_ilvr_b(src_lines[4 * i + 2], src_lines[4 * i]);
63*5f39d1b3SJooyung Han       src_lines_intertwined_2x[2 * i][1] =
64*5f39d1b3SJooyung Han           __builtin_msa_ilvl_b(src_lines[4 * i + 2], src_lines[4 * i]);
65*5f39d1b3SJooyung Han       src_lines_intertwined_2x[2 * i + 1][0] =
66*5f39d1b3SJooyung Han           __builtin_msa_ilvr_b(src_lines[4 * i + 3], src_lines[4 * i + 1]);
67*5f39d1b3SJooyung Han       src_lines_intertwined_2x[2 * i + 1][1] =
68*5f39d1b3SJooyung Han           __builtin_msa_ilvl_b(src_lines[4 * i + 3], src_lines[4 * i + 1]);
69*5f39d1b3SJooyung Han     }
70*5f39d1b3SJooyung Han     v16i8 src_lines_intertwined_4x[2 * kCells][2];
71*5f39d1b3SJooyung Han     for (int i = 0; i < kCells; i++) {
72*5f39d1b3SJooyung Han       src_lines_intertwined_4x[2 * i][0] =
73*5f39d1b3SJooyung Han           __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][0],
74*5f39d1b3SJooyung Han                                src_lines_intertwined_2x[2 * i][0]);
75*5f39d1b3SJooyung Han       src_lines_intertwined_4x[2 * i][1] =
76*5f39d1b3SJooyung Han           __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][0],
77*5f39d1b3SJooyung Han                                src_lines_intertwined_2x[2 * i][0]);
78*5f39d1b3SJooyung Han       src_lines_intertwined_4x[2 * i + 1][0] =
79*5f39d1b3SJooyung Han           __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][1],
80*5f39d1b3SJooyung Han                                src_lines_intertwined_2x[2 * i][1]);
81*5f39d1b3SJooyung Han       src_lines_intertwined_4x[2 * i + 1][1] =
82*5f39d1b3SJooyung Han           __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][1],
83*5f39d1b3SJooyung Han                                src_lines_intertwined_2x[2 * i][1]);
84*5f39d1b3SJooyung Han     }
85*5f39d1b3SJooyung Han     // Store the resulting DepthMajor 4x2 cells in the destination packed block
86*5f39d1b3SJooyung Han     for (int outer = 0; outer < 2; outer++) {
87*5f39d1b3SJooyung Han       for (int inner = 0; inner < 2; inner++) {
88*5f39d1b3SJooyung Han         if (kCells % 2 == 0) {
89*5f39d1b3SJooyung Han           for (int cell = 0; cell < kCells; cell += 2) {
90*5f39d1b3SJooyung Han             v2i64 tmp = __builtin_msa_ilvr_d(
91*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
92*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
93*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
94*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * cell + outer][inner]));
95*5f39d1b3SJooyung Han             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
96*5f39d1b3SJooyung Han             dst_ptr += 16;
97*5f39d1b3SJooyung Han           }
98*5f39d1b3SJooyung Han           for (int cell = 0; cell < kCells; cell += 2) {
99*5f39d1b3SJooyung Han             v2i64 tmp = __builtin_msa_ilvl_d(
100*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
101*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
102*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
103*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * cell + outer][inner]));
104*5f39d1b3SJooyung Han             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
105*5f39d1b3SJooyung Han             dst_ptr += 16;
106*5f39d1b3SJooyung Han           }
107*5f39d1b3SJooyung Han         } else {
108*5f39d1b3SJooyung Han           // Store even number of low vector halves.
109*5f39d1b3SJooyung Han           for (int cell = 0; cell < kCells - 1; cell += 2) {
110*5f39d1b3SJooyung Han             v2i64 tmp = __builtin_msa_ilvr_d(
111*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
112*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
113*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
114*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * cell + outer][inner]));
115*5f39d1b3SJooyung Han             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
116*5f39d1b3SJooyung Han             dst_ptr += 16;
117*5f39d1b3SJooyung Han           }
118*5f39d1b3SJooyung Han           // Store last low half and first high half.
119*5f39d1b3SJooyung Han           v2i64 tmp = reinterpret_cast<v2i64>(
120*5f39d1b3SJooyung Han               src_lines_intertwined_4x[2 * 0 + outer][inner]);
121*5f39d1b3SJooyung Han           tmp = __builtin_msa_insve_d(
122*5f39d1b3SJooyung Han               tmp, 0,
123*5f39d1b3SJooyung Han               reinterpret_cast<v2i64>(
124*5f39d1b3SJooyung Han                   src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner]));
125*5f39d1b3SJooyung Han           __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
126*5f39d1b3SJooyung Han           dst_ptr += 16;
127*5f39d1b3SJooyung Han           // Store even number of high vector halves.
128*5f39d1b3SJooyung Han           for (int cell = 1; cell < kCells; cell += 2) {
129*5f39d1b3SJooyung Han             v2i64 tmp = __builtin_msa_ilvl_d(
130*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
131*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
132*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
133*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * cell + outer][inner]));
134*5f39d1b3SJooyung Han             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
135*5f39d1b3SJooyung Han             dst_ptr += 16;
136*5f39d1b3SJooyung Han           }
137*5f39d1b3SJooyung Han         }
138*5f39d1b3SJooyung Han       }
139*5f39d1b3SJooyung Han     }
140*5f39d1b3SJooyung Han     // Compute sums across the depth dimension
141*5f39d1b3SJooyung Han     v8i16 sums_of_2_cells[kCells][4];
142*5f39d1b3SJooyung Han     const v16i8 zeroes = __builtin_msa_ldi_b(0);
143*5f39d1b3SJooyung Han     for (int outer = 0; outer < 2; outer++) {
144*5f39d1b3SJooyung Han       for (int inner = 0; inner < 2; inner++) {
145*5f39d1b3SJooyung Han         int i = 2 * outer + inner;
146*5f39d1b3SJooyung Han         for (int cell = 0; cell < kCells; cell++) {
147*5f39d1b3SJooyung Han           v8i16 tmp0 = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(
148*5f39d1b3SJooyung Han               zeroes, src_lines_intertwined_4x[2 * cell + outer][inner]));
149*5f39d1b3SJooyung Han           v8i16 tmp1 = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(
150*5f39d1b3SJooyung Han               zeroes, src_lines_intertwined_4x[2 * cell + outer][inner]));
151*5f39d1b3SJooyung Han           sums_of_2_cells[cell][i] = __builtin_msa_addv_h(tmp0, tmp1);
152*5f39d1b3SJooyung Han         }
153*5f39d1b3SJooyung Han       }
154*5f39d1b3SJooyung Han     }
155*5f39d1b3SJooyung Han     v4i32 sums_of_4_cells[kCells][4];
156*5f39d1b3SJooyung Han     for (int i = 0; i < 4; i++) {
157*5f39d1b3SJooyung Han       for (int cell = 0; cell < kCells; cell++) {
158*5f39d1b3SJooyung Han         v4i32 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(
159*5f39d1b3SJooyung Han             reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i]));
160*5f39d1b3SJooyung Han         v4i32 tmp1 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(
161*5f39d1b3SJooyung Han             reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i]));
162*5f39d1b3SJooyung Han         sums_of_4_cells[cell][i] = __builtin_msa_addv_w(tmp0, tmp1);
163*5f39d1b3SJooyung Han       }
164*5f39d1b3SJooyung Han     }
165*5f39d1b3SJooyung Han     // Update the sums_of_each_slice vector
166*5f39d1b3SJooyung Han     for (int cell = 0; cell < kCells; cell++) {
167*5f39d1b3SJooyung Han       v4i32 s01 = __builtin_msa_addv_w(sums_of_4_cells[cell][0],
168*5f39d1b3SJooyung Han                                        sums_of_4_cells[cell][1]);
169*5f39d1b3SJooyung Han       v4i32 s23 = __builtin_msa_addv_w(sums_of_4_cells[cell][2],
170*5f39d1b3SJooyung Han                                        sums_of_4_cells[cell][3]);
171*5f39d1b3SJooyung Han       v4i32 s = __builtin_msa_addv_w(s01, s23);
172*5f39d1b3SJooyung Han       std::int32_t* sums_of_each_slice_ptr =
173*5f39d1b3SJooyung Han           dst->sums_of_each_slice() + start_width + 4 * cell;
174*5f39d1b3SJooyung Han       v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0);
175*5f39d1b3SJooyung Han       tmp = __builtin_msa_addv_w(tmp, s);
176*5f39d1b3SJooyung Han       __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0);
177*5f39d1b3SJooyung Han     }
178*5f39d1b3SJooyung Han     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
179*5f39d1b3SJooyung Han   }
180*5f39d1b3SJooyung Han };
181*5f39d1b3SJooyung Han 
182*5f39d1b3SJooyung Han template <int Cells>
183*5f39d1b3SJooyung Han using WidthMajorSideFormatNCells4x2 =
184*5f39d1b3SJooyung Han     KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
185*5f39d1b3SJooyung Han 
186*5f39d1b3SJooyung Han template <int Cells>
187*5f39d1b3SJooyung Han class PackingRegisterBlock<
188*5f39d1b3SJooyung Han     WidthMajorUint8SideMap,
189*5f39d1b3SJooyung Han     PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
190*5f39d1b3SJooyung Han     : public PackingRegisterBlockBase<
191*5f39d1b3SJooyung Han           WidthMajorUint8SideMap,
192*5f39d1b3SJooyung Han           PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
193*5f39d1b3SJooyung Han  public:
194*5f39d1b3SJooyung Han   typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
195*5f39d1b3SJooyung Han   typedef typename KernelSideFormat::Cell CellFormat;
196*5f39d1b3SJooyung Han   static constexpr int kCells = KernelSideFormat::kCells;
197*5f39d1b3SJooyung Han   static const int kCellWidth = CellFormat::kWidth;
198*5f39d1b3SJooyung Han   static const int kKernelWidth = CellFormat::kWidth * kCells;
199*5f39d1b3SJooyung Han   static const int kCellDepth = CellFormat::kDepth;
200*5f39d1b3SJooyung Han   static const int kCellSize = CellFormat::kSize;
201*5f39d1b3SJooyung Han 
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)202*5f39d1b3SJooyung Han   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
203*5f39d1b3SJooyung Han     std::uint8_t* dst_ptr = dst->current_data();
204*5f39d1b3SJooyung Han     const std::uint8_t* src_ptr = this->complete_src_.data();
205*5f39d1b3SJooyung Han     const int stride = this->complete_src_.stride();
206*5f39d1b3SJooyung Han     // Load source WidthMajor data
207*5f39d1b3SJooyung Han     v8i16 src_lines[kCells * 4];
208*5f39d1b3SJooyung Han     for (int i = 0; i < kCells; i++) {
209*5f39d1b3SJooyung Han #define GEMMLOWP_UNROLLED_LOOP_ITER(k)                           \
210*5f39d1b3SJooyung Han   src_lines[4 * i + k] =                                         \
211*5f39d1b3SJooyung Han       __builtin_msa_ld_h(const_cast<std::uint8_t*>(src_ptr), 0); \
212*5f39d1b3SJooyung Han   src_ptr += stride;
213*5f39d1b3SJooyung Han 
214*5f39d1b3SJooyung Han       GEMMLOWP_UNROLLED_LOOP_ITER(0)
215*5f39d1b3SJooyung Han       GEMMLOWP_UNROLLED_LOOP_ITER(1)
216*5f39d1b3SJooyung Han       GEMMLOWP_UNROLLED_LOOP_ITER(2)
217*5f39d1b3SJooyung Han       GEMMLOWP_UNROLLED_LOOP_ITER(3)
218*5f39d1b3SJooyung Han 
219*5f39d1b3SJooyung Han #undef GEMMLOWP_UNROLLED_LOOP_ITER
220*5f39d1b3SJooyung Han     }
221*5f39d1b3SJooyung Han     // Reorder the data within registers to make WidthMajor 4x2 cells
222*5f39d1b3SJooyung Han     v8i16 src_lines_intertwined_2x[2 * kCells][2];
223*5f39d1b3SJooyung Han     for (int i = 0; i < kCells; i++) {
224*5f39d1b3SJooyung Han       src_lines_intertwined_2x[2 * i][0] =
225*5f39d1b3SJooyung Han           __builtin_msa_ilvr_h(src_lines[4 * i + 2], src_lines[4 * i]);
226*5f39d1b3SJooyung Han       src_lines_intertwined_2x[2 * i][1] =
227*5f39d1b3SJooyung Han           __builtin_msa_ilvl_h(src_lines[4 * i + 2], src_lines[4 * i]);
228*5f39d1b3SJooyung Han       src_lines_intertwined_2x[2 * i + 1][0] =
229*5f39d1b3SJooyung Han           __builtin_msa_ilvr_h(src_lines[4 * i + 3], src_lines[4 * i + 1]);
230*5f39d1b3SJooyung Han       src_lines_intertwined_2x[2 * i + 1][1] =
231*5f39d1b3SJooyung Han           __builtin_msa_ilvl_h(src_lines[4 * i + 3], src_lines[4 * i + 1]);
232*5f39d1b3SJooyung Han     }
233*5f39d1b3SJooyung Han     v8i16 src_lines_intertwined_4x[2 * kCells][2];
234*5f39d1b3SJooyung Han     for (int i = 0; i < kCells; i++) {
235*5f39d1b3SJooyung Han       src_lines_intertwined_4x[2 * i][0] =
236*5f39d1b3SJooyung Han           __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][0],
237*5f39d1b3SJooyung Han                                src_lines_intertwined_2x[2 * i][0]);
238*5f39d1b3SJooyung Han       src_lines_intertwined_4x[2 * i][1] =
239*5f39d1b3SJooyung Han           __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][0],
240*5f39d1b3SJooyung Han                                src_lines_intertwined_2x[2 * i][0]);
241*5f39d1b3SJooyung Han       src_lines_intertwined_4x[2 * i + 1][0] =
242*5f39d1b3SJooyung Han           __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][1],
243*5f39d1b3SJooyung Han                                src_lines_intertwined_2x[2 * i][1]);
244*5f39d1b3SJooyung Han       src_lines_intertwined_4x[2 * i + 1][1] =
245*5f39d1b3SJooyung Han           __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][1],
246*5f39d1b3SJooyung Han                                src_lines_intertwined_2x[2 * i][1]);
247*5f39d1b3SJooyung Han     }
248*5f39d1b3SJooyung Han     // Store the resulting WidthMajor 4x2 cells in the destination packed block
249*5f39d1b3SJooyung Han     for (int outer = 0; outer < 2; outer++) {
250*5f39d1b3SJooyung Han       for (int inner = 0; inner < 2; inner++) {
251*5f39d1b3SJooyung Han         if (kCells % 2 == 0) {
252*5f39d1b3SJooyung Han           for (int cell = 0; cell < kCells; cell += 2) {
253*5f39d1b3SJooyung Han             v2i64 tmp = __builtin_msa_ilvr_d(
254*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
255*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
256*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
257*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * cell + outer][inner]));
258*5f39d1b3SJooyung Han             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
259*5f39d1b3SJooyung Han             dst_ptr += 16;
260*5f39d1b3SJooyung Han           }
261*5f39d1b3SJooyung Han           for (int cell = 0; cell < kCells; cell += 2) {
262*5f39d1b3SJooyung Han             v2i64 tmp = __builtin_msa_ilvl_d(
263*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
264*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
265*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
266*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * cell + outer][inner]));
267*5f39d1b3SJooyung Han             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
268*5f39d1b3SJooyung Han             dst_ptr += 16;
269*5f39d1b3SJooyung Han           }
270*5f39d1b3SJooyung Han         } else {
271*5f39d1b3SJooyung Han           // Store even number of low vector halves.
272*5f39d1b3SJooyung Han           for (int cell = 0; cell < kCells - 1; cell += 2) {
273*5f39d1b3SJooyung Han             v2i64 tmp = __builtin_msa_ilvr_d(
274*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
275*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
276*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
277*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * cell + outer][inner]));
278*5f39d1b3SJooyung Han             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
279*5f39d1b3SJooyung Han             dst_ptr += 16;
280*5f39d1b3SJooyung Han           }
281*5f39d1b3SJooyung Han           // Store last low half and first high half.
282*5f39d1b3SJooyung Han           v2i64 tmp = reinterpret_cast<v2i64>(
283*5f39d1b3SJooyung Han               src_lines_intertwined_4x[2 * 0 + outer][inner]);
284*5f39d1b3SJooyung Han           tmp = __builtin_msa_insve_d(
285*5f39d1b3SJooyung Han               tmp, 0,
286*5f39d1b3SJooyung Han               reinterpret_cast<v2i64>(
287*5f39d1b3SJooyung Han                   src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner]));
288*5f39d1b3SJooyung Han           __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
289*5f39d1b3SJooyung Han           dst_ptr += 16;
290*5f39d1b3SJooyung Han           // Store even number of high vector halves.
291*5f39d1b3SJooyung Han           for (int cell = 1; cell < kCells; cell += 2) {
292*5f39d1b3SJooyung Han             v2i64 tmp = __builtin_msa_ilvl_d(
293*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
294*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
295*5f39d1b3SJooyung Han                 reinterpret_cast<v2i64>(
296*5f39d1b3SJooyung Han                     src_lines_intertwined_4x[2 * cell + outer][inner]));
297*5f39d1b3SJooyung Han             __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
298*5f39d1b3SJooyung Han             dst_ptr += 16;
299*5f39d1b3SJooyung Han           }
300*5f39d1b3SJooyung Han         }
301*5f39d1b3SJooyung Han       }
302*5f39d1b3SJooyung Han     }
303*5f39d1b3SJooyung Han     // Compute sums across the depth dimension
304*5f39d1b3SJooyung Han     v8i16 sums_of_2[kCells][4];
305*5f39d1b3SJooyung Han     for (int outer = 0; outer < 2; outer++) {
306*5f39d1b3SJooyung Han       for (int inner = 0; inner < 2; inner++) {
307*5f39d1b3SJooyung Han         int i = 2 * outer + inner;
308*5f39d1b3SJooyung Han         for (int cell = 0; cell < kCells; cell++) {
309*5f39d1b3SJooyung Han           sums_of_2[cell][i] = reinterpret_cast<v8i16>(__builtin_msa_hadd_u_h(
310*5f39d1b3SJooyung Han               reinterpret_cast<v16u8>(
311*5f39d1b3SJooyung Han                   src_lines_intertwined_4x[2 * cell + outer][inner]),
312*5f39d1b3SJooyung Han               reinterpret_cast<v16u8>(
313*5f39d1b3SJooyung Han                   src_lines_intertwined_4x[2 * cell + outer][inner])));
314*5f39d1b3SJooyung Han         }
315*5f39d1b3SJooyung Han       }
316*5f39d1b3SJooyung Han     }
317*5f39d1b3SJooyung Han     v8i16 sums_of_4[kCells][2];
318*5f39d1b3SJooyung Han     for (int i = 0; i < 2; i++) {
319*5f39d1b3SJooyung Han       for (int cell = 0; cell < kCells; cell++) {
320*5f39d1b3SJooyung Han         sums_of_4[cell][i] = __builtin_msa_addv_h(sums_of_2[cell][2 * i],
321*5f39d1b3SJooyung Han                                                   sums_of_2[cell][2 * i + 1]);
322*5f39d1b3SJooyung Han       }
323*5f39d1b3SJooyung Han     }
324*5f39d1b3SJooyung Han     v8i16 sums_of_8[kCells];
325*5f39d1b3SJooyung Han     for (int cell = 0; cell < kCells; cell++) {
326*5f39d1b3SJooyung Han       sums_of_8[cell] =
327*5f39d1b3SJooyung Han           __builtin_msa_addv_h(sums_of_4[cell][0], sums_of_4[cell][1]);
328*5f39d1b3SJooyung Han     }
329*5f39d1b3SJooyung Han 
330*5f39d1b3SJooyung Han     v4i32 sums_of_16[kCells];
331*5f39d1b3SJooyung Han     const v8i16 zeroes = __builtin_msa_ldi_h(0);
332*5f39d1b3SJooyung Han     for (int cell = 0; cell < kCells; cell++) {
333*5f39d1b3SJooyung Han       sums_of_16[cell] = reinterpret_cast<v4i32>(
334*5f39d1b3SJooyung Han           __builtin_msa_ilvr_h(zeroes, sums_of_8[cell]));
335*5f39d1b3SJooyung Han       v8i16 tmp = __builtin_msa_ilvl_h(zeroes, sums_of_8[cell]);
336*5f39d1b3SJooyung Han       sums_of_16[cell] =
337*5f39d1b3SJooyung Han           __builtin_msa_addv_w(sums_of_16[cell], reinterpret_cast<v4i32>(tmp));
338*5f39d1b3SJooyung Han     }
339*5f39d1b3SJooyung Han     // Update the sums_of_each_slice vector
340*5f39d1b3SJooyung Han     for (int cell = 0; cell < kCells; cell++) {
341*5f39d1b3SJooyung Han       std::int32_t* sums_of_each_slice_ptr =
342*5f39d1b3SJooyung Han           dst->sums_of_each_slice() + start_width + 4 * cell;
343*5f39d1b3SJooyung Han       v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0);
344*5f39d1b3SJooyung Han       tmp = __builtin_msa_addv_w(tmp, sums_of_16[cell]);
345*5f39d1b3SJooyung Han       __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0);
346*5f39d1b3SJooyung Han     }
347*5f39d1b3SJooyung Han     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
348*5f39d1b3SJooyung Han   }
349*5f39d1b3SJooyung Han };
350*5f39d1b3SJooyung Han 
351*5f39d1b3SJooyung Han template <int Width>
352*5f39d1b3SJooyung Han using Int8FastKernelFormat =
353*5f39d1b3SJooyung Han     KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
354*5f39d1b3SJooyung Han 
355*5f39d1b3SJooyung Han template <int Width>
356*5f39d1b3SJooyung Han class PackingRegisterBlock<WidthMajorUint8SideMap,
357*5f39d1b3SJooyung Han                            PackedSideBlock<Int8FastKernelFormat<Width>>>
358*5f39d1b3SJooyung Han     : public PackingRegisterBlockBase<
359*5f39d1b3SJooyung Han           WidthMajorUint8SideMap,
360*5f39d1b3SJooyung Han           PackedSideBlock<Int8FastKernelFormat<Width>>> {
361*5f39d1b3SJooyung Han  public:
362*5f39d1b3SJooyung Han   static_assert(Width == 2 || Width == 4, "");
363*5f39d1b3SJooyung Han   typedef Int8FastKernelFormat<Width> KernelSideFormat;
364*5f39d1b3SJooyung Han   typedef typename KernelSideFormat::Cell CellFormat;
365*5f39d1b3SJooyung Han   static const int kCells = KernelSideFormat::kCells;
366*5f39d1b3SJooyung Han   static const int kCellWidth = CellFormat::kWidth;
367*5f39d1b3SJooyung Han   static const int kKernelWidth = CellFormat::kWidth * kCells;
368*5f39d1b3SJooyung Han   static const int kCellDepth = CellFormat::kDepth;
369*5f39d1b3SJooyung Han   static const int kCellSize = CellFormat::kSize;
370*5f39d1b3SJooyung Han 
Pack(PackedSideBlock<KernelSideFormat> * dst,int start_width)371*5f39d1b3SJooyung Han   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
372*5f39d1b3SJooyung Han     std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
373*5f39d1b3SJooyung Han     std::uint8_t* dst_ptr = dst->current_data();
374*5f39d1b3SJooyung Han     const std::uint8_t* const src_ptr = this->complete_src_.data();
375*5f39d1b3SJooyung Han     const int stride = this->complete_src_.stride();
376*5f39d1b3SJooyung Han     // Load source WidthMajor data.
377*5f39d1b3SJooyung Han     v16i8 src_lines[Width];
378*5f39d1b3SJooyung Han     for (int i = 0; i < Width; i++) {
379*5f39d1b3SJooyung Han       src_lines[i] = __builtin_msa_ld_b(
380*5f39d1b3SJooyung Han           const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
381*5f39d1b3SJooyung Han     }
382*5f39d1b3SJooyung Han     for (int i = 0; i < Width; i++) {
383*5f39d1b3SJooyung Han       // Subtract 128 by inverting bit 7.
384*5f39d1b3SJooyung Han       src_lines[i] = reinterpret_cast<v16i8>(
385*5f39d1b3SJooyung Han           __builtin_msa_bnegi_b(reinterpret_cast<v16u8>(src_lines[i]), 7));
386*5f39d1b3SJooyung Han     }
387*5f39d1b3SJooyung Han     for (int i = 0; i < Width; i++) {
388*5f39d1b3SJooyung Han       __builtin_msa_st_b(src_lines[i], dst_ptr + 16 * i, 0);
389*5f39d1b3SJooyung Han     }
390*5f39d1b3SJooyung Han     v8i16 sums2[Width];
391*5f39d1b3SJooyung Han     for (int i = 0; i < Width; i++) {
392*5f39d1b3SJooyung Han       sums2[i] = __builtin_msa_hadd_s_h(src_lines[i], src_lines[i]);
393*5f39d1b3SJooyung Han     }
394*5f39d1b3SJooyung Han     v4i32 sums4_wide[Width];
395*5f39d1b3SJooyung Han     for (int i = 0; i < Width; i++) {
396*5f39d1b3SJooyung Han       sums4_wide[i] = __builtin_msa_hadd_s_w(sums2[i], sums2[i]);
397*5f39d1b3SJooyung Han     }
398*5f39d1b3SJooyung Han     v8i16 sums4[Width / 2];
399*5f39d1b3SJooyung Han     for (int i = 0; i < Width / 2; i++) {
400*5f39d1b3SJooyung Han       sums4[i] = __builtin_msa_pckev_h(
401*5f39d1b3SJooyung Han           reinterpret_cast<v8i16>(sums4_wide[2 * i + 1]),
402*5f39d1b3SJooyung Han           reinterpret_cast<v8i16>(sums4_wide[2 * i]));
403*5f39d1b3SJooyung Han     }
404*5f39d1b3SJooyung Han     v4i32 sums8_wide[Width / 2];
405*5f39d1b3SJooyung Han     for (int i = 0; i < Width / 2; i++) {
406*5f39d1b3SJooyung Han       sums8_wide[i] = __builtin_msa_hadd_s_w(sums4[i], sums4[i]);
407*5f39d1b3SJooyung Han     }
408*5f39d1b3SJooyung Han     if (Width == 4) {
409*5f39d1b3SJooyung Han       v4i32 sum = __builtin_msa_ld_w(const_cast<std::int32_t*>(sums_ptr), 0);
410*5f39d1b3SJooyung Han       v8i16 sums8 = __builtin_msa_pckev_h(
411*5f39d1b3SJooyung Han           reinterpret_cast<v8i16>(sums8_wide[1]),
412*5f39d1b3SJooyung Han           reinterpret_cast<v8i16>(sums8_wide[0]));
413*5f39d1b3SJooyung Han       v4i32 sums16 = __builtin_msa_hadd_s_w(sums8, sums8);
414*5f39d1b3SJooyung Han       sum = __builtin_msa_addv_w(sum, sums16);
415*5f39d1b3SJooyung Han       __builtin_msa_st_w(sum, sums_ptr, 0);
416*5f39d1b3SJooyung Han     } else {
417*5f39d1b3SJooyung Han       assert(Width == 2);
418*5f39d1b3SJooyung Han       std::int32_t sum[2] = { sums_ptr[0], sums_ptr[1] };
419*5f39d1b3SJooyung Han       v2i64 sums16 = __builtin_msa_hadd_s_d(sums8_wide[0], sums8_wide[0]);
420*5f39d1b3SJooyung Han       sum[0] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 0);
421*5f39d1b3SJooyung Han       sum[1] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 2);
422*5f39d1b3SJooyung Han       sums_ptr[0] = sum[0];
423*5f39d1b3SJooyung Han       sums_ptr[1] = sum[1];
424*5f39d1b3SJooyung Han     }
425*5f39d1b3SJooyung Han     dst->seek_forward_n_cells(1);
426*5f39d1b3SJooyung Han   }
427*5f39d1b3SJooyung Han };
428*5f39d1b3SJooyung Han 
429*5f39d1b3SJooyung Han }  // namespace gemmlowp
430*5f39d1b3SJooyung Han 
431*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_PACK_MSA_H_
432