xref: /aosp_15_r20/external/ruy/ruy/pack_arm.h (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1*bb86c7edSAndroid Build Coastguard Worker /* Copyright 2019 Google LLC. All Rights Reserved.
2*bb86c7edSAndroid Build Coastguard Worker 
3*bb86c7edSAndroid Build Coastguard Worker Licensed under the Apache License, Version 2.0 (the "License");
4*bb86c7edSAndroid Build Coastguard Worker you may not use this file except in compliance with the License.
5*bb86c7edSAndroid Build Coastguard Worker You may obtain a copy of the License at
6*bb86c7edSAndroid Build Coastguard Worker 
7*bb86c7edSAndroid Build Coastguard Worker     http://www.apache.org/licenses/LICENSE-2.0
8*bb86c7edSAndroid Build Coastguard Worker 
9*bb86c7edSAndroid Build Coastguard Worker Unless required by applicable law or agreed to in writing, software
10*bb86c7edSAndroid Build Coastguard Worker distributed under the License is distributed on an "AS IS" BASIS,
11*bb86c7edSAndroid Build Coastguard Worker WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*bb86c7edSAndroid Build Coastguard Worker See the License for the specific language governing permissions and
13*bb86c7edSAndroid Build Coastguard Worker limitations under the License.
14*bb86c7edSAndroid Build Coastguard Worker ==============================================================================*/
15*bb86c7edSAndroid Build Coastguard Worker 
16*bb86c7edSAndroid Build Coastguard Worker #ifndef RUY_RUY_PACK_ARM_H_
17*bb86c7edSAndroid Build Coastguard Worker #define RUY_RUY_PACK_ARM_H_
18*bb86c7edSAndroid Build Coastguard Worker 
19*bb86c7edSAndroid Build Coastguard Worker #include <algorithm>
20*bb86c7edSAndroid Build Coastguard Worker #include <cstdint>
21*bb86c7edSAndroid Build Coastguard Worker #include <type_traits>
22*bb86c7edSAndroid Build Coastguard Worker 
23*bb86c7edSAndroid Build Coastguard Worker #include "ruy/asm_helpers.h"
24*bb86c7edSAndroid Build Coastguard Worker #include "ruy/check_macros.h"
25*bb86c7edSAndroid Build Coastguard Worker #include "ruy/mat.h"
26*bb86c7edSAndroid Build Coastguard Worker #include "ruy/opt_set.h"
27*bb86c7edSAndroid Build Coastguard Worker #include "ruy/pack_common.h"
28*bb86c7edSAndroid Build Coastguard Worker #include "ruy/path.h"
29*bb86c7edSAndroid Build Coastguard Worker #include "ruy/platform.h"
30*bb86c7edSAndroid Build Coastguard Worker #include "ruy/profiler/instrumentation.h"
31*bb86c7edSAndroid Build Coastguard Worker #include "ruy/tune.h"
32*bb86c7edSAndroid Build Coastguard Worker 
33*bb86c7edSAndroid Build Coastguard Worker namespace ruy {
34*bb86c7edSAndroid Build Coastguard Worker 
35*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON
36*bb86c7edSAndroid Build Coastguard Worker RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
37*bb86c7edSAndroid Build Coastguard Worker RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
38*bb86c7edSAndroid Build Coastguard Worker 
39*bb86c7edSAndroid Build Coastguard Worker RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 8)
40*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_32
41*bb86c7edSAndroid Build Coastguard Worker RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 4)
42*bb86c7edSAndroid Build Coastguard Worker #endif
43*bb86c7edSAndroid Build Coastguard Worker 
44*bb86c7edSAndroid Build Coastguard Worker template <>
45*bb86c7edSAndroid Build Coastguard Worker struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
46*bb86c7edSAndroid Build Coastguard Worker   using Type = std::int8_t;
47*bb86c7edSAndroid Build Coastguard Worker };
48*bb86c7edSAndroid Build Coastguard Worker template <>
49*bb86c7edSAndroid Build Coastguard Worker struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
50*bb86c7edSAndroid Build Coastguard Worker   using Type = std::int8_t;
51*bb86c7edSAndroid Build Coastguard Worker };
52*bb86c7edSAndroid Build Coastguard Worker #endif
53*bb86c7edSAndroid Build Coastguard Worker 
54*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON
55*bb86c7edSAndroid Build Coastguard Worker void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
56*bb86c7edSAndroid Build Coastguard Worker                              int src_rows, int src_cols, int block_row,
57*bb86c7edSAndroid Build Coastguard Worker                              int start_col, int end_col,
58*bb86c7edSAndroid Build Coastguard Worker                              std::int8_t* packed_ptr, int packed_stride,
59*bb86c7edSAndroid Build Coastguard Worker                              int packed_zero_point, std::int32_t* sums_ptr,
60*bb86c7edSAndroid Build Coastguard Worker                              int input_xor, int kernel_cols);
61*bb86c7edSAndroid Build Coastguard Worker #endif
62*bb86c7edSAndroid Build Coastguard Worker 
63*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
64*bb86c7edSAndroid Build Coastguard Worker 
65*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1,
66*bb86c7edSAndroid Build Coastguard Worker                              const void* src_ptr2, const void* src_ptr3,
67*bb86c7edSAndroid Build Coastguard Worker                              int src_inc0, int src_inc1, int src_inc2,
68*bb86c7edSAndroid Build Coastguard Worker                              int src_inc3, int src_rows, int src_zero_point,
69*bb86c7edSAndroid Build Coastguard Worker                              std::int8_t* packed_ptr, std::int32_t* sums_ptr,
70*bb86c7edSAndroid Build Coastguard Worker                              int input_xor);
71*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1,
72*bb86c7edSAndroid Build Coastguard Worker                                    const void* src_ptr2, const void* src_ptr3,
73*bb86c7edSAndroid Build Coastguard Worker                                    int src_inc0, int src_inc1, int src_inc2,
74*bb86c7edSAndroid Build Coastguard Worker                                    int src_inc3, int src_rows,
75*bb86c7edSAndroid Build Coastguard Worker                                    int src_zero_point, std::int8_t* packed_ptr,
76*bb86c7edSAndroid Build Coastguard Worker                                    std::int32_t* sums_ptr, int input_xor);
77*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
78*bb86c7edSAndroid Build Coastguard Worker                                     const void* src_ptr2, const void* src_ptr3,
79*bb86c7edSAndroid Build Coastguard Worker                                     int src_inc0, int src_inc1, int src_inc2,
80*bb86c7edSAndroid Build Coastguard Worker                                     int src_inc3, int src_rows,
81*bb86c7edSAndroid Build Coastguard Worker                                     int src_zero_point, std::int8_t* packed_ptr,
82*bb86c7edSAndroid Build Coastguard Worker                                     std::int32_t* sums_ptr, int input_xor);
83*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeonDotprodA55ish(
84*bb86c7edSAndroid Build Coastguard Worker     const void* src_ptr0, const void* src_ptr1, const void* src_ptr2,
85*bb86c7edSAndroid Build Coastguard Worker     const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2,
86*bb86c7edSAndroid Build Coastguard Worker     int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr,
87*bb86c7edSAndroid Build Coastguard Worker     std::int32_t* sums_ptr, int input_xor);
88*bb86c7edSAndroid Build Coastguard Worker void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
89*bb86c7edSAndroid Build Coastguard Worker                                     const void* src_ptr2, const void* src_ptr3,
90*bb86c7edSAndroid Build Coastguard Worker                                     int src_inc0, int src_inc1, int src_inc2,
91*bb86c7edSAndroid Build Coastguard Worker                                     int src_inc3, int src_cols,
92*bb86c7edSAndroid Build Coastguard Worker                                     int src_zero_point, std::int8_t* packed_ptr,
93*bb86c7edSAndroid Build Coastguard Worker                                     int packed_stride, std::int32_t* sums_ptr,
94*bb86c7edSAndroid Build Coastguard Worker                                     int input_xor);
95*bb86c7edSAndroid Build Coastguard Worker #elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
96*bb86c7edSAndroid Build Coastguard Worker 
97*bb86c7edSAndroid Build Coastguard Worker struct PackParams8bit {
98*bb86c7edSAndroid Build Coastguard Worker   const void* src_ptr0;
99*bb86c7edSAndroid Build Coastguard Worker   const void* src_ptr1;
100*bb86c7edSAndroid Build Coastguard Worker   const void* src_ptr2;
101*bb86c7edSAndroid Build Coastguard Worker   const void* src_ptr3;
102*bb86c7edSAndroid Build Coastguard Worker   const std::int32_t* sums_ptr;
103*bb86c7edSAndroid Build Coastguard Worker   const std::int8_t* packed_ptr;
104*bb86c7edSAndroid Build Coastguard Worker   int src_inc0;
105*bb86c7edSAndroid Build Coastguard Worker   int src_inc1;
106*bb86c7edSAndroid Build Coastguard Worker   int src_inc2;
107*bb86c7edSAndroid Build Coastguard Worker   int src_inc3;
108*bb86c7edSAndroid Build Coastguard Worker   int src_rows;
109*bb86c7edSAndroid Build Coastguard Worker   int src_zero_point;
110*bb86c7edSAndroid Build Coastguard Worker   int input_xor;
111*bb86c7edSAndroid Build Coastguard Worker };
112*bb86c7edSAndroid Build Coastguard Worker 
113*bb86c7edSAndroid Build Coastguard Worker inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1,
114*bb86c7edSAndroid Build Coastguard Worker                                const void* src_ptr2, const void* src_ptr3,
115*bb86c7edSAndroid Build Coastguard Worker                                const std::int32_t* sums_ptr,
116*bb86c7edSAndroid Build Coastguard Worker                                const std::int8_t* packed_ptr, int src_inc0,
117*bb86c7edSAndroid Build Coastguard Worker                                int src_inc1, int src_inc2, int src_inc3,
118*bb86c7edSAndroid Build Coastguard Worker                                int src_rows, int src_zero_point, int input_xor,
119*bb86c7edSAndroid Build Coastguard Worker                                PackParams8bit* params) {
120*bb86c7edSAndroid Build Coastguard Worker   params->src_ptr0 = src_ptr0;
121*bb86c7edSAndroid Build Coastguard Worker   params->src_ptr1 = src_ptr1;
122*bb86c7edSAndroid Build Coastguard Worker   params->src_ptr2 = src_ptr2;
123*bb86c7edSAndroid Build Coastguard Worker   params->src_ptr3 = src_ptr3;
124*bb86c7edSAndroid Build Coastguard Worker   params->sums_ptr = sums_ptr;
125*bb86c7edSAndroid Build Coastguard Worker   params->packed_ptr = packed_ptr;
126*bb86c7edSAndroid Build Coastguard Worker   params->src_inc0 = src_inc0;
127*bb86c7edSAndroid Build Coastguard Worker   params->src_inc1 = src_inc1;
128*bb86c7edSAndroid Build Coastguard Worker   params->src_inc2 = src_inc2;
129*bb86c7edSAndroid Build Coastguard Worker   params->src_inc3 = src_inc3;
130*bb86c7edSAndroid Build Coastguard Worker   params->src_rows = src_rows;
131*bb86c7edSAndroid Build Coastguard Worker   params->src_zero_point = src_zero_point;
132*bb86c7edSAndroid Build Coastguard Worker   params->input_xor = input_xor;
133*bb86c7edSAndroid Build Coastguard Worker }
134*bb86c7edSAndroid Build Coastguard Worker 
135*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params);
136*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params);
137*bb86c7edSAndroid Build Coastguard Worker 
138*bb86c7edSAndroid Build Coastguard Worker #endif  // (RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
139*bb86c7edSAndroid Build Coastguard Worker 
140*bb86c7edSAndroid Build Coastguard Worker #if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM)
141*bb86c7edSAndroid Build Coastguard Worker 
142*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar>
143*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar,
144*bb86c7edSAndroid Build Coastguard Worker                 std::int8_t, std::int32_t, Order::kColMajor> {
145*bb86c7edSAndroid Build Coastguard Worker   static_assert(std::is_same<Scalar, std::int8_t>::value ||
146*bb86c7edSAndroid Build Coastguard Worker                     std::is_same<Scalar, std::uint8_t>::value,
147*bb86c7edSAndroid Build Coastguard Worker                 "");
148*bb86c7edSAndroid Build Coastguard Worker   static constexpr int kInputXor =
149*bb86c7edSAndroid Build Coastguard Worker       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
150*bb86c7edSAndroid Build Coastguard Worker 
151*bb86c7edSAndroid Build Coastguard Worker   static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
152*bb86c7edSAndroid Build Coastguard Worker                   PMat<std::int8_t>* packed_matrix, int start_col,
153*bb86c7edSAndroid Build Coastguard Worker                   int end_col) {
154*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(src_matrix.layout));
155*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(packed_matrix->layout));
156*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK_EQ(start_col % 4, 0);
157*bb86c7edSAndroid Build Coastguard Worker     std::int32_t* sums = packed_matrix->sums;
158*bb86c7edSAndroid Build Coastguard Worker     Scalar zerobuf[16];
159*bb86c7edSAndroid Build Coastguard Worker     memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
160*bb86c7edSAndroid Build Coastguard Worker     for (int block_col = start_col; block_col < end_col; block_col += 4) {
161*bb86c7edSAndroid Build Coastguard Worker       int src_stride = src_matrix.layout.stride;
162*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
163*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr1 = src_ptr0 + src_stride;
164*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr2 = src_ptr1 + src_stride;
165*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr3 = src_ptr2 + src_stride;
166*bb86c7edSAndroid Build Coastguard Worker       int src_inc0 = 16;
167*bb86c7edSAndroid Build Coastguard Worker       int src_inc1 = 16;
168*bb86c7edSAndroid Build Coastguard Worker       int src_inc2 = 16;
169*bb86c7edSAndroid Build Coastguard Worker       int src_inc3 = 16;
170*bb86c7edSAndroid Build Coastguard Worker       if (block_col >= src_matrix.layout.cols - 3) {
171*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 0) {
172*bb86c7edSAndroid Build Coastguard Worker           src_ptr0 = zerobuf;
173*bb86c7edSAndroid Build Coastguard Worker           src_inc0 = 0;
174*bb86c7edSAndroid Build Coastguard Worker         }
175*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 1) {
176*bb86c7edSAndroid Build Coastguard Worker           src_ptr1 = zerobuf;
177*bb86c7edSAndroid Build Coastguard Worker           src_inc1 = 0;
178*bb86c7edSAndroid Build Coastguard Worker         }
179*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 2) {
180*bb86c7edSAndroid Build Coastguard Worker           src_ptr2 = zerobuf;
181*bb86c7edSAndroid Build Coastguard Worker           src_inc2 = 0;
182*bb86c7edSAndroid Build Coastguard Worker         }
183*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 3) {
184*bb86c7edSAndroid Build Coastguard Worker           src_ptr3 = zerobuf;
185*bb86c7edSAndroid Build Coastguard Worker           src_inc3 = 0;
186*bb86c7edSAndroid Build Coastguard Worker         }
187*bb86c7edSAndroid Build Coastguard Worker       }
188*bb86c7edSAndroid Build Coastguard Worker       std::int8_t* packed_ptr =
189*bb86c7edSAndroid Build Coastguard Worker           packed_matrix->data + packed_matrix->layout.stride * block_col;
190*bb86c7edSAndroid Build Coastguard Worker       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
191*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64
192*bb86c7edSAndroid Build Coastguard Worker       if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
193*bb86c7edSAndroid Build Coastguard Worker         Pack8bitColMajorForNeonA55ish(
194*bb86c7edSAndroid Build Coastguard Worker             src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
195*bb86c7edSAndroid Build Coastguard Worker             src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
196*bb86c7edSAndroid Build Coastguard Worker             packed_ptr, sums_ptr, kInputXor);
197*bb86c7edSAndroid Build Coastguard Worker       } else {
198*bb86c7edSAndroid Build Coastguard Worker         Pack8bitColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
199*bb86c7edSAndroid Build Coastguard Worker                                 src_inc0, src_inc1, src_inc2, src_inc3,
200*bb86c7edSAndroid Build Coastguard Worker                                 src_matrix.layout.rows, src_matrix.zero_point,
201*bb86c7edSAndroid Build Coastguard Worker                                 packed_ptr, sums_ptr, kInputXor);
202*bb86c7edSAndroid Build Coastguard Worker       }
203*bb86c7edSAndroid Build Coastguard Worker #else
204*bb86c7edSAndroid Build Coastguard Worker       (void)tuning;
205*bb86c7edSAndroid Build Coastguard Worker       // We have a more limited set of general purpose registers in ARMv7, so
206*bb86c7edSAndroid Build Coastguard Worker       // we use the "params" struct technique from the kernel code to save
207*bb86c7edSAndroid Build Coastguard Worker       // registers.
208*bb86c7edSAndroid Build Coastguard Worker       PackParams8bit params;
209*bb86c7edSAndroid Build Coastguard Worker       MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr,
210*bb86c7edSAndroid Build Coastguard Worker                          packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3,
211*bb86c7edSAndroid Build Coastguard Worker                          src_matrix.layout.rows, src_matrix.zero_point,
212*bb86c7edSAndroid Build Coastguard Worker                          kInputXor, &params);
213*bb86c7edSAndroid Build Coastguard Worker       Pack8bitColMajorForNeon4Cols(params);
214*bb86c7edSAndroid Build Coastguard Worker #endif  // RUY_PLATFORM_NEON_64
215*bb86c7edSAndroid Build Coastguard Worker     }
216*bb86c7edSAndroid Build Coastguard Worker   }
217*bb86c7edSAndroid Build Coastguard Worker };
218*bb86c7edSAndroid Build Coastguard Worker 
219*bb86c7edSAndroid Build Coastguard Worker #endif  // (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) &&
220*bb86c7edSAndroid Build Coastguard Worker         // RUY_OPT(ASM)
221*bb86c7edSAndroid Build Coastguard Worker 
222*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
223*bb86c7edSAndroid Build Coastguard Worker // The 32-bit float kernel is 4 rows X 2 columns, so we need an additional
224*bb86c7edSAndroid Build Coastguard Worker // partial specialization for the RHS, which has a FixedKernelLayout with 2
225*bb86c7edSAndroid Build Coastguard Worker // columns.
226*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar>
227*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 2>, Scalar,
228*bb86c7edSAndroid Build Coastguard Worker                 std::int8_t, std::int32_t, Order::kColMajor> {
229*bb86c7edSAndroid Build Coastguard Worker   static_assert(std::is_same<Scalar, std::int8_t>::value ||
230*bb86c7edSAndroid Build Coastguard Worker                     std::is_same<Scalar, std::uint8_t>::value,
231*bb86c7edSAndroid Build Coastguard Worker                 "");
232*bb86c7edSAndroid Build Coastguard Worker   static constexpr int kInputXor =
233*bb86c7edSAndroid Build Coastguard Worker       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
234*bb86c7edSAndroid Build Coastguard Worker   static void Run(Tuning, const Mat<Scalar>& src_matrix,
235*bb86c7edSAndroid Build Coastguard Worker                   PMat<std::int8_t>* packed_matrix, int start_col,
236*bb86c7edSAndroid Build Coastguard Worker                   int end_col) {
237*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(src_matrix.layout));
238*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(packed_matrix->layout));
239*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK_EQ(start_col % 2, 0);
240*bb86c7edSAndroid Build Coastguard Worker     std::int32_t* sums = packed_matrix->sums;
241*bb86c7edSAndroid Build Coastguard Worker     Scalar zerobuf[16];
242*bb86c7edSAndroid Build Coastguard Worker     memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
243*bb86c7edSAndroid Build Coastguard Worker     for (int block_col = start_col; block_col < end_col; block_col += 2) {
244*bb86c7edSAndroid Build Coastguard Worker       int src_stride = src_matrix.layout.stride;
245*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
246*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr1 = src_ptr0 + src_stride;
247*bb86c7edSAndroid Build Coastguard Worker       int src_inc0 = 16;
248*bb86c7edSAndroid Build Coastguard Worker       int src_inc1 = 16;
249*bb86c7edSAndroid Build Coastguard Worker       if (block_col >= src_matrix.layout.cols - 2) {
250*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 0) {
251*bb86c7edSAndroid Build Coastguard Worker           src_ptr0 = zerobuf;
252*bb86c7edSAndroid Build Coastguard Worker           src_inc0 = 0;
253*bb86c7edSAndroid Build Coastguard Worker         }
254*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 1) {
255*bb86c7edSAndroid Build Coastguard Worker           src_ptr1 = zerobuf;
256*bb86c7edSAndroid Build Coastguard Worker           src_inc1 = 0;
257*bb86c7edSAndroid Build Coastguard Worker         }
258*bb86c7edSAndroid Build Coastguard Worker       }
259*bb86c7edSAndroid Build Coastguard Worker       std::int8_t* packed_ptr =
260*bb86c7edSAndroid Build Coastguard Worker           packed_matrix->data + packed_matrix->layout.stride * block_col;
261*bb86c7edSAndroid Build Coastguard Worker       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
262*bb86c7edSAndroid Build Coastguard Worker       PackParams8bit params;
263*bb86c7edSAndroid Build Coastguard Worker       MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr,
264*bb86c7edSAndroid Build Coastguard Worker                          packed_ptr, src_inc0, src_inc1, -1, -1,
265*bb86c7edSAndroid Build Coastguard Worker                          src_matrix.layout.rows, src_matrix.zero_point,
266*bb86c7edSAndroid Build Coastguard Worker                          kInputXor, &params);
267*bb86c7edSAndroid Build Coastguard Worker       Pack8bitColMajorForNeon2Cols(params);
268*bb86c7edSAndroid Build Coastguard Worker     }
269*bb86c7edSAndroid Build Coastguard Worker   }
270*bb86c7edSAndroid Build Coastguard Worker };
271*bb86c7edSAndroid Build Coastguard Worker #endif  // (RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)
272*bb86c7edSAndroid Build Coastguard Worker 
273*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
274*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar>
275*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
276*bb86c7edSAndroid Build Coastguard Worker                 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
277*bb86c7edSAndroid Build Coastguard Worker   static_assert(std::is_same<Scalar, std::int8_t>::value ||
278*bb86c7edSAndroid Build Coastguard Worker                     std::is_same<Scalar, std::uint8_t>::value,
279*bb86c7edSAndroid Build Coastguard Worker                 "");
280*bb86c7edSAndroid Build Coastguard Worker   static constexpr int kInputXor =
281*bb86c7edSAndroid Build Coastguard Worker       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
282*bb86c7edSAndroid Build Coastguard Worker 
283*bb86c7edSAndroid Build Coastguard Worker   static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
284*bb86c7edSAndroid Build Coastguard Worker                   PMat<std::int8_t>* packed_matrix, int start_col,
285*bb86c7edSAndroid Build Coastguard Worker                   int end_col) {
286*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(src_matrix.layout));
287*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(packed_matrix->layout));
288*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK_EQ(start_col % 8, 0);
289*bb86c7edSAndroid Build Coastguard Worker     std::int32_t* sums = packed_matrix->sums;
290*bb86c7edSAndroid Build Coastguard Worker     Scalar zerobuf[16];
291*bb86c7edSAndroid Build Coastguard Worker     memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
292*bb86c7edSAndroid Build Coastguard Worker     for (int block_col = start_col; block_col < end_col; block_col += 4) {
293*bb86c7edSAndroid Build Coastguard Worker       int src_stride = src_matrix.layout.stride;
294*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
295*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr1 = src_ptr0 + src_stride;
296*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr2 = src_ptr1 + src_stride;
297*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr3 = src_ptr2 + src_stride;
298*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc0 = 16;
299*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc1 = 16;
300*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc2 = 16;
301*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc3 = 16;
302*bb86c7edSAndroid Build Coastguard Worker       if (block_col >= src_matrix.layout.cols - 3) {
303*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 0) {
304*bb86c7edSAndroid Build Coastguard Worker           src_ptr0 = zerobuf;
305*bb86c7edSAndroid Build Coastguard Worker           src_inc0 = 0;
306*bb86c7edSAndroid Build Coastguard Worker         }
307*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 1) {
308*bb86c7edSAndroid Build Coastguard Worker           src_ptr1 = zerobuf;
309*bb86c7edSAndroid Build Coastguard Worker           src_inc1 = 0;
310*bb86c7edSAndroid Build Coastguard Worker         }
311*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 2) {
312*bb86c7edSAndroid Build Coastguard Worker           src_ptr2 = zerobuf;
313*bb86c7edSAndroid Build Coastguard Worker           src_inc2 = 0;
314*bb86c7edSAndroid Build Coastguard Worker         }
315*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 3) {
316*bb86c7edSAndroid Build Coastguard Worker           src_ptr3 = zerobuf;
317*bb86c7edSAndroid Build Coastguard Worker           src_inc3 = 0;
318*bb86c7edSAndroid Build Coastguard Worker         }
319*bb86c7edSAndroid Build Coastguard Worker       }
320*bb86c7edSAndroid Build Coastguard Worker       std::int8_t* packed_ptr =
321*bb86c7edSAndroid Build Coastguard Worker           packed_matrix->data +
322*bb86c7edSAndroid Build Coastguard Worker           packed_matrix->layout.stride * (block_col & ~7) +
323*bb86c7edSAndroid Build Coastguard Worker           ((block_col & 4) * 4);
324*bb86c7edSAndroid Build Coastguard Worker       std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
325*bb86c7edSAndroid Build Coastguard Worker       if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
326*bb86c7edSAndroid Build Coastguard Worker         Pack8bitColMajorForNeonDotprodA55ish(
327*bb86c7edSAndroid Build Coastguard Worker             src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
328*bb86c7edSAndroid Build Coastguard Worker             src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
329*bb86c7edSAndroid Build Coastguard Worker             packed_ptr, sums_ptr, kInputXor);
330*bb86c7edSAndroid Build Coastguard Worker       } else {
331*bb86c7edSAndroid Build Coastguard Worker         Pack8bitColMajorForNeonDotprod(
332*bb86c7edSAndroid Build Coastguard Worker             src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
333*bb86c7edSAndroid Build Coastguard Worker             src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
334*bb86c7edSAndroid Build Coastguard Worker             packed_ptr, sums_ptr, kInputXor);
335*bb86c7edSAndroid Build Coastguard Worker       }
336*bb86c7edSAndroid Build Coastguard Worker     }
337*bb86c7edSAndroid Build Coastguard Worker   }
338*bb86c7edSAndroid Build Coastguard Worker };
339*bb86c7edSAndroid Build Coastguard Worker #endif  // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM)
340*bb86c7edSAndroid Build Coastguard Worker 
341*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
342*bb86c7edSAndroid Build Coastguard Worker void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
343*bb86c7edSAndroid Build Coastguard Worker                               const float* src_ptr2, const float* src_ptr3,
344*bb86c7edSAndroid Build Coastguard Worker                               int src_inc0, int src_inc1, int src_inc2,
345*bb86c7edSAndroid Build Coastguard Worker                               int src_inc3, int src_rows, float* packed_ptr);
346*bb86c7edSAndroid Build Coastguard Worker void PackFloatColMajorForNeonA55ish(const float* src_ptr0,
347*bb86c7edSAndroid Build Coastguard Worker                                     const float* src_ptr1,
348*bb86c7edSAndroid Build Coastguard Worker                                     const float* src_ptr2,
349*bb86c7edSAndroid Build Coastguard Worker                                     const float* src_ptr3, int src_inc0,
350*bb86c7edSAndroid Build Coastguard Worker                                     int src_inc1, int src_inc2, int src_inc3,
351*bb86c7edSAndroid Build Coastguard Worker                                     int src_rows, float* packed_ptr);
352*bb86c7edSAndroid Build Coastguard Worker 
353*bb86c7edSAndroid Build Coastguard Worker #elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
354*bb86c7edSAndroid Build Coastguard Worker void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
355*bb86c7edSAndroid Build Coastguard Worker                               const float* src_ptr2, const float* src_ptr3,
356*bb86c7edSAndroid Build Coastguard Worker                               int src_inc, int src_rows, float* packed_ptr,
357*bb86c7edSAndroid Build Coastguard Worker                               int stride);
358*bb86c7edSAndroid Build Coastguard Worker #endif  // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM)
359*bb86c7edSAndroid Build Coastguard Worker 
360*bb86c7edSAndroid Build Coastguard Worker #if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM)
361*bb86c7edSAndroid Build Coastguard Worker 
362*bb86c7edSAndroid Build Coastguard Worker template <>
363*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
364*bb86c7edSAndroid Build Coastguard Worker                 float, float, Order::kColMajor> {
365*bb86c7edSAndroid Build Coastguard Worker   static void Run(Tuning tuning, const Mat<float>& src_matrix,
366*bb86c7edSAndroid Build Coastguard Worker                   PMat<float>* packed_matrix, int start_col, int end_col) {
367*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(src_matrix.layout));
368*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(packed_matrix->layout));
369*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK_EQ(start_col % 8, 0);
370*bb86c7edSAndroid Build Coastguard Worker     const float zerobuf[4] = {0};
371*bb86c7edSAndroid Build Coastguard Worker     for (int block_col = start_col; block_col < end_col; block_col += 4) {
372*bb86c7edSAndroid Build Coastguard Worker       int src_stride = src_matrix.layout.stride;
373*bb86c7edSAndroid Build Coastguard Worker       const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
374*bb86c7edSAndroid Build Coastguard Worker       const float* src_ptr1 = src_ptr0 + src_stride;
375*bb86c7edSAndroid Build Coastguard Worker       const float* src_ptr2 = src_ptr1 + src_stride;
376*bb86c7edSAndroid Build Coastguard Worker       const float* src_ptr3 = src_ptr2 + src_stride;
377*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc0 = 16;
378*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc1 = 16;
379*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc2 = 16;
380*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc3 = 16;
381*bb86c7edSAndroid Build Coastguard Worker       if (block_col >= src_matrix.layout.cols - 3) {
382*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 0) {
383*bb86c7edSAndroid Build Coastguard Worker           src_ptr0 = zerobuf;
384*bb86c7edSAndroid Build Coastguard Worker           src_inc0 = 0;
385*bb86c7edSAndroid Build Coastguard Worker         }
386*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 1) {
387*bb86c7edSAndroid Build Coastguard Worker           src_ptr1 = zerobuf;
388*bb86c7edSAndroid Build Coastguard Worker           src_inc1 = 0;
389*bb86c7edSAndroid Build Coastguard Worker         }
390*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 2) {
391*bb86c7edSAndroid Build Coastguard Worker           src_ptr2 = zerobuf;
392*bb86c7edSAndroid Build Coastguard Worker           src_inc2 = 0;
393*bb86c7edSAndroid Build Coastguard Worker         }
394*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 3) {
395*bb86c7edSAndroid Build Coastguard Worker           src_ptr3 = zerobuf;
396*bb86c7edSAndroid Build Coastguard Worker           src_inc3 = 0;
397*bb86c7edSAndroid Build Coastguard Worker         }
398*bb86c7edSAndroid Build Coastguard Worker       }
399*bb86c7edSAndroid Build Coastguard Worker       float* packed_ptr = packed_matrix->data +
400*bb86c7edSAndroid Build Coastguard Worker                           packed_matrix->layout.stride * (block_col & ~7) +
401*bb86c7edSAndroid Build Coastguard Worker                           ((block_col & 4));
402*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64
403*bb86c7edSAndroid Build Coastguard Worker       if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
404*bb86c7edSAndroid Build Coastguard Worker         PackFloatColMajorForNeonA55ish(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
405*bb86c7edSAndroid Build Coastguard Worker                                        src_inc0, src_inc1, src_inc2, src_inc3,
406*bb86c7edSAndroid Build Coastguard Worker                                        src_matrix.layout.rows, packed_ptr);
407*bb86c7edSAndroid Build Coastguard Worker       } else {
408*bb86c7edSAndroid Build Coastguard Worker         PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
409*bb86c7edSAndroid Build Coastguard Worker                                  src_inc0, src_inc1, src_inc2, src_inc3,
410*bb86c7edSAndroid Build Coastguard Worker                                  src_matrix.layout.rows, packed_ptr);
411*bb86c7edSAndroid Build Coastguard Worker       }
412*bb86c7edSAndroid Build Coastguard Worker #else
413*bb86c7edSAndroid Build Coastguard Worker       (void)tuning;
414*bb86c7edSAndroid Build Coastguard Worker       // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc
415*bb86c7edSAndroid Build Coastguard Worker       // to save on registers (we have fewer general purpose registers in
416*bb86c7edSAndroid Build Coastguard Worker       // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four
417*bb86c7edSAndroid Build Coastguard Worker       // values that are each either 16 or 0 and use them directly. For the
418*bb86c7edSAndroid Build Coastguard Worker       // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should
419*bb86c7edSAndroid Build Coastguard Worker       // use the value 16 (bit is set) or 0 (bit is not set) for the
420*bb86c7edSAndroid Build Coastguard Worker       // respective increment value.
421*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc = 0;
422*bb86c7edSAndroid Build Coastguard Worker       src_inc += src_inc0 == 16 ? 1 : 0;
423*bb86c7edSAndroid Build Coastguard Worker       src_inc += src_inc1 == 16 ? 2 : 0;
424*bb86c7edSAndroid Build Coastguard Worker       src_inc += src_inc2 == 16 ? 4 : 0;
425*bb86c7edSAndroid Build Coastguard Worker       src_inc += src_inc3 == 16 ? 8 : 0;
426*bb86c7edSAndroid Build Coastguard Worker       const int kOutputStride = 32;
427*bb86c7edSAndroid Build Coastguard Worker       PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
428*bb86c7edSAndroid Build Coastguard Worker                                src_matrix.layout.rows, packed_ptr,
429*bb86c7edSAndroid Build Coastguard Worker                                kOutputStride);
430*bb86c7edSAndroid Build Coastguard Worker #endif  // RUY_PLATFORM_NEON_64
431*bb86c7edSAndroid Build Coastguard Worker     }
432*bb86c7edSAndroid Build Coastguard Worker   }
433*bb86c7edSAndroid Build Coastguard Worker };
434*bb86c7edSAndroid Build Coastguard Worker 
435*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_32
436*bb86c7edSAndroid Build Coastguard Worker // The 32-bit float kernel is 8 rows X 4 columns, so we need an additional
437*bb86c7edSAndroid Build Coastguard Worker // specialization for a FixedKernelLayout with 4 columns.
438*bb86c7edSAndroid Build Coastguard Worker template <>
439*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
440*bb86c7edSAndroid Build Coastguard Worker                 float, float, Order::kColMajor> {
441*bb86c7edSAndroid Build Coastguard Worker   static void Run(Tuning, const Mat<float>& src_matrix,
442*bb86c7edSAndroid Build Coastguard Worker                   PMat<float>* packed_matrix, int start_col, int end_col) {
443*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(src_matrix.layout));
444*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(packed_matrix->layout));
445*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK_EQ(start_col % 4, 0);
446*bb86c7edSAndroid Build Coastguard Worker     const float zerobuf[4] = {0};
447*bb86c7edSAndroid Build Coastguard Worker     for (int block_col = start_col; block_col < end_col; block_col += 4) {
448*bb86c7edSAndroid Build Coastguard Worker       int src_stride = src_matrix.layout.stride;
449*bb86c7edSAndroid Build Coastguard Worker       const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
450*bb86c7edSAndroid Build Coastguard Worker       const float* src_ptr1 = src_ptr0 + src_stride;
451*bb86c7edSAndroid Build Coastguard Worker       const float* src_ptr2 = src_ptr1 + src_stride;
452*bb86c7edSAndroid Build Coastguard Worker       const float* src_ptr3 = src_ptr2 + src_stride;
453*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc0 = 16;
454*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc1 = 16;
455*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc2 = 16;
456*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc3 = 16;
457*bb86c7edSAndroid Build Coastguard Worker       if (block_col >= src_matrix.layout.cols - 3) {
458*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 0) {
459*bb86c7edSAndroid Build Coastguard Worker           src_ptr0 = zerobuf;
460*bb86c7edSAndroid Build Coastguard Worker           src_inc0 = 0;
461*bb86c7edSAndroid Build Coastguard Worker         }
462*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 1) {
463*bb86c7edSAndroid Build Coastguard Worker           src_ptr1 = zerobuf;
464*bb86c7edSAndroid Build Coastguard Worker           src_inc1 = 0;
465*bb86c7edSAndroid Build Coastguard Worker         }
466*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 2) {
467*bb86c7edSAndroid Build Coastguard Worker           src_ptr2 = zerobuf;
468*bb86c7edSAndroid Build Coastguard Worker           src_inc2 = 0;
469*bb86c7edSAndroid Build Coastguard Worker         }
470*bb86c7edSAndroid Build Coastguard Worker         if (block_col >= src_matrix.layout.cols - 3) {
471*bb86c7edSAndroid Build Coastguard Worker           src_ptr3 = zerobuf;
472*bb86c7edSAndroid Build Coastguard Worker           src_inc3 = 0;
473*bb86c7edSAndroid Build Coastguard Worker         }
474*bb86c7edSAndroid Build Coastguard Worker       }
475*bb86c7edSAndroid Build Coastguard Worker       float* packed_ptr =
476*bb86c7edSAndroid Build Coastguard Worker           packed_matrix->data + packed_matrix->layout.stride * (block_col);
477*bb86c7edSAndroid Build Coastguard Worker       // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc
478*bb86c7edSAndroid Build Coastguard Worker       // to save registers.
479*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc = 0;
480*bb86c7edSAndroid Build Coastguard Worker       src_inc += src_inc0 == 16 ? 1 : 0;
481*bb86c7edSAndroid Build Coastguard Worker       src_inc += src_inc1 == 16 ? 2 : 0;
482*bb86c7edSAndroid Build Coastguard Worker       src_inc += src_inc2 == 16 ? 4 : 0;
483*bb86c7edSAndroid Build Coastguard Worker       src_inc += src_inc3 == 16 ? 8 : 0;
484*bb86c7edSAndroid Build Coastguard Worker       const int kOutputStride = 16;
485*bb86c7edSAndroid Build Coastguard Worker       PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
486*bb86c7edSAndroid Build Coastguard Worker                                src_matrix.layout.rows, packed_ptr,
487*bb86c7edSAndroid Build Coastguard Worker                                kOutputStride);
488*bb86c7edSAndroid Build Coastguard Worker     }
489*bb86c7edSAndroid Build Coastguard Worker   }
490*bb86c7edSAndroid Build Coastguard Worker };
491*bb86c7edSAndroid Build Coastguard Worker #endif  // (RUY_PLATFORM_NEON_32)
492*bb86c7edSAndroid Build Coastguard Worker #endif  // (RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && \
493*bb86c7edSAndroid Build Coastguard Worker         // RUY_OPT(ASM)
494*bb86c7edSAndroid Build Coastguard Worker 
495*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
496*bb86c7edSAndroid Build Coastguard Worker 
497*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar>
498*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
499*bb86c7edSAndroid Build Coastguard Worker                 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
500*bb86c7edSAndroid Build Coastguard Worker   static_assert(std::is_same<Scalar, std::int8_t>::value ||
501*bb86c7edSAndroid Build Coastguard Worker                     std::is_same<Scalar, std::uint8_t>::value,
502*bb86c7edSAndroid Build Coastguard Worker                 "");
503*bb86c7edSAndroid Build Coastguard Worker   static constexpr int kInputXor =
504*bb86c7edSAndroid Build Coastguard Worker       std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
505*bb86c7edSAndroid Build Coastguard Worker 
506*bb86c7edSAndroid Build Coastguard Worker   static void Run(Tuning, const Mat<Scalar>& src_matrix,
507*bb86c7edSAndroid Build Coastguard Worker                   PMat<std::int8_t>* packed_matrix, int start_col,
508*bb86c7edSAndroid Build Coastguard Worker                   int end_col) {
509*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsRowMajor(src_matrix.layout));
510*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK(IsColMajor(packed_matrix->layout));
511*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK_EQ(start_col % 8, 0);
512*bb86c7edSAndroid Build Coastguard Worker     std::int32_t* sums = packed_matrix->sums;
513*bb86c7edSAndroid Build Coastguard Worker     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
514*bb86c7edSAndroid Build Coastguard Worker     Scalar zerobuf[8];
515*bb86c7edSAndroid Build Coastguard Worker     memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
516*bb86c7edSAndroid Build Coastguard Worker     int src_stride = src_matrix.layout.stride;
517*bb86c7edSAndroid Build Coastguard Worker     // As the source matrix is row-major and the destination packed matrix is
518*bb86c7edSAndroid Build Coastguard Worker     // column-major, there is no traversal order that will be optimal for both
519*bb86c7edSAndroid Build Coastguard Worker     // so we choose to favor the source matrix with a row-major traversal order.
520*bb86c7edSAndroid Build Coastguard Worker     // Loop over groups of 4 rows.
521*bb86c7edSAndroid Build Coastguard Worker     for (int block_row = 0; block_row < packed_matrix->layout.rows;
522*bb86c7edSAndroid Build Coastguard Worker          block_row += 4) {
523*bb86c7edSAndroid Build Coastguard Worker       // src_ptr[0-3] shall point to the positions in the 4 rows of the source
524*bb86c7edSAndroid Build Coastguard Worker       // matrix that we are loading from, and will be incremented by
525*bb86c7edSAndroid Build Coastguard Worker       // src_inc[0-3] after each 4x8 block is loaded.
526*bb86c7edSAndroid Build Coastguard Worker       // First we compute these src_ptr and src_inc values for the case where
527*bb86c7edSAndroid Build Coastguard Worker       // there are 4 rows left to be loaded from in the source matrix ...
528*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr0 =
529*bb86c7edSAndroid Build Coastguard Worker           src_matrix.data.get() + src_stride * block_row + start_col;
530*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr1 = src_ptr0 + src_stride;
531*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr2 = src_ptr1 + src_stride;
532*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr3 = src_ptr2 + src_stride;
533*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc0 = 8;
534*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc1 = 8;
535*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc2 = 8;
536*bb86c7edSAndroid Build Coastguard Worker       std::int64_t src_inc3 = 8;
537*bb86c7edSAndroid Build Coastguard Worker       // ... and now we adjust these values in case there are fewer than 4 rows
538*bb86c7edSAndroid Build Coastguard Worker       // left to load from in the source matrix. In that case, we set the
539*bb86c7edSAndroid Build Coastguard Worker       // corresponding src_ptr pointer to load from `zerobuf` and set src_inc
540*bb86c7edSAndroid Build Coastguard Worker       // to 0 to avoid overrunning that small buffer.
541*bb86c7edSAndroid Build Coastguard Worker       if (block_row >= src_matrix.layout.rows - 3) {
542*bb86c7edSAndroid Build Coastguard Worker         if (block_row >= src_matrix.layout.rows - 0) {
543*bb86c7edSAndroid Build Coastguard Worker           src_ptr0 = zerobuf;
544*bb86c7edSAndroid Build Coastguard Worker           src_inc0 = 0;
545*bb86c7edSAndroid Build Coastguard Worker         }
546*bb86c7edSAndroid Build Coastguard Worker         if (block_row >= src_matrix.layout.rows - 1) {
547*bb86c7edSAndroid Build Coastguard Worker           src_ptr1 = zerobuf;
548*bb86c7edSAndroid Build Coastguard Worker           src_inc1 = 0;
549*bb86c7edSAndroid Build Coastguard Worker         }
550*bb86c7edSAndroid Build Coastguard Worker         if (block_row >= src_matrix.layout.rows - 2) {
551*bb86c7edSAndroid Build Coastguard Worker           src_ptr2 = zerobuf;
552*bb86c7edSAndroid Build Coastguard Worker           src_inc2 = 0;
553*bb86c7edSAndroid Build Coastguard Worker         }
554*bb86c7edSAndroid Build Coastguard Worker         if (block_row >= src_matrix.layout.rows - 3) {
555*bb86c7edSAndroid Build Coastguard Worker           src_ptr3 = zerobuf;
556*bb86c7edSAndroid Build Coastguard Worker           src_inc3 = 0;
557*bb86c7edSAndroid Build Coastguard Worker         }
558*bb86c7edSAndroid Build Coastguard Worker       }
559*bb86c7edSAndroid Build Coastguard Worker       // Let src_cols be the number of source matrix columns to handle.
560*bb86c7edSAndroid Build Coastguard Worker       int src_cols = std::min(end_col, src_matrix.layout.cols) - start_col;
561*bb86c7edSAndroid Build Coastguard Worker       std::int8_t* packed_ptr = packed_matrix->data +
562*bb86c7edSAndroid Build Coastguard Worker                                 packed_matrix->layout.stride * start_col +
563*bb86c7edSAndroid Build Coastguard Worker                                 8 * block_row;
564*bb86c7edSAndroid Build Coastguard Worker       std::int32_t* sums_ptr = sums + start_col;
565*bb86c7edSAndroid Build Coastguard Worker       Pack8bitRowMajorForNeonDotprod(
566*bb86c7edSAndroid Build Coastguard Worker           src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, src_inc2,
567*bb86c7edSAndroid Build Coastguard Worker           src_inc3, src_cols, src_matrix.zero_point, packed_ptr,
568*bb86c7edSAndroid Build Coastguard Worker           packed_matrix->layout.stride, sums_ptr, kInputXor);
569*bb86c7edSAndroid Build Coastguard Worker     }
570*bb86c7edSAndroid Build Coastguard Worker   }
571*bb86c7edSAndroid Build Coastguard Worker };
572*bb86c7edSAndroid Build Coastguard Worker 
573*bb86c7edSAndroid Build Coastguard Worker #endif  // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
574*bb86c7edSAndroid Build Coastguard Worker 
575*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON
576*bb86c7edSAndroid Build Coastguard Worker 
577*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar, int KernelCols>
578*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon,
579*bb86c7edSAndroid Build Coastguard Worker                 FixedKernelLayout<Order::kColMajor, 16, KernelCols>, Scalar,
580*bb86c7edSAndroid Build Coastguard Worker                 std::int8_t, std::int32_t, Order::kRowMajor> {
581*bb86c7edSAndroid Build Coastguard Worker   static void Run(Tuning, const Mat<Scalar>& src_matrix,
582*bb86c7edSAndroid Build Coastguard Worker                   PMat<std::int8_t>* packed_matrix, int start_col,
583*bb86c7edSAndroid Build Coastguard Worker                   int end_col) {
584*bb86c7edSAndroid Build Coastguard Worker     profiler::ScopeLabel label("Pack (KNeon, from row-major source)");
585*bb86c7edSAndroid Build Coastguard Worker     static constexpr int kInputXor =
586*bb86c7edSAndroid Build Coastguard Worker         std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
587*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
588*bb86c7edSAndroid Build Coastguard Worker     RUY_DCHECK_EQ((end_col - start_col) % KernelCols, 0);
589*bb86c7edSAndroid Build Coastguard Worker     std::int32_t* sums = packed_matrix->sums;
590*bb86c7edSAndroid Build Coastguard Worker     std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
591*bb86c7edSAndroid Build Coastguard Worker     int block_row = 0;
592*bb86c7edSAndroid Build Coastguard Worker     for (; block_row < packed_matrix->layout.rows; block_row += 16) {
593*bb86c7edSAndroid Build Coastguard Worker       int src_stride = src_matrix.layout.stride;
594*bb86c7edSAndroid Build Coastguard Worker       int packed_stride = packed_matrix->layout.stride;
595*bb86c7edSAndroid Build Coastguard Worker       const Scalar* src_ptr =
596*bb86c7edSAndroid Build Coastguard Worker           src_matrix.data.get() + block_row * src_stride + start_col;
597*bb86c7edSAndroid Build Coastguard Worker       std::int8_t* packed_ptr = packed_matrix->data +
598*bb86c7edSAndroid Build Coastguard Worker                                 start_col * packed_stride +
599*bb86c7edSAndroid Build Coastguard Worker                                 block_row * KernelCols;
600*bb86c7edSAndroid Build Coastguard Worker 
601*bb86c7edSAndroid Build Coastguard Worker       Pack8bitRowMajorForNeon(
602*bb86c7edSAndroid Build Coastguard Worker           reinterpret_cast<const std::uint8_t*>(src_ptr), src_stride,
603*bb86c7edSAndroid Build Coastguard Worker           src_matrix.layout.rows, src_matrix.layout.cols, block_row, start_col,
604*bb86c7edSAndroid Build Coastguard Worker           end_col, packed_ptr, packed_stride, packed_matrix->zero_point, sums,
605*bb86c7edSAndroid Build Coastguard Worker           kInputXor, KernelCols);
606*bb86c7edSAndroid Build Coastguard Worker     }
607*bb86c7edSAndroid Build Coastguard Worker   }
608*bb86c7edSAndroid Build Coastguard Worker };
609*bb86c7edSAndroid Build Coastguard Worker #endif
610*bb86c7edSAndroid Build Coastguard Worker 
611*bb86c7edSAndroid Build Coastguard Worker }  // namespace ruy
612*bb86c7edSAndroid Build Coastguard Worker 
613*bb86c7edSAndroid Build Coastguard Worker #endif  // RUY_RUY_PACK_ARM_H_
614