1*bb86c7edSAndroid Build Coastguard Worker /* Copyright 2019 Google LLC. All Rights Reserved. 2*bb86c7edSAndroid Build Coastguard Worker 3*bb86c7edSAndroid Build Coastguard Worker Licensed under the Apache License, Version 2.0 (the "License"); 4*bb86c7edSAndroid Build Coastguard Worker you may not use this file except in compliance with the License. 5*bb86c7edSAndroid Build Coastguard Worker You may obtain a copy of the License at 6*bb86c7edSAndroid Build Coastguard Worker 7*bb86c7edSAndroid Build Coastguard Worker http://www.apache.org/licenses/LICENSE-2.0 8*bb86c7edSAndroid Build Coastguard Worker 9*bb86c7edSAndroid Build Coastguard Worker Unless required by applicable law or agreed to in writing, software 10*bb86c7edSAndroid Build Coastguard Worker distributed under the License is distributed on an "AS IS" BASIS, 11*bb86c7edSAndroid Build Coastguard Worker WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*bb86c7edSAndroid Build Coastguard Worker See the License for the specific language governing permissions and 13*bb86c7edSAndroid Build Coastguard Worker limitations under the License. 14*bb86c7edSAndroid Build Coastguard Worker ==============================================================================*/ 15*bb86c7edSAndroid Build Coastguard Worker 16*bb86c7edSAndroid Build Coastguard Worker #ifndef RUY_RUY_PACK_ARM_H_ 17*bb86c7edSAndroid Build Coastguard Worker #define RUY_RUY_PACK_ARM_H_ 18*bb86c7edSAndroid Build Coastguard Worker 19*bb86c7edSAndroid Build Coastguard Worker #include <algorithm> 20*bb86c7edSAndroid Build Coastguard Worker #include <cstdint> 21*bb86c7edSAndroid Build Coastguard Worker #include <type_traits> 22*bb86c7edSAndroid Build Coastguard Worker 23*bb86c7edSAndroid Build Coastguard Worker #include "ruy/asm_helpers.h" 24*bb86c7edSAndroid Build Coastguard Worker #include "ruy/check_macros.h" 25*bb86c7edSAndroid Build Coastguard Worker #include "ruy/mat.h" 26*bb86c7edSAndroid Build Coastguard Worker #include "ruy/opt_set.h" 27*bb86c7edSAndroid Build Coastguard Worker #include "ruy/pack_common.h" 28*bb86c7edSAndroid Build Coastguard Worker #include "ruy/path.h" 29*bb86c7edSAndroid Build Coastguard Worker #include "ruy/platform.h" 30*bb86c7edSAndroid Build Coastguard Worker #include "ruy/profiler/instrumentation.h" 31*bb86c7edSAndroid Build Coastguard Worker #include "ruy/tune.h" 32*bb86c7edSAndroid Build Coastguard Worker 33*bb86c7edSAndroid Build Coastguard Worker namespace ruy { 34*bb86c7edSAndroid Build Coastguard Worker 35*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON 36*bb86c7edSAndroid Build Coastguard Worker RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon) 37*bb86c7edSAndroid Build Coastguard Worker RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod) 38*bb86c7edSAndroid Build Coastguard Worker 39*bb86c7edSAndroid Build Coastguard Worker RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 8) 40*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_32 41*bb86c7edSAndroid Build Coastguard Worker RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 4) 42*bb86c7edSAndroid Build Coastguard Worker #endif 43*bb86c7edSAndroid Build Coastguard Worker 44*bb86c7edSAndroid Build Coastguard Worker template <> 45*bb86c7edSAndroid Build Coastguard Worker struct PackedTypeImpl<Path::kNeon, std::uint8_t> { 46*bb86c7edSAndroid Build Coastguard Worker using Type = std::int8_t; 47*bb86c7edSAndroid Build Coastguard Worker }; 48*bb86c7edSAndroid Build Coastguard Worker template <> 49*bb86c7edSAndroid Build Coastguard Worker struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> { 50*bb86c7edSAndroid Build Coastguard Worker using Type = std::int8_t; 51*bb86c7edSAndroid Build Coastguard Worker }; 52*bb86c7edSAndroid Build Coastguard Worker #endif 53*bb86c7edSAndroid Build Coastguard Worker 54*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON 55*bb86c7edSAndroid Build Coastguard Worker void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride, 56*bb86c7edSAndroid Build Coastguard Worker int src_rows, int src_cols, int block_row, 57*bb86c7edSAndroid Build Coastguard Worker int start_col, int end_col, 58*bb86c7edSAndroid Build Coastguard Worker std::int8_t* packed_ptr, int packed_stride, 59*bb86c7edSAndroid Build Coastguard Worker int packed_zero_point, std::int32_t* sums_ptr, 60*bb86c7edSAndroid Build Coastguard Worker int input_xor, int kernel_cols); 61*bb86c7edSAndroid Build Coastguard Worker #endif 62*bb86c7edSAndroid Build Coastguard Worker 63*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) 64*bb86c7edSAndroid Build Coastguard Worker 65*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1, 66*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr2, const void* src_ptr3, 67*bb86c7edSAndroid Build Coastguard Worker int src_inc0, int src_inc1, int src_inc2, 68*bb86c7edSAndroid Build Coastguard Worker int src_inc3, int src_rows, int src_zero_point, 69*bb86c7edSAndroid Build Coastguard Worker std::int8_t* packed_ptr, std::int32_t* sums_ptr, 70*bb86c7edSAndroid Build Coastguard Worker int input_xor); 71*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1, 72*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr2, const void* src_ptr3, 73*bb86c7edSAndroid Build Coastguard Worker int src_inc0, int src_inc1, int src_inc2, 74*bb86c7edSAndroid Build Coastguard Worker int src_inc3, int src_rows, 75*bb86c7edSAndroid Build Coastguard Worker int src_zero_point, std::int8_t* packed_ptr, 76*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums_ptr, int input_xor); 77*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1, 78*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr2, const void* src_ptr3, 79*bb86c7edSAndroid Build Coastguard Worker int src_inc0, int src_inc1, int src_inc2, 80*bb86c7edSAndroid Build Coastguard Worker int src_inc3, int src_rows, 81*bb86c7edSAndroid Build Coastguard Worker int src_zero_point, std::int8_t* packed_ptr, 82*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums_ptr, int input_xor); 83*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeonDotprodA55ish( 84*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr0, const void* src_ptr1, const void* src_ptr2, 85*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2, 86*bb86c7edSAndroid Build Coastguard Worker int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr, 87*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums_ptr, int input_xor); 88*bb86c7edSAndroid Build Coastguard Worker void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1, 89*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr2, const void* src_ptr3, 90*bb86c7edSAndroid Build Coastguard Worker int src_inc0, int src_inc1, int src_inc2, 91*bb86c7edSAndroid Build Coastguard Worker int src_inc3, int src_cols, 92*bb86c7edSAndroid Build Coastguard Worker int src_zero_point, std::int8_t* packed_ptr, 93*bb86c7edSAndroid Build Coastguard Worker int packed_stride, std::int32_t* sums_ptr, 94*bb86c7edSAndroid Build Coastguard Worker int input_xor); 95*bb86c7edSAndroid Build Coastguard Worker #elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) 96*bb86c7edSAndroid Build Coastguard Worker 97*bb86c7edSAndroid Build Coastguard Worker struct PackParams8bit { 98*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr0; 99*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr1; 100*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr2; 101*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr3; 102*bb86c7edSAndroid Build Coastguard Worker const std::int32_t* sums_ptr; 103*bb86c7edSAndroid Build Coastguard Worker const std::int8_t* packed_ptr; 104*bb86c7edSAndroid Build Coastguard Worker int src_inc0; 105*bb86c7edSAndroid Build Coastguard Worker int src_inc1; 106*bb86c7edSAndroid Build Coastguard Worker int src_inc2; 107*bb86c7edSAndroid Build Coastguard Worker int src_inc3; 108*bb86c7edSAndroid Build Coastguard Worker int src_rows; 109*bb86c7edSAndroid Build Coastguard Worker int src_zero_point; 110*bb86c7edSAndroid Build Coastguard Worker int input_xor; 111*bb86c7edSAndroid Build Coastguard Worker }; 112*bb86c7edSAndroid Build Coastguard Worker 113*bb86c7edSAndroid Build Coastguard Worker inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1, 114*bb86c7edSAndroid Build Coastguard Worker const void* src_ptr2, const void* src_ptr3, 115*bb86c7edSAndroid Build Coastguard Worker const std::int32_t* sums_ptr, 116*bb86c7edSAndroid Build Coastguard Worker const std::int8_t* packed_ptr, int src_inc0, 117*bb86c7edSAndroid Build Coastguard Worker int src_inc1, int src_inc2, int src_inc3, 118*bb86c7edSAndroid Build Coastguard Worker int src_rows, int src_zero_point, int input_xor, 119*bb86c7edSAndroid Build Coastguard Worker PackParams8bit* params) { 120*bb86c7edSAndroid Build Coastguard Worker params->src_ptr0 = src_ptr0; 121*bb86c7edSAndroid Build Coastguard Worker params->src_ptr1 = src_ptr1; 122*bb86c7edSAndroid Build Coastguard Worker params->src_ptr2 = src_ptr2; 123*bb86c7edSAndroid Build Coastguard Worker params->src_ptr3 = src_ptr3; 124*bb86c7edSAndroid Build Coastguard Worker params->sums_ptr = sums_ptr; 125*bb86c7edSAndroid Build Coastguard Worker params->packed_ptr = packed_ptr; 126*bb86c7edSAndroid Build Coastguard Worker params->src_inc0 = src_inc0; 127*bb86c7edSAndroid Build Coastguard Worker params->src_inc1 = src_inc1; 128*bb86c7edSAndroid Build Coastguard Worker params->src_inc2 = src_inc2; 129*bb86c7edSAndroid Build Coastguard Worker params->src_inc3 = src_inc3; 130*bb86c7edSAndroid Build Coastguard Worker params->src_rows = src_rows; 131*bb86c7edSAndroid Build Coastguard Worker params->src_zero_point = src_zero_point; 132*bb86c7edSAndroid Build Coastguard Worker params->input_xor = input_xor; 133*bb86c7edSAndroid Build Coastguard Worker } 134*bb86c7edSAndroid Build Coastguard Worker 135*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params); 136*bb86c7edSAndroid Build Coastguard Worker void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params); 137*bb86c7edSAndroid Build Coastguard Worker 138*bb86c7edSAndroid Build Coastguard Worker #endif // (RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) 139*bb86c7edSAndroid Build Coastguard Worker 140*bb86c7edSAndroid Build Coastguard Worker #if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM) 141*bb86c7edSAndroid Build Coastguard Worker 142*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar> 143*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar, 144*bb86c7edSAndroid Build Coastguard Worker std::int8_t, std::int32_t, Order::kColMajor> { 145*bb86c7edSAndroid Build Coastguard Worker static_assert(std::is_same<Scalar, std::int8_t>::value || 146*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::uint8_t>::value, 147*bb86c7edSAndroid Build Coastguard Worker ""); 148*bb86c7edSAndroid Build Coastguard Worker static constexpr int kInputXor = 149*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 150*bb86c7edSAndroid Build Coastguard Worker 151*bb86c7edSAndroid Build Coastguard Worker static void Run(Tuning tuning, const Mat<Scalar>& src_matrix, 152*bb86c7edSAndroid Build Coastguard Worker PMat<std::int8_t>* packed_matrix, int start_col, 153*bb86c7edSAndroid Build Coastguard Worker int end_col) { 154*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(src_matrix.layout)); 155*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(packed_matrix->layout)); 156*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK_EQ(start_col % 4, 0); 157*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums = packed_matrix->sums; 158*bb86c7edSAndroid Build Coastguard Worker Scalar zerobuf[16]; 159*bb86c7edSAndroid Build Coastguard Worker memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); 160*bb86c7edSAndroid Build Coastguard Worker for (int block_col = start_col; block_col < end_col; block_col += 4) { 161*bb86c7edSAndroid Build Coastguard Worker int src_stride = src_matrix.layout.stride; 162*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; 163*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr1 = src_ptr0 + src_stride; 164*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr2 = src_ptr1 + src_stride; 165*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr3 = src_ptr2 + src_stride; 166*bb86c7edSAndroid Build Coastguard Worker int src_inc0 = 16; 167*bb86c7edSAndroid Build Coastguard Worker int src_inc1 = 16; 168*bb86c7edSAndroid Build Coastguard Worker int src_inc2 = 16; 169*bb86c7edSAndroid Build Coastguard Worker int src_inc3 = 16; 170*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 3) { 171*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 0) { 172*bb86c7edSAndroid Build Coastguard Worker src_ptr0 = zerobuf; 173*bb86c7edSAndroid Build Coastguard Worker src_inc0 = 0; 174*bb86c7edSAndroid Build Coastguard Worker } 175*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 1) { 176*bb86c7edSAndroid Build Coastguard Worker src_ptr1 = zerobuf; 177*bb86c7edSAndroid Build Coastguard Worker src_inc1 = 0; 178*bb86c7edSAndroid Build Coastguard Worker } 179*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 2) { 180*bb86c7edSAndroid Build Coastguard Worker src_ptr2 = zerobuf; 181*bb86c7edSAndroid Build Coastguard Worker src_inc2 = 0; 182*bb86c7edSAndroid Build Coastguard Worker } 183*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 3) { 184*bb86c7edSAndroid Build Coastguard Worker src_ptr3 = zerobuf; 185*bb86c7edSAndroid Build Coastguard Worker src_inc3 = 0; 186*bb86c7edSAndroid Build Coastguard Worker } 187*bb86c7edSAndroid Build Coastguard Worker } 188*bb86c7edSAndroid Build Coastguard Worker std::int8_t* packed_ptr = 189*bb86c7edSAndroid Build Coastguard Worker packed_matrix->data + packed_matrix->layout.stride * block_col; 190*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; 191*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 192*bb86c7edSAndroid Build Coastguard Worker if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 193*bb86c7edSAndroid Build Coastguard Worker Pack8bitColMajorForNeonA55ish( 194*bb86c7edSAndroid Build Coastguard Worker src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, 195*bb86c7edSAndroid Build Coastguard Worker src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, 196*bb86c7edSAndroid Build Coastguard Worker packed_ptr, sums_ptr, kInputXor); 197*bb86c7edSAndroid Build Coastguard Worker } else { 198*bb86c7edSAndroid Build Coastguard Worker Pack8bitColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, 199*bb86c7edSAndroid Build Coastguard Worker src_inc0, src_inc1, src_inc2, src_inc3, 200*bb86c7edSAndroid Build Coastguard Worker src_matrix.layout.rows, src_matrix.zero_point, 201*bb86c7edSAndroid Build Coastguard Worker packed_ptr, sums_ptr, kInputXor); 202*bb86c7edSAndroid Build Coastguard Worker } 203*bb86c7edSAndroid Build Coastguard Worker #else 204*bb86c7edSAndroid Build Coastguard Worker (void)tuning; 205*bb86c7edSAndroid Build Coastguard Worker // We have a more limited set of general purpose registers in ARMv7, so 206*bb86c7edSAndroid Build Coastguard Worker // we use the "params" struct technique from the kernel code to save 207*bb86c7edSAndroid Build Coastguard Worker // registers. 208*bb86c7edSAndroid Build Coastguard Worker PackParams8bit params; 209*bb86c7edSAndroid Build Coastguard Worker MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr, 210*bb86c7edSAndroid Build Coastguard Worker packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3, 211*bb86c7edSAndroid Build Coastguard Worker src_matrix.layout.rows, src_matrix.zero_point, 212*bb86c7edSAndroid Build Coastguard Worker kInputXor, ¶ms); 213*bb86c7edSAndroid Build Coastguard Worker Pack8bitColMajorForNeon4Cols(params); 214*bb86c7edSAndroid Build Coastguard Worker #endif // RUY_PLATFORM_NEON_64 215*bb86c7edSAndroid Build Coastguard Worker } 216*bb86c7edSAndroid Build Coastguard Worker } 217*bb86c7edSAndroid Build Coastguard Worker }; 218*bb86c7edSAndroid Build Coastguard Worker 219*bb86c7edSAndroid Build Coastguard Worker #endif // (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && 220*bb86c7edSAndroid Build Coastguard Worker // RUY_OPT(ASM) 221*bb86c7edSAndroid Build Coastguard Worker 222*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) 223*bb86c7edSAndroid Build Coastguard Worker // The 32-bit float kernel is 4 rows X 2 columns, so we need an additional 224*bb86c7edSAndroid Build Coastguard Worker // partial specialization for the RHS, which has a FixedKernelLayout with 2 225*bb86c7edSAndroid Build Coastguard Worker // columns. 226*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar> 227*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 2>, Scalar, 228*bb86c7edSAndroid Build Coastguard Worker std::int8_t, std::int32_t, Order::kColMajor> { 229*bb86c7edSAndroid Build Coastguard Worker static_assert(std::is_same<Scalar, std::int8_t>::value || 230*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::uint8_t>::value, 231*bb86c7edSAndroid Build Coastguard Worker ""); 232*bb86c7edSAndroid Build Coastguard Worker static constexpr int kInputXor = 233*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 234*bb86c7edSAndroid Build Coastguard Worker static void Run(Tuning, const Mat<Scalar>& src_matrix, 235*bb86c7edSAndroid Build Coastguard Worker PMat<std::int8_t>* packed_matrix, int start_col, 236*bb86c7edSAndroid Build Coastguard Worker int end_col) { 237*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(src_matrix.layout)); 238*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(packed_matrix->layout)); 239*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK_EQ(start_col % 2, 0); 240*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums = packed_matrix->sums; 241*bb86c7edSAndroid Build Coastguard Worker Scalar zerobuf[16]; 242*bb86c7edSAndroid Build Coastguard Worker memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); 243*bb86c7edSAndroid Build Coastguard Worker for (int block_col = start_col; block_col < end_col; block_col += 2) { 244*bb86c7edSAndroid Build Coastguard Worker int src_stride = src_matrix.layout.stride; 245*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; 246*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr1 = src_ptr0 + src_stride; 247*bb86c7edSAndroid Build Coastguard Worker int src_inc0 = 16; 248*bb86c7edSAndroid Build Coastguard Worker int src_inc1 = 16; 249*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 2) { 250*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 0) { 251*bb86c7edSAndroid Build Coastguard Worker src_ptr0 = zerobuf; 252*bb86c7edSAndroid Build Coastguard Worker src_inc0 = 0; 253*bb86c7edSAndroid Build Coastguard Worker } 254*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 1) { 255*bb86c7edSAndroid Build Coastguard Worker src_ptr1 = zerobuf; 256*bb86c7edSAndroid Build Coastguard Worker src_inc1 = 0; 257*bb86c7edSAndroid Build Coastguard Worker } 258*bb86c7edSAndroid Build Coastguard Worker } 259*bb86c7edSAndroid Build Coastguard Worker std::int8_t* packed_ptr = 260*bb86c7edSAndroid Build Coastguard Worker packed_matrix->data + packed_matrix->layout.stride * block_col; 261*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; 262*bb86c7edSAndroid Build Coastguard Worker PackParams8bit params; 263*bb86c7edSAndroid Build Coastguard Worker MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr, 264*bb86c7edSAndroid Build Coastguard Worker packed_ptr, src_inc0, src_inc1, -1, -1, 265*bb86c7edSAndroid Build Coastguard Worker src_matrix.layout.rows, src_matrix.zero_point, 266*bb86c7edSAndroid Build Coastguard Worker kInputXor, ¶ms); 267*bb86c7edSAndroid Build Coastguard Worker Pack8bitColMajorForNeon2Cols(params); 268*bb86c7edSAndroid Build Coastguard Worker } 269*bb86c7edSAndroid Build Coastguard Worker } 270*bb86c7edSAndroid Build Coastguard Worker }; 271*bb86c7edSAndroid Build Coastguard Worker #endif // (RUY_PLATFORM_NEON_32) && RUY_OPT(ASM) 272*bb86c7edSAndroid Build Coastguard Worker 273*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) 274*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar> 275*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>, 276*bb86c7edSAndroid Build Coastguard Worker Scalar, std::int8_t, std::int32_t, Order::kColMajor> { 277*bb86c7edSAndroid Build Coastguard Worker static_assert(std::is_same<Scalar, std::int8_t>::value || 278*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::uint8_t>::value, 279*bb86c7edSAndroid Build Coastguard Worker ""); 280*bb86c7edSAndroid Build Coastguard Worker static constexpr int kInputXor = 281*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 282*bb86c7edSAndroid Build Coastguard Worker 283*bb86c7edSAndroid Build Coastguard Worker static void Run(Tuning tuning, const Mat<Scalar>& src_matrix, 284*bb86c7edSAndroid Build Coastguard Worker PMat<std::int8_t>* packed_matrix, int start_col, 285*bb86c7edSAndroid Build Coastguard Worker int end_col) { 286*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(src_matrix.layout)); 287*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(packed_matrix->layout)); 288*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK_EQ(start_col % 8, 0); 289*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums = packed_matrix->sums; 290*bb86c7edSAndroid Build Coastguard Worker Scalar zerobuf[16]; 291*bb86c7edSAndroid Build Coastguard Worker memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); 292*bb86c7edSAndroid Build Coastguard Worker for (int block_col = start_col; block_col < end_col; block_col += 4) { 293*bb86c7edSAndroid Build Coastguard Worker int src_stride = src_matrix.layout.stride; 294*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; 295*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr1 = src_ptr0 + src_stride; 296*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr2 = src_ptr1 + src_stride; 297*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr3 = src_ptr2 + src_stride; 298*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc0 = 16; 299*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc1 = 16; 300*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc2 = 16; 301*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc3 = 16; 302*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 3) { 303*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 0) { 304*bb86c7edSAndroid Build Coastguard Worker src_ptr0 = zerobuf; 305*bb86c7edSAndroid Build Coastguard Worker src_inc0 = 0; 306*bb86c7edSAndroid Build Coastguard Worker } 307*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 1) { 308*bb86c7edSAndroid Build Coastguard Worker src_ptr1 = zerobuf; 309*bb86c7edSAndroid Build Coastguard Worker src_inc1 = 0; 310*bb86c7edSAndroid Build Coastguard Worker } 311*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 2) { 312*bb86c7edSAndroid Build Coastguard Worker src_ptr2 = zerobuf; 313*bb86c7edSAndroid Build Coastguard Worker src_inc2 = 0; 314*bb86c7edSAndroid Build Coastguard Worker } 315*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 3) { 316*bb86c7edSAndroid Build Coastguard Worker src_ptr3 = zerobuf; 317*bb86c7edSAndroid Build Coastguard Worker src_inc3 = 0; 318*bb86c7edSAndroid Build Coastguard Worker } 319*bb86c7edSAndroid Build Coastguard Worker } 320*bb86c7edSAndroid Build Coastguard Worker std::int8_t* packed_ptr = 321*bb86c7edSAndroid Build Coastguard Worker packed_matrix->data + 322*bb86c7edSAndroid Build Coastguard Worker packed_matrix->layout.stride * (block_col & ~7) + 323*bb86c7edSAndroid Build Coastguard Worker ((block_col & 4) * 4); 324*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; 325*bb86c7edSAndroid Build Coastguard Worker if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 326*bb86c7edSAndroid Build Coastguard Worker Pack8bitColMajorForNeonDotprodA55ish( 327*bb86c7edSAndroid Build Coastguard Worker src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, 328*bb86c7edSAndroid Build Coastguard Worker src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, 329*bb86c7edSAndroid Build Coastguard Worker packed_ptr, sums_ptr, kInputXor); 330*bb86c7edSAndroid Build Coastguard Worker } else { 331*bb86c7edSAndroid Build Coastguard Worker Pack8bitColMajorForNeonDotprod( 332*bb86c7edSAndroid Build Coastguard Worker src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, 333*bb86c7edSAndroid Build Coastguard Worker src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, 334*bb86c7edSAndroid Build Coastguard Worker packed_ptr, sums_ptr, kInputXor); 335*bb86c7edSAndroid Build Coastguard Worker } 336*bb86c7edSAndroid Build Coastguard Worker } 337*bb86c7edSAndroid Build Coastguard Worker } 338*bb86c7edSAndroid Build Coastguard Worker }; 339*bb86c7edSAndroid Build Coastguard Worker #endif // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM) 340*bb86c7edSAndroid Build Coastguard Worker 341*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) 342*bb86c7edSAndroid Build Coastguard Worker void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1, 343*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr2, const float* src_ptr3, 344*bb86c7edSAndroid Build Coastguard Worker int src_inc0, int src_inc1, int src_inc2, 345*bb86c7edSAndroid Build Coastguard Worker int src_inc3, int src_rows, float* packed_ptr); 346*bb86c7edSAndroid Build Coastguard Worker void PackFloatColMajorForNeonA55ish(const float* src_ptr0, 347*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr1, 348*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr2, 349*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr3, int src_inc0, 350*bb86c7edSAndroid Build Coastguard Worker int src_inc1, int src_inc2, int src_inc3, 351*bb86c7edSAndroid Build Coastguard Worker int src_rows, float* packed_ptr); 352*bb86c7edSAndroid Build Coastguard Worker 353*bb86c7edSAndroid Build Coastguard Worker #elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) 354*bb86c7edSAndroid Build Coastguard Worker void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1, 355*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr2, const float* src_ptr3, 356*bb86c7edSAndroid Build Coastguard Worker int src_inc, int src_rows, float* packed_ptr, 357*bb86c7edSAndroid Build Coastguard Worker int stride); 358*bb86c7edSAndroid Build Coastguard Worker #endif // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM) 359*bb86c7edSAndroid Build Coastguard Worker 360*bb86c7edSAndroid Build Coastguard Worker #if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM) 361*bb86c7edSAndroid Build Coastguard Worker 362*bb86c7edSAndroid Build Coastguard Worker template <> 363*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, 364*bb86c7edSAndroid Build Coastguard Worker float, float, Order::kColMajor> { 365*bb86c7edSAndroid Build Coastguard Worker static void Run(Tuning tuning, const Mat<float>& src_matrix, 366*bb86c7edSAndroid Build Coastguard Worker PMat<float>* packed_matrix, int start_col, int end_col) { 367*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(src_matrix.layout)); 368*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(packed_matrix->layout)); 369*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK_EQ(start_col % 8, 0); 370*bb86c7edSAndroid Build Coastguard Worker const float zerobuf[4] = {0}; 371*bb86c7edSAndroid Build Coastguard Worker for (int block_col = start_col; block_col < end_col; block_col += 4) { 372*bb86c7edSAndroid Build Coastguard Worker int src_stride = src_matrix.layout.stride; 373*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; 374*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr1 = src_ptr0 + src_stride; 375*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr2 = src_ptr1 + src_stride; 376*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr3 = src_ptr2 + src_stride; 377*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc0 = 16; 378*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc1 = 16; 379*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc2 = 16; 380*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc3 = 16; 381*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 3) { 382*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 0) { 383*bb86c7edSAndroid Build Coastguard Worker src_ptr0 = zerobuf; 384*bb86c7edSAndroid Build Coastguard Worker src_inc0 = 0; 385*bb86c7edSAndroid Build Coastguard Worker } 386*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 1) { 387*bb86c7edSAndroid Build Coastguard Worker src_ptr1 = zerobuf; 388*bb86c7edSAndroid Build Coastguard Worker src_inc1 = 0; 389*bb86c7edSAndroid Build Coastguard Worker } 390*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 2) { 391*bb86c7edSAndroid Build Coastguard Worker src_ptr2 = zerobuf; 392*bb86c7edSAndroid Build Coastguard Worker src_inc2 = 0; 393*bb86c7edSAndroid Build Coastguard Worker } 394*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 3) { 395*bb86c7edSAndroid Build Coastguard Worker src_ptr3 = zerobuf; 396*bb86c7edSAndroid Build Coastguard Worker src_inc3 = 0; 397*bb86c7edSAndroid Build Coastguard Worker } 398*bb86c7edSAndroid Build Coastguard Worker } 399*bb86c7edSAndroid Build Coastguard Worker float* packed_ptr = packed_matrix->data + 400*bb86c7edSAndroid Build Coastguard Worker packed_matrix->layout.stride * (block_col & ~7) + 401*bb86c7edSAndroid Build Coastguard Worker ((block_col & 4)); 402*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 403*bb86c7edSAndroid Build Coastguard Worker if (__builtin_expect(tuning == Tuning::kA55ish, true)) { 404*bb86c7edSAndroid Build Coastguard Worker PackFloatColMajorForNeonA55ish(src_ptr0, src_ptr1, src_ptr2, src_ptr3, 405*bb86c7edSAndroid Build Coastguard Worker src_inc0, src_inc1, src_inc2, src_inc3, 406*bb86c7edSAndroid Build Coastguard Worker src_matrix.layout.rows, packed_ptr); 407*bb86c7edSAndroid Build Coastguard Worker } else { 408*bb86c7edSAndroid Build Coastguard Worker PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, 409*bb86c7edSAndroid Build Coastguard Worker src_inc0, src_inc1, src_inc2, src_inc3, 410*bb86c7edSAndroid Build Coastguard Worker src_matrix.layout.rows, packed_ptr); 411*bb86c7edSAndroid Build Coastguard Worker } 412*bb86c7edSAndroid Build Coastguard Worker #else 413*bb86c7edSAndroid Build Coastguard Worker (void)tuning; 414*bb86c7edSAndroid Build Coastguard Worker // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc 415*bb86c7edSAndroid Build Coastguard Worker // to save on registers (we have fewer general purpose registers in 416*bb86c7edSAndroid Build Coastguard Worker // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four 417*bb86c7edSAndroid Build Coastguard Worker // values that are each either 16 or 0 and use them directly. For the 418*bb86c7edSAndroid Build Coastguard Worker // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should 419*bb86c7edSAndroid Build Coastguard Worker // use the value 16 (bit is set) or 0 (bit is not set) for the 420*bb86c7edSAndroid Build Coastguard Worker // respective increment value. 421*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc = 0; 422*bb86c7edSAndroid Build Coastguard Worker src_inc += src_inc0 == 16 ? 1 : 0; 423*bb86c7edSAndroid Build Coastguard Worker src_inc += src_inc1 == 16 ? 2 : 0; 424*bb86c7edSAndroid Build Coastguard Worker src_inc += src_inc2 == 16 ? 4 : 0; 425*bb86c7edSAndroid Build Coastguard Worker src_inc += src_inc3 == 16 ? 8 : 0; 426*bb86c7edSAndroid Build Coastguard Worker const int kOutputStride = 32; 427*bb86c7edSAndroid Build Coastguard Worker PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, 428*bb86c7edSAndroid Build Coastguard Worker src_matrix.layout.rows, packed_ptr, 429*bb86c7edSAndroid Build Coastguard Worker kOutputStride); 430*bb86c7edSAndroid Build Coastguard Worker #endif // RUY_PLATFORM_NEON_64 431*bb86c7edSAndroid Build Coastguard Worker } 432*bb86c7edSAndroid Build Coastguard Worker } 433*bb86c7edSAndroid Build Coastguard Worker }; 434*bb86c7edSAndroid Build Coastguard Worker 435*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_32 436*bb86c7edSAndroid Build Coastguard Worker // The 32-bit float kernel is 8 rows X 4 columns, so we need an additional 437*bb86c7edSAndroid Build Coastguard Worker // specialization for a FixedKernelLayout with 4 columns. 438*bb86c7edSAndroid Build Coastguard Worker template <> 439*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float, 440*bb86c7edSAndroid Build Coastguard Worker float, float, Order::kColMajor> { 441*bb86c7edSAndroid Build Coastguard Worker static void Run(Tuning, const Mat<float>& src_matrix, 442*bb86c7edSAndroid Build Coastguard Worker PMat<float>* packed_matrix, int start_col, int end_col) { 443*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(src_matrix.layout)); 444*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(packed_matrix->layout)); 445*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK_EQ(start_col % 4, 0); 446*bb86c7edSAndroid Build Coastguard Worker const float zerobuf[4] = {0}; 447*bb86c7edSAndroid Build Coastguard Worker for (int block_col = start_col; block_col < end_col; block_col += 4) { 448*bb86c7edSAndroid Build Coastguard Worker int src_stride = src_matrix.layout.stride; 449*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; 450*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr1 = src_ptr0 + src_stride; 451*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr2 = src_ptr1 + src_stride; 452*bb86c7edSAndroid Build Coastguard Worker const float* src_ptr3 = src_ptr2 + src_stride; 453*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc0 = 16; 454*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc1 = 16; 455*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc2 = 16; 456*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc3 = 16; 457*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 3) { 458*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 0) { 459*bb86c7edSAndroid Build Coastguard Worker src_ptr0 = zerobuf; 460*bb86c7edSAndroid Build Coastguard Worker src_inc0 = 0; 461*bb86c7edSAndroid Build Coastguard Worker } 462*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 1) { 463*bb86c7edSAndroid Build Coastguard Worker src_ptr1 = zerobuf; 464*bb86c7edSAndroid Build Coastguard Worker src_inc1 = 0; 465*bb86c7edSAndroid Build Coastguard Worker } 466*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 2) { 467*bb86c7edSAndroid Build Coastguard Worker src_ptr2 = zerobuf; 468*bb86c7edSAndroid Build Coastguard Worker src_inc2 = 0; 469*bb86c7edSAndroid Build Coastguard Worker } 470*bb86c7edSAndroid Build Coastguard Worker if (block_col >= src_matrix.layout.cols - 3) { 471*bb86c7edSAndroid Build Coastguard Worker src_ptr3 = zerobuf; 472*bb86c7edSAndroid Build Coastguard Worker src_inc3 = 0; 473*bb86c7edSAndroid Build Coastguard Worker } 474*bb86c7edSAndroid Build Coastguard Worker } 475*bb86c7edSAndroid Build Coastguard Worker float* packed_ptr = 476*bb86c7edSAndroid Build Coastguard Worker packed_matrix->data + packed_matrix->layout.stride * (block_col); 477*bb86c7edSAndroid Build Coastguard Worker // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc 478*bb86c7edSAndroid Build Coastguard Worker // to save registers. 479*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc = 0; 480*bb86c7edSAndroid Build Coastguard Worker src_inc += src_inc0 == 16 ? 1 : 0; 481*bb86c7edSAndroid Build Coastguard Worker src_inc += src_inc1 == 16 ? 2 : 0; 482*bb86c7edSAndroid Build Coastguard Worker src_inc += src_inc2 == 16 ? 4 : 0; 483*bb86c7edSAndroid Build Coastguard Worker src_inc += src_inc3 == 16 ? 8 : 0; 484*bb86c7edSAndroid Build Coastguard Worker const int kOutputStride = 16; 485*bb86c7edSAndroid Build Coastguard Worker PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, 486*bb86c7edSAndroid Build Coastguard Worker src_matrix.layout.rows, packed_ptr, 487*bb86c7edSAndroid Build Coastguard Worker kOutputStride); 488*bb86c7edSAndroid Build Coastguard Worker } 489*bb86c7edSAndroid Build Coastguard Worker } 490*bb86c7edSAndroid Build Coastguard Worker }; 491*bb86c7edSAndroid Build Coastguard Worker #endif // (RUY_PLATFORM_NEON_32) 492*bb86c7edSAndroid Build Coastguard Worker #endif // (RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && \ 493*bb86c7edSAndroid Build Coastguard Worker // RUY_OPT(ASM) 494*bb86c7edSAndroid Build Coastguard Worker 495*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) 496*bb86c7edSAndroid Build Coastguard Worker 497*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar> 498*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>, 499*bb86c7edSAndroid Build Coastguard Worker Scalar, std::int8_t, std::int32_t, Order::kRowMajor> { 500*bb86c7edSAndroid Build Coastguard Worker static_assert(std::is_same<Scalar, std::int8_t>::value || 501*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::uint8_t>::value, 502*bb86c7edSAndroid Build Coastguard Worker ""); 503*bb86c7edSAndroid Build Coastguard Worker static constexpr int kInputXor = 504*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 505*bb86c7edSAndroid Build Coastguard Worker 506*bb86c7edSAndroid Build Coastguard Worker static void Run(Tuning, const Mat<Scalar>& src_matrix, 507*bb86c7edSAndroid Build Coastguard Worker PMat<std::int8_t>* packed_matrix, int start_col, 508*bb86c7edSAndroid Build Coastguard Worker int end_col) { 509*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsRowMajor(src_matrix.layout)); 510*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK(IsColMajor(packed_matrix->layout)); 511*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK_EQ(start_col % 8, 0); 512*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums = packed_matrix->sums; 513*bb86c7edSAndroid Build Coastguard Worker std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); 514*bb86c7edSAndroid Build Coastguard Worker Scalar zerobuf[8]; 515*bb86c7edSAndroid Build Coastguard Worker memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); 516*bb86c7edSAndroid Build Coastguard Worker int src_stride = src_matrix.layout.stride; 517*bb86c7edSAndroid Build Coastguard Worker // As the source matrix is row-major and the destination packed matrix is 518*bb86c7edSAndroid Build Coastguard Worker // column-major, there is no traversal order that will be optimal for both 519*bb86c7edSAndroid Build Coastguard Worker // so we choose to favor the source matrix with a row-major traversal order. 520*bb86c7edSAndroid Build Coastguard Worker // Loop over groups of 4 rows. 521*bb86c7edSAndroid Build Coastguard Worker for (int block_row = 0; block_row < packed_matrix->layout.rows; 522*bb86c7edSAndroid Build Coastguard Worker block_row += 4) { 523*bb86c7edSAndroid Build Coastguard Worker // src_ptr[0-3] shall point to the positions in the 4 rows of the source 524*bb86c7edSAndroid Build Coastguard Worker // matrix that we are loading from, and will be incremented by 525*bb86c7edSAndroid Build Coastguard Worker // src_inc[0-3] after each 4x8 block is loaded. 526*bb86c7edSAndroid Build Coastguard Worker // First we compute these src_ptr and src_inc values for the case where 527*bb86c7edSAndroid Build Coastguard Worker // there are 4 rows left to be loaded from in the source matrix ... 528*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr0 = 529*bb86c7edSAndroid Build Coastguard Worker src_matrix.data.get() + src_stride * block_row + start_col; 530*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr1 = src_ptr0 + src_stride; 531*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr2 = src_ptr1 + src_stride; 532*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr3 = src_ptr2 + src_stride; 533*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc0 = 8; 534*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc1 = 8; 535*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc2 = 8; 536*bb86c7edSAndroid Build Coastguard Worker std::int64_t src_inc3 = 8; 537*bb86c7edSAndroid Build Coastguard Worker // ... and now we adjust these values in case there are fewer than 4 rows 538*bb86c7edSAndroid Build Coastguard Worker // left to load from in the source matrix. In that case, we set the 539*bb86c7edSAndroid Build Coastguard Worker // corresponding src_ptr pointer to load from `zerobuf` and set src_inc 540*bb86c7edSAndroid Build Coastguard Worker // to 0 to avoid overrunning that small buffer. 541*bb86c7edSAndroid Build Coastguard Worker if (block_row >= src_matrix.layout.rows - 3) { 542*bb86c7edSAndroid Build Coastguard Worker if (block_row >= src_matrix.layout.rows - 0) { 543*bb86c7edSAndroid Build Coastguard Worker src_ptr0 = zerobuf; 544*bb86c7edSAndroid Build Coastguard Worker src_inc0 = 0; 545*bb86c7edSAndroid Build Coastguard Worker } 546*bb86c7edSAndroid Build Coastguard Worker if (block_row >= src_matrix.layout.rows - 1) { 547*bb86c7edSAndroid Build Coastguard Worker src_ptr1 = zerobuf; 548*bb86c7edSAndroid Build Coastguard Worker src_inc1 = 0; 549*bb86c7edSAndroid Build Coastguard Worker } 550*bb86c7edSAndroid Build Coastguard Worker if (block_row >= src_matrix.layout.rows - 2) { 551*bb86c7edSAndroid Build Coastguard Worker src_ptr2 = zerobuf; 552*bb86c7edSAndroid Build Coastguard Worker src_inc2 = 0; 553*bb86c7edSAndroid Build Coastguard Worker } 554*bb86c7edSAndroid Build Coastguard Worker if (block_row >= src_matrix.layout.rows - 3) { 555*bb86c7edSAndroid Build Coastguard Worker src_ptr3 = zerobuf; 556*bb86c7edSAndroid Build Coastguard Worker src_inc3 = 0; 557*bb86c7edSAndroid Build Coastguard Worker } 558*bb86c7edSAndroid Build Coastguard Worker } 559*bb86c7edSAndroid Build Coastguard Worker // Let src_cols be the number of source matrix columns to handle. 560*bb86c7edSAndroid Build Coastguard Worker int src_cols = std::min(end_col, src_matrix.layout.cols) - start_col; 561*bb86c7edSAndroid Build Coastguard Worker std::int8_t* packed_ptr = packed_matrix->data + 562*bb86c7edSAndroid Build Coastguard Worker packed_matrix->layout.stride * start_col + 563*bb86c7edSAndroid Build Coastguard Worker 8 * block_row; 564*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums_ptr = sums + start_col; 565*bb86c7edSAndroid Build Coastguard Worker Pack8bitRowMajorForNeonDotprod( 566*bb86c7edSAndroid Build Coastguard Worker src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, src_inc2, 567*bb86c7edSAndroid Build Coastguard Worker src_inc3, src_cols, src_matrix.zero_point, packed_ptr, 568*bb86c7edSAndroid Build Coastguard Worker packed_matrix->layout.stride, sums_ptr, kInputXor); 569*bb86c7edSAndroid Build Coastguard Worker } 570*bb86c7edSAndroid Build Coastguard Worker } 571*bb86c7edSAndroid Build Coastguard Worker }; 572*bb86c7edSAndroid Build Coastguard Worker 573*bb86c7edSAndroid Build Coastguard Worker #endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) 574*bb86c7edSAndroid Build Coastguard Worker 575*bb86c7edSAndroid Build Coastguard Worker #if RUY_PLATFORM_NEON 576*bb86c7edSAndroid Build Coastguard Worker 577*bb86c7edSAndroid Build Coastguard Worker template <typename Scalar, int KernelCols> 578*bb86c7edSAndroid Build Coastguard Worker struct PackImpl<Path::kNeon, 579*bb86c7edSAndroid Build Coastguard Worker FixedKernelLayout<Order::kColMajor, 16, KernelCols>, Scalar, 580*bb86c7edSAndroid Build Coastguard Worker std::int8_t, std::int32_t, Order::kRowMajor> { 581*bb86c7edSAndroid Build Coastguard Worker static void Run(Tuning, const Mat<Scalar>& src_matrix, 582*bb86c7edSAndroid Build Coastguard Worker PMat<std::int8_t>* packed_matrix, int start_col, 583*bb86c7edSAndroid Build Coastguard Worker int end_col) { 584*bb86c7edSAndroid Build Coastguard Worker profiler::ScopeLabel label("Pack (KNeon, from row-major source)"); 585*bb86c7edSAndroid Build Coastguard Worker static constexpr int kInputXor = 586*bb86c7edSAndroid Build Coastguard Worker std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; 587*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); 588*bb86c7edSAndroid Build Coastguard Worker RUY_DCHECK_EQ((end_col - start_col) % KernelCols, 0); 589*bb86c7edSAndroid Build Coastguard Worker std::int32_t* sums = packed_matrix->sums; 590*bb86c7edSAndroid Build Coastguard Worker std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); 591*bb86c7edSAndroid Build Coastguard Worker int block_row = 0; 592*bb86c7edSAndroid Build Coastguard Worker for (; block_row < packed_matrix->layout.rows; block_row += 16) { 593*bb86c7edSAndroid Build Coastguard Worker int src_stride = src_matrix.layout.stride; 594*bb86c7edSAndroid Build Coastguard Worker int packed_stride = packed_matrix->layout.stride; 595*bb86c7edSAndroid Build Coastguard Worker const Scalar* src_ptr = 596*bb86c7edSAndroid Build Coastguard Worker src_matrix.data.get() + block_row * src_stride + start_col; 597*bb86c7edSAndroid Build Coastguard Worker std::int8_t* packed_ptr = packed_matrix->data + 598*bb86c7edSAndroid Build Coastguard Worker start_col * packed_stride + 599*bb86c7edSAndroid Build Coastguard Worker block_row * KernelCols; 600*bb86c7edSAndroid Build Coastguard Worker 601*bb86c7edSAndroid Build Coastguard Worker Pack8bitRowMajorForNeon( 602*bb86c7edSAndroid Build Coastguard Worker reinterpret_cast<const std::uint8_t*>(src_ptr), src_stride, 603*bb86c7edSAndroid Build Coastguard Worker src_matrix.layout.rows, src_matrix.layout.cols, block_row, start_col, 604*bb86c7edSAndroid Build Coastguard Worker end_col, packed_ptr, packed_stride, packed_matrix->zero_point, sums, 605*bb86c7edSAndroid Build Coastguard Worker kInputXor, KernelCols); 606*bb86c7edSAndroid Build Coastguard Worker } 607*bb86c7edSAndroid Build Coastguard Worker } 608*bb86c7edSAndroid Build Coastguard Worker }; 609*bb86c7edSAndroid Build Coastguard Worker #endif 610*bb86c7edSAndroid Build Coastguard Worker 611*bb86c7edSAndroid Build Coastguard Worker } // namespace ruy 612*bb86c7edSAndroid Build Coastguard Worker 613*bb86c7edSAndroid Build Coastguard Worker #endif // RUY_RUY_PACK_ARM_H_ 614