1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2015 Benoit Steiner <[email protected]> 5 // Copyright (C) 2015 Matthew Sarett <[email protected]> 6 // Copyright (C) 2016 Nishant Patil <[email protected]> 7 // 8 // This Source Code Form is subject to the terms of the Mozilla 9 // Public License v. 2.0. If a copy of the MPL was not distributed 10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 11 12 #ifndef CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_ 13 #define CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_ 14 15 namespace Eigen { 16 namespace internal { 17 18 // AVX2 optimized implementation of Mat-Mat product. 19 // LHS is encoded using signed 16-bit integers. 20 // RHS is encoded using signed 16-bit integers. 21 #ifdef EIGEN_USE_OPTIMIZED_INT16_INT16_MAT_MAT_PRODUCT 22 23 // Define quantized traits 24 template <bool _ConjLhs, bool _ConjRhs> 25 class gebp_traits<QInt16, QInt16, _ConjLhs, _ConjRhs> { 26 public: 27 typedef QInt16 LhsScalar; 28 typedef QInt16 RhsScalar; 29 typedef QInt32 ResScalar; 30 31 typedef typename packet_traits<LhsScalar>::type LhsPacket; 32 typedef LhsPacket LhsPacket4Packing; 33 34 enum { 35 // Define register blocking scheme. 36 nr = 16, 37 mr = 16, 38 kr = 4, 39 // Ignore progress tracking per loop iteration. 40 LhsProgress = -1, 41 RhsProgress = -1 42 }; 43 }; 44 45 // Specialized blocking for quantized implementations. 46 // Used by TensorContractionThreadPool, inputs must have dimensions that are 47 // multiples of 32. 48 template <typename Index, int ShardingType> 49 class TensorContractionBlocking<QInt16, QInt16, QInt16, Index, ShardingType> { 50 public: 51 TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) 52 : kc_(((k + 15) / 16) * 16), 53 mc_(((m + 15) / 16) * 16), 54 nc_(((n + 15) / 16) * 16) { 55 eigen_assert(mc_ % 16 == 0); 56 eigen_assert(kc_ % 16 == 0); 57 if (!k || !m || !n) { 58 return; 59 } 60 61 if (ShardingType == ShardByCol) { 62 eigen_assert(nc_ % 16 == 0); 63 nc_ = (((nc_ / num_threads) + 15) / 16) * 16; 64 } else { 65 eigen_assert(nc_ % 16 == 0); 66 mc_ = (((mc_ / num_threads) + 15) / 16) * 16; 67 } 68 } 69 kc()70 EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } mc()71 EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } nc()72 EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } 73 74 private: 75 Index kc_; 76 Index mc_; 77 Index nc_; 78 }; 79 80 // Specialized blocking for quantized implementations. 81 // Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to 82 // multiples of 32. 83 template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor> 84 class gemm_blocking_space<ColMajor, QInt16, QInt16, MaxRows, MaxCols, MaxDepth, 85 KcFactor, false> 86 : public level3_blocking<QInt16, QInt16> { 87 DenseIndex m_sizeA; 88 DenseIndex m_sizeB; 89 90 public: gemm_blocking_space(DenseIndex rows,DenseIndex cols,DenseIndex depth,DenseIndex,bool)91 gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, 92 DenseIndex /*num_threads*/, bool /*l3_blocking*/) { 93 this->m_mc = ((rows + 15) / 16) * 16; 94 this->m_nc = ((cols + 15) / 16) * 16; 95 this->m_kc = ((depth + 15) / 16) * 16; 96 m_sizeA = this->m_mc * this->m_kc; 97 m_sizeB = this->m_kc * this->m_nc; 98 } allocateA()99 void allocateA() { 100 if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt16>(m_sizeA); 101 } allocateB()102 void allocateB() { 103 if (this->m_blockB == 0) this->m_blockB = aligned_new<QInt16>(m_sizeB); 104 } allocateAll()105 void allocateAll() { 106 allocateA(); 107 allocateB(); 108 } ~gemm_blocking_space()109 ~gemm_blocking_space() { 110 aligned_delete(this->m_blockA, m_sizeA); 111 aligned_delete(this->m_blockB, m_sizeB); 112 } 113 }; 114 115 // Below are the fully optimized versions that are correct only for sizes that 116 // are multiple of 16. It is about a 10% performance benefit to keep these 117 // implementations separate. 118 119 // Arrange a block of the left input matrix in contiguous memory. 120 // 121 // Given column major input (A0 beside A1 in memory): 122 // A0 B0 C0 D0 E0 F0 G0 H0 ... 123 // A1 B1 C1 D1 E1 F1 G1 H1 ... 124 // A2 B2 C2 D2 E2 F2 G2 H2 ... 125 // A3 B3 C3 D3 E3 F3 G3 H3 ... 126 // A4 B4 C4 D4 E4 F4 G4 H4 ... 127 // A5 B5 C5 D5 E5 F5 G5 H5 ... 128 // A6 B6 C6 D6 E6 F6 G6 H6 ... 129 // A7 B7 C7 D7 E7 F7 G7 H7 ... 130 // A8 ... 131 // ... 132 // 133 // Packing with m = 8 yields row major output (A0 beside B0 in memory): 134 // A0 B0 135 // A1 B1 136 // A2 B2 137 // A3 B3 138 // A4 B4 139 // A5 B5 140 // A6 B6 141 // A7 B7 142 // ... 143 // 144 // The purpose is to collect m rows of size k. Two elements of the same 145 // row are arranged contiguously because madd performs an adjacent addition 146 // in the kernel. 147 148 template <typename Index, typename DataMapper, int Pack1, int Pack2, 149 bool Conjugate, bool PanelMode> 150 struct gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, QInt16, ColMajor, 151 Conjugate, PanelMode> { 152 EIGEN_DONT_INLINE void operator()(QInt16* blockA, const DataMapper& lhs, 153 Index depth, Index rows, Index stride = 0, 154 Index offset = 0); 155 }; 156 157 template <typename Index, typename DataMapper, int Pack1, int Pack2, 158 bool Conjugate, bool PanelMode> 159 EIGEN_DONT_INLINE void gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, 160 QInt16, ColMajor, Conjugate, PanelMode>:: 161 operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows, 162 Index stride, Index offset) { 163 eigen_assert(stride == 0); 164 eigen_assert(offset == 0); 165 166 typedef typename packet_traits<QInt16>::type Packet; 167 168 // Use alternate function for weird sizes 169 if (rows % 16 != 0 || depth % 16 != 0) { 170 assert(false && 171 "only depths and rows that are a multiple of 16 are currently " 172 "supported"); 173 // gemm_pack_lhs_any<QInt16, Index, DataMapper, Pack1, Pack2, ColMajor, 174 // Conjugate, PanelMode> lhs_pack; 175 // return lhs_pack(blockA, lhs, depth, rows, stride, offset); 176 } 177 178 // Get vector pointer 179 __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); 180 181 // Pack rows in sets of 16 182 for (Index m = 0; m < rows; m += 16) { 183 // Pack depth in sets of 4 184 for (Index k = 0; k < depth; k += 4) { 185 // Load vectors 186 __m256i L_A = lhs.template loadPacket<Packet>(m, k); 187 __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1); 188 __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2); 189 __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3); 190 191 // Rearrange the inputs as required by the kernel 192 __m256i L_AB0_AB7 = _mm256_unpacklo_epi16(L_A, L_B); 193 __m256i L_AB8_AB15 = _mm256_unpackhi_epi16(L_A, L_B); 194 __m256i L_CD0_CD7 = _mm256_unpacklo_epi16(L_C, L_D); 195 __m256i L_CD8_CD15 = _mm256_unpackhi_epi16(L_C, L_D); 196 197 __m256i L_AD0 = _mm256_permute2x128_si256(L_AB0_AB7, L_AB8_AB15, 0x20); 198 _mm256_store_si256(blockA_256++, L_AD0); 199 __m256i L_AD8 = _mm256_permute2x128_si256(L_CD0_CD7, L_CD8_CD15, 0x20); 200 _mm256_store_si256(blockA_256++, L_AD8); 201 __m256i L_AD16 = _mm256_permute2x128_si256(L_AB0_AB7, L_AB8_AB15, 0x31); 202 _mm256_store_si256(blockA_256++, L_AD16); 203 __m256i L_AD24 = _mm256_permute2x128_si256(L_CD0_CD7, L_CD8_CD15, 0x31); 204 _mm256_store_si256(blockA_256++, L_AD24); 205 } 206 } 207 } 208 209 // Arrange a block of the right input matrix in contiguous memory. 210 // 211 // Given column major input (A0 beside A1 in memory): 212 // A0 B0 C0 D0 E0 F0 G0 H0 ... 213 // A1 B1 C1 D1 E1 F1 G1 H1 ... 214 // A2 B2 C2 D2 E2 F2 G2 H2 ... 215 // A3 B3 C3 D3 E3 F3 G3 H3 ... 216 // A4 B4 C4 D4 E4 F4 G4 H4 ... 217 // A5 B5 C5 D5 E5 F5 G5 H5 ... 218 // A6 B6 C6 D6 E6 F6 G6 H6 ... 219 // A7 B7 C7 D7 E7 F7 G7 H7 ... 220 // A8 ... 221 // ... 222 // Packing yields row major output (A0 beside A1 in memory): 223 // A0 A1 A2 A3 A4 A5 A6 A7 224 // B0 B1 B2 B3 B4 B5 B6 B7 225 // ... 226 // 227 // At least two elements of the same col are arranged contiguously because 228 // maddubs and madd both perform an adjacent addition in the kernel. We can 229 // save work by leaving 4 adjacent elements because kr = 4. 230 // The purpose is to collect n cols of size k. Two elements of the same 231 // col are arranged contiguously because madd performs an adjacent addition 232 // in the kernel. 233 template <typename Index, typename DataMapper, int nr, bool Conjugate, 234 bool PanelMode> 235 struct gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor, Conjugate, 236 PanelMode> { 237 EIGEN_DONT_INLINE void operator()(QInt16* blockB, const DataMapper& rhs, 238 Index depth, Index cols, Index stride = 0, 239 Index offset = 0); 240 }; 241 242 template <typename Index, typename DataMapper, int nr, bool Conjugate, 243 bool PanelMode> 244 EIGEN_DONT_INLINE void gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor, 245 Conjugate, PanelMode>:: 246 operator()(QInt16* blockB, const DataMapper& rhs, Index depth, Index cols, 247 Index stride, Index offset) { 248 eigen_assert(stride == 0); 249 eigen_assert(offset == 0); 250 251 typedef typename packet_traits<QInt16>::type Packet; 252 253 // Use alternate function for weird sizes 254 if (cols % 16 != 0 || depth % 16 != 0) { 255 assert(false && 256 "only depths and cols that are a multiple of 16 are currently " 257 "supported"); 258 // gemm_pack_rhs_any<QInt16, Index, DataMapper, nr, ColMajor, Conjugate, 259 // PanelMode> rhs_pack; 260 // return rhs_pack(blockB, rhs, depth, cols, stride, offset); 261 } 262 263 // Get vector pointer 264 __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); 265 266 // Perform a step of the packing for 4 columns 267 __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_4, R_AD_8, R_AD_12; 268 #define PACK_STEP \ 269 R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \ 270 R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \ 271 R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \ 272 R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \ 273 R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \ 274 R_AD_8 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \ 275 R_AD_4 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \ 276 R_AD_12 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \ 277 _mm256_store_si256(blockB_256, R_AD_0); \ 278 _mm256_store_si256(blockB_256 + 4, R_AD_4); \ 279 _mm256_store_si256(blockB_256 + 8, R_AD_8); \ 280 _mm256_store_si256(blockB_256 + 12, R_AD_12); \ 281 blockB_256++; 282 283 // Pack cols in sets of 16 284 for (Index n = 0; n < cols; n += 16) { 285 // Pack depth in sets of 16 286 for (Index k = 0; k < depth; k += 16) { 287 __m256i R_A = rhs.template loadPacket<Packet>(k, n); 288 __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1); 289 __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2); 290 __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3); 291 PACK_STEP; 292 293 R_A = rhs.template loadPacket<Packet>(k, n + 4); 294 R_B = rhs.template loadPacket<Packet>(k, n + 5); 295 R_C = rhs.template loadPacket<Packet>(k, n + 6); 296 R_D = rhs.template loadPacket<Packet>(k, n + 7); 297 PACK_STEP; 298 299 R_A = rhs.template loadPacket<Packet>(k, n + 8); 300 R_B = rhs.template loadPacket<Packet>(k, n + 9); 301 R_C = rhs.template loadPacket<Packet>(k, n + 10); 302 R_D = rhs.template loadPacket<Packet>(k, n + 11); 303 PACK_STEP; 304 305 R_A = rhs.template loadPacket<Packet>(k, n + 12); 306 R_B = rhs.template loadPacket<Packet>(k, n + 13); 307 R_C = rhs.template loadPacket<Packet>(k, n + 14); 308 R_D = rhs.template loadPacket<Packet>(k, n + 15); 309 PACK_STEP; 310 311 blockB_256 += 12; 312 } 313 } 314 #undef PACK_STEP 315 } 316 317 // Perform the actual multiplication on packed inputs 318 template <typename Index, typename DataMapper, int mr, int nr, 319 bool ConjugateLhs, bool ConjugateRhs> 320 struct gebp_kernel<QInt16, QInt16, Index, DataMapper, mr, nr, ConjugateLhs, 321 ConjugateRhs> { 322 typedef typename DataMapper::LinearMapper LinearMapper; 323 324 EIGEN_DONT_INLINE 325 void operator()(const DataMapper& res, const QInt16* blockA, 326 const QInt16* blockB, Index rows, Index depth, Index cols, 327 QInt32 alpha, Index strideA = -1, Index strideB = -1, 328 Index offsetA = 0, Index offsetB = 0); 329 }; 330 331 template <typename Index, typename DataMapper, int mr, int nr, 332 bool ConjugateLhs, bool ConjugateRhs> 333 EIGEN_DONT_INLINE void gebp_kernel<QInt16, QInt16, Index, DataMapper, mr, nr, 334 ConjugateLhs, ConjugateRhs>:: 335 operator()(const DataMapper& res, const QInt16* blockA, const QInt16* blockB, 336 Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, 337 Index strideB, Index offsetA, Index offsetB) { 338 EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); 339 EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); 340 eigen_assert(alpha.value == 1); 341 eigen_assert(strideA == -1); 342 eigen_assert(strideB == -1); 343 eigen_assert(offsetA == 0); 344 eigen_assert(offsetB == 0); 345 eigen_assert(rows > 0); 346 eigen_assert(cols > 0); 347 eigen_assert(depth > 0); 348 eigen_assert(blockA); 349 eigen_assert(blockB); 350 351 // Use alternate function for weird sizes 352 if (rows % 16 != 0 || cols % 16 != 0 || depth % 16 != 0) { 353 assert(false && 354 "only depths, cols and rows that are a multiple of 16 are currently " 355 "supported"); 356 // gebp_kernel_any<QInt16, QInt16, Index, DataMapper, mr, nr, ConjugateLhs, 357 // ConjugateRhs> gebp; 358 // return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, 359 // strideB, offsetA, offsetB); 360 } 361 362 // Create result block 363 QInt32* blockO = aligned_new<QInt32>(16 * 16); 364 memset(blockO, 0, 16 * 16 * sizeof(QInt32)); 365 366 // Get vectorized pointers 367 __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO); 368 const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA); 369 const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB); 370 371 // Loop over blocks of 16 columns 372 for (Index n = 0; n < cols; n += 16) { 373 // Reset index into blockA 374 Index indexL = 0; 375 // Loop over blocks of 16 rows 376 for (Index m = 0; m < rows; m += 16) { 377 // Reset index into blockB 378 Index indexR = n / 16 * depth; 379 // Loop over blocks of 4 on depth 380 for (Index k = 0; k < depth; k += 4) { 381 // Load inputs 382 __m256i L_AD0 = blockA_256[indexL++]; 383 __m256i L_AD8 = blockA_256[indexL++]; 384 __m256i L_EH0 = blockA_256[indexL++]; 385 __m256i L_EH8 = blockA_256[indexL++]; 386 387 __m256i R_AH0 = blockB_256[indexR++]; 388 __m256i R_AH4 = blockB_256[indexR++]; 389 __m256i R_AH8 = blockB_256[indexR++]; 390 __m256i R_AH12 = blockB_256[indexR++]; 391 392 // Declare variables used in COMPUTE_STEP 393 __m256i P_32_A, P_32_B, P_32; 394 395 #define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \ 396 P_32_A = _mm256_madd_epi16(R_INPUT_A, L_AD0); \ 397 P_32_B = _mm256_madd_epi16(R_INPUT_B, L_AD8); \ 398 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 399 _mm256_store_si256( \ 400 blockO_256 + 2 * OFFSET, \ 401 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 2 * OFFSET), P_32)); \ 402 \ 403 P_32_A = _mm256_madd_epi16(R_INPUT_A, L_EH0); \ 404 P_32_B = _mm256_madd_epi16(R_INPUT_B, L_EH8); \ 405 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 406 _mm256_store_si256( \ 407 blockO_256 + 2 * OFFSET + 1, \ 408 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 2 * OFFSET + 1), P_32)); 409 410 // Permute and shuffle to copy a single value across the entire vector 411 // Then compute the multiplication 412 // Replicate lower 128-bits of R_AH0 across both lanes 413 __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00); 414 // Copy first two elements of R_AH0 across entire vector 415 __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 416 // Copy second two elements of R_AH0 across entire vector 417 __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 418 419 COMPUTE_STEP(R_AD0, R_EH0, 0); 420 __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 421 __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 422 COMPUTE_STEP(R_AD1, R_EH1, 1); 423 424 // Replicate upper 128-bits of R_AH0 across both lanes 425 R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11); 426 __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 427 __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 428 COMPUTE_STEP(R_AD2, R_EH2, 2); 429 __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 430 __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 431 COMPUTE_STEP(R_AD3, R_EH3, 3); 432 433 R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00); 434 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 435 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 436 COMPUTE_STEP(R_AD0, R_EH0, 4); 437 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 438 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 439 COMPUTE_STEP(R_AD1, R_EH1, 5); 440 R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11); 441 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 442 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 443 COMPUTE_STEP(R_AD2, R_EH2, 6); 444 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 445 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 446 COMPUTE_STEP(R_AD3, R_EH3, 7); 447 448 R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00); 449 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 450 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 451 COMPUTE_STEP(R_AD0, R_EH0, 8); 452 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 453 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 454 COMPUTE_STEP(R_AD1, R_EH1, 9); 455 R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11); 456 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 457 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 458 COMPUTE_STEP(R_AD2, R_EH2, 10); 459 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 460 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 461 COMPUTE_STEP(R_AD3, R_EH3, 11); 462 463 R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00); 464 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 465 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 466 COMPUTE_STEP(R_AD0, R_EH0, 12); 467 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 468 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 469 COMPUTE_STEP(R_AD1, R_EH1, 13); 470 R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11); 471 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 472 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 473 COMPUTE_STEP(R_AD2, R_EH2, 14); 474 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 475 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 476 COMPUTE_STEP(R_AD3, R_EH3, 15); 477 478 #undef COMPUTE_STEP 479 } 480 481 // Transfer the results to the result matrix 482 Index i = 0; 483 for (Index j = n; j < n + 16; j++) { 484 LinearMapper r0 = res.getLinearMapper(m, j); 485 LinearMapper r1 = res.getLinearMapper(m + 8, j); 486 typedef typename packet_traits<QInt32>::type Packet; 487 r0.template storePacket<Packet>( 488 0, _mm256_add_epi32(blockO_256[i++], 489 r0.template loadPacket<Packet>(0))); 490 r1.template storePacket<Packet>( 491 0, _mm256_add_epi32(blockO_256[i++], 492 r1.template loadPacket<Packet>(0))); 493 } 494 495 // Zero the result block so it can be reused 496 memset(blockO, 0, 16 * 16 * sizeof(QInt32)); 497 } 498 } 499 aligned_delete(blockO, 16 * 16); 500 } 501 502 #endif 503 504 // AVX2 optimized implementation of Mat-Mat product. 505 // LHS is encoded using signed 8-bit integers. 506 // RHS is encoded using unsigned 8-bit integers. 507 #ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT 508 509 // Define quantized traits 510 template <bool _ConjLhs, bool _ConjRhs> 511 class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> { 512 public: 513 typedef QInt8 LhsScalar; 514 typedef QUInt8 RhsScalar; 515 typedef QInt32 ResScalar; 516 517 typedef typename packet_traits<LhsScalar>::type LhsPacket; 518 typedef LhsPacket LhsPacket4Packing; 519 520 enum { 521 // Define register blocking scheme. 522 nr = 32, 523 mr = 32, 524 kr = 8, 525 // Ignore progress tracking per loop iteration. 526 LhsProgress = -1, 527 RhsProgress = -1 528 }; 529 }; 530 531 // Specialized blocking for quantized implementations. 532 // Used by TensorContractionThreadPool, inputs must have dimensions that are 533 // multiples of 32. 534 template <typename ResScalar, typename Index, typename LeftTensor, 535 typename left_nocontract_t, typename left_contract_t, 536 bool left_inner_dim_contiguous, bool left_inner_dim_reordered, 537 int LeftAlignment, typename RightTensor, typename right_nocontract_t, 538 typename right_contract_t, bool right_inner_dim_contiguous, 539 bool right_inner_dim_reordered, int RightAlignment, int ShardingType> 540 class TensorContractionBlocking< 541 ResScalar, 542 TensorContractionInputMapper< 543 QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32, 544 left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>, 545 TensorContractionInputMapper<QUInt8, Index, Rhs, RightTensor, 546 right_nocontract_t, right_contract_t, 32, 547 right_inner_dim_contiguous, 548 right_inner_dim_reordered, RightAlignment>, 549 Index, ShardingType> { 550 public: 551 typedef QInt8 LhsScalar; 552 typedef QUInt8 RhsScalar; 553 554 TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) 555 : kc_(k), mc_(m), nc_(n) { 556 eigen_assert(m % 32 == 0); 557 eigen_assert(k % 32 == 0); 558 if (!k || !m || !n) { 559 return; 560 } 561 562 if (ShardingType == ShardByCol) { 563 eigen_assert(n % 32 == 0); 564 nc_ = (((n / num_threads) + 31) / 32) * 32; 565 } else { 566 eigen_assert(n % 32 == 0 || n == 1); 567 // Special case to avoid breaking the unimplemented matrix-vector case 568 if (n == 1) { 569 nc_ = 32; 570 } 571 mc_ = (((m / num_threads) + 31) / 32) * 32; 572 } 573 } 574 575 EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } 576 EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } 577 EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } 578 579 private: 580 Index kc_; 581 Index mc_; 582 Index nc_; 583 }; 584 585 // Specialized blocking for quantized implementations. 586 // Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to 587 // multiples of 32. 588 template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor> 589 class gemm_blocking_space<ColMajor, QInt8, QInt8, MaxRows, MaxCols, MaxDepth, 590 KcFactor, false> 591 : public level3_blocking<QInt8, QInt8> { 592 DenseIndex m_sizeA; 593 DenseIndex m_sizeB; 594 595 public: 596 gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, 597 DenseIndex /*num_threads*/, bool /*l3_blocking*/) { 598 this->m_mc = ((rows + 31) / 32) * 32; 599 this->m_nc = ((cols + 31) / 32) * 32; 600 this->m_kc = ((depth + 31) / 32) * 32; 601 m_sizeA = this->m_mc * this->m_kc; 602 m_sizeB = this->m_kc * this->m_nc; 603 } 604 void allocateA() { 605 if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA); 606 } 607 void allocateB() { 608 if (this->m_blockB == 0) this->m_blockB = aligned_new<QInt8>(m_sizeB); 609 } 610 void allocateAll() { 611 allocateA(); 612 allocateB(); 613 } 614 ~gemm_blocking_space() { 615 aligned_delete(this->m_blockA, m_sizeA); 616 aligned_delete(this->m_blockB, m_sizeB); 617 } 618 }; 619 620 template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor> 621 class gemm_blocking_space<ColMajor, QInt8, QUInt8, MaxRows, MaxCols, MaxDepth, 622 KcFactor, false> 623 : public level3_blocking<QInt8, QUInt8> { 624 DenseIndex m_sizeA; 625 DenseIndex m_sizeB; 626 627 public: 628 gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth, 629 DenseIndex /*num_threads*/, bool /*l3_blocking*/) { 630 this->m_mc = ((rows + 31) / 32) * 32; 631 this->m_nc = ((cols + 31) / 32) * 32; 632 this->m_kc = ((depth + 31) / 32) * 32; 633 m_sizeA = this->m_mc * this->m_kc; 634 m_sizeB = this->m_kc * this->m_nc; 635 } 636 void allocateA() { 637 if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA); 638 } 639 void allocateB() { 640 if (this->m_blockB == 0) this->m_blockB = aligned_new<QUInt8>(m_sizeB); 641 } 642 void allocateAll() { 643 allocateA(); 644 allocateB(); 645 } 646 ~gemm_blocking_space() { 647 aligned_delete(this->m_blockA, m_sizeA); 648 aligned_delete(this->m_blockB, m_sizeB); 649 } 650 }; 651 652 // Alternate templates for any input sizes 653 template <typename Scalar, typename Index, typename DataMapper, int Pack1, 654 int Pack2, int StorageOrder, bool Conjugate = false, 655 bool PanelMode = false> 656 struct gemm_pack_lhs_any; 657 template <typename Index, typename DataMapper, int Pack1, int Pack2, 658 bool Conjugate, bool PanelMode> 659 struct gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, 660 Conjugate, PanelMode> { 661 EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs, 662 Index depth, Index rows, Index stride = 0, 663 Index offset = 0); 664 }; 665 666 template <typename Scalar, typename Index, typename DataMapper, int nr, 667 int StorageOrder, bool Conjugate = false, bool PanelMode = false> 668 struct gemm_pack_rhs_any; 669 template <typename Index, typename DataMapper, int nr, bool Conjugate, 670 bool PanelMode> 671 struct gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, 672 PanelMode> { 673 EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs, 674 Index depth, Index cols, Index stride = 0, 675 Index offset = 0); 676 }; 677 678 template <typename LhsScalar, typename RhsScalar, typename Index, 679 typename DataMapper, int mr, int nr, bool ConjugateLhs = false, 680 bool ConjugateRhs = false> 681 struct gebp_kernel_any; 682 template <typename Index, typename DataMapper, int mr, int nr, 683 bool ConjugateLhs, bool ConjugateRhs> 684 struct gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, 685 ConjugateRhs> { 686 typedef typename DataMapper::LinearMapper LinearMapper; 687 688 EIGEN_DONT_INLINE 689 void operator()(const DataMapper& res, const QInt8* blockA, 690 const QUInt8* blockB, Index rows, Index depth, Index cols, 691 QInt32 alpha, Index strideA = -1, Index strideB = -1, 692 Index offsetA = 0, Index offsetB = 0); 693 }; 694 695 // Alternate implementations for any input sizes 696 template <typename Index, typename DataMapper, int Pack1, int Pack2, 697 bool Conjugate, bool PanelMode> 698 EIGEN_DONT_INLINE void gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, 699 ColMajor, Conjugate, PanelMode>:: 700 operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, 701 Index stride, Index offset) { 702 eigen_assert(stride == 0); 703 eigen_assert(offset == 0); 704 705 typedef typename packet_traits<QInt8>::type Packet; 706 707 // Get vector pointer 708 __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); 709 710 // Get even multiples of the dimensions 711 Index rows_32 = (rows / 32) * 32; 712 Index depth_8 = (depth / 8) * 8; 713 714 // Get padding for when depth is not a multiple of 32 715 int padding = 0; 716 if (depth % 32 != 0) { 717 int depth_32 = (depth / 32) * 32; 718 int extra_depth = depth - depth_32; 719 int extra_depth_8 = ((extra_depth + 7) / 8) * 8; 720 padding = 32 - extra_depth_8; 721 } 722 723 // Pack rows in sets of 32 724 for (Index m = 0; m < rows_32; m += 32) { 725 // Pack depth in sets of 8 726 for (Index k = 0; k < depth_8; k += 8) { 727 // Load vectors 728 __m256i L_A = lhs.template loadPacket<Packet>(m, k); 729 __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1); 730 731 // Interleave 8-bit elements 732 __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); 733 __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); 734 735 __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2); 736 __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3); 737 __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); 738 __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); 739 740 // Interleave 16-bit elements 741 __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); 742 __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); 743 744 // Use permute before we store to cross 128-bit lanes 745 __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); 746 _mm256_store_si256(blockA_256++, L_AD0); 747 748 // Complete packing for 32 x 8 block 749 __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); 750 __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); 751 __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); 752 __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); 753 _mm256_store_si256(blockA_256++, L_AD8); 754 _mm256_store_si256(blockA_256++, L_AD16); 755 __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); 756 _mm256_store_si256(blockA_256++, L_AD24); 757 __m256i L_E = lhs.template loadPacket<Packet>(m, k + 4); 758 __m256i L_F = lhs.template loadPacket<Packet>(m, k + 5); 759 __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); 760 __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); 761 __m256i L_G = lhs.template loadPacket<Packet>(m, k + 6); 762 __m256i L_H = lhs.template loadPacket<Packet>(m, k + 7); 763 __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); 764 __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); 765 __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); 766 __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); 767 __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); 768 _mm256_store_si256(blockA_256++, L_EH0); 769 __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); 770 __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); 771 __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); 772 __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); 773 _mm256_store_si256(blockA_256++, L_EH8); 774 _mm256_store_si256(blockA_256++, L_EH16); 775 __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); 776 _mm256_store_si256(blockA_256++, L_EH24); 777 } 778 779 // Finish the k dimension, padding with zeros 780 if (depth_8 < depth) { 781 __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; 782 switch (depth - depth_8) { 783 case 1: 784 L_A = lhs.template loadPacket<Packet>(m, depth_8); 785 L_B = _mm256_setzero_si256(); 786 L_C = _mm256_setzero_si256(); 787 L_D = _mm256_setzero_si256(); 788 L_E = _mm256_setzero_si256(); 789 L_F = _mm256_setzero_si256(); 790 L_G = _mm256_setzero_si256(); 791 L_H = _mm256_setzero_si256(); 792 break; 793 case 2: 794 L_A = lhs.template loadPacket<Packet>(m, depth_8); 795 L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1); 796 L_C = _mm256_setzero_si256(); 797 L_D = _mm256_setzero_si256(); 798 L_E = _mm256_setzero_si256(); 799 L_F = _mm256_setzero_si256(); 800 L_G = _mm256_setzero_si256(); 801 L_H = _mm256_setzero_si256(); 802 break; 803 case 3: 804 L_A = lhs.template loadPacket<Packet>(m, depth_8); 805 L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1); 806 L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2); 807 L_D = _mm256_setzero_si256(); 808 L_E = _mm256_setzero_si256(); 809 L_F = _mm256_setzero_si256(); 810 L_G = _mm256_setzero_si256(); 811 L_H = _mm256_setzero_si256(); 812 break; 813 case 4: 814 L_A = lhs.template loadPacket<Packet>(m, depth_8); 815 L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1); 816 L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2); 817 L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3); 818 L_E = _mm256_setzero_si256(); 819 L_F = _mm256_setzero_si256(); 820 L_G = _mm256_setzero_si256(); 821 L_H = _mm256_setzero_si256(); 822 break; 823 case 5: 824 L_A = lhs.template loadPacket<Packet>(m, depth_8); 825 L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1); 826 L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2); 827 L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3); 828 L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4); 829 L_F = _mm256_setzero_si256(); 830 L_G = _mm256_setzero_si256(); 831 L_H = _mm256_setzero_si256(); 832 break; 833 case 6: 834 L_A = lhs.template loadPacket<Packet>(m, depth_8); 835 L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1); 836 L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2); 837 L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3); 838 L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4); 839 L_F = lhs.template loadPacket<Packet>(m, depth_8 + 5); 840 L_G = _mm256_setzero_si256(); 841 L_H = _mm256_setzero_si256(); 842 break; 843 case 7: 844 L_A = lhs.template loadPacket<Packet>(m, depth_8); 845 L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1); 846 L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2); 847 L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3); 848 L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4); 849 L_F = lhs.template loadPacket<Packet>(m, depth_8 + 5); 850 L_G = lhs.template loadPacket<Packet>(m, depth_8 + 6); 851 L_H = _mm256_setzero_si256(); 852 break; 853 } 854 855 // Interleave 8-bit elements 856 __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); 857 __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); 858 859 __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); 860 __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); 861 862 // Interleave 16-bit elements 863 __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); 864 __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); 865 866 // Use permute before we store to cross 128-bit lanes 867 __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); 868 _mm256_store_si256(blockA_256++, L_AD0); 869 870 // Complete packing 871 __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); 872 __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); 873 __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); 874 __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); 875 _mm256_store_si256(blockA_256++, L_AD8); 876 _mm256_store_si256(blockA_256++, L_AD16); 877 __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); 878 _mm256_store_si256(blockA_256++, L_AD24); 879 __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); 880 __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); 881 __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); 882 __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); 883 __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); 884 __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); 885 __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); 886 _mm256_store_si256(blockA_256++, L_EH0); 887 __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); 888 __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); 889 __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); 890 __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); 891 _mm256_store_si256(blockA_256++, L_EH8); 892 _mm256_store_si256(blockA_256++, L_EH16); 893 __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); 894 _mm256_store_si256(blockA_256++, L_EH24); 895 } 896 blockA_256 += padding; 897 } 898 899 // Finish the m dimension, padding with zeros 900 if (rows_32 < rows) { 901 // Pack depth in sets of 8 902 for (Index k = 0; k < depth_8; k += 8) { 903 // Load vectors 904 __m256i L_A = _mm256_setzero_si256(); 905 __m256i L_B = _mm256_setzero_si256(); 906 __m256i L_C = _mm256_setzero_si256(); 907 __m256i L_D = _mm256_setzero_si256(); 908 __m256i L_E = _mm256_setzero_si256(); 909 __m256i L_F = _mm256_setzero_si256(); 910 __m256i L_G = _mm256_setzero_si256(); 911 __m256i L_H = _mm256_setzero_si256(); 912 for (Index m = 0; m < rows - rows_32; m++) { 913 QInt8* ptr = (QInt8*)&L_A; 914 ptr[m] = lhs(rows_32 + m, k); 915 ptr = (QInt8*)&L_B; 916 ptr[m] = lhs(rows_32 + m, k + 1); 917 ptr = (QInt8*)&L_C; 918 ptr[m] = lhs(rows_32 + m, k + 2); 919 ptr = (QInt8*)&L_D; 920 ptr[m] = lhs(rows_32 + m, k + 3); 921 ptr = (QInt8*)&L_E; 922 ptr[m] = lhs(rows_32 + m, k + 4); 923 ptr = (QInt8*)&L_F; 924 ptr[m] = lhs(rows_32 + m, k + 5); 925 ptr = (QInt8*)&L_G; 926 ptr[m] = lhs(rows_32 + m, k + 6); 927 ptr = (QInt8*)&L_H; 928 ptr[m] = lhs(rows_32 + m, k + 7); 929 } 930 931 // Interleave 8-bit elements 932 __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); 933 __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); 934 __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); 935 __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); 936 937 // Interleave 16-bit elements 938 __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); 939 __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); 940 941 // Use permute before we store to cross 128-bit lanes 942 __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); 943 _mm256_store_si256(blockA_256++, L_AD0); 944 945 // Complete packing for 32 x 8 block 946 __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); 947 __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); 948 __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); 949 __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); 950 _mm256_store_si256(blockA_256++, L_AD8); 951 _mm256_store_si256(blockA_256++, L_AD16); 952 __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); 953 _mm256_store_si256(blockA_256++, L_AD24); 954 __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); 955 __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); 956 __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); 957 __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); 958 __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); 959 __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); 960 __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); 961 _mm256_store_si256(blockA_256++, L_EH0); 962 __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); 963 __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); 964 __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); 965 __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); 966 _mm256_store_si256(blockA_256++, L_EH8); 967 _mm256_store_si256(blockA_256++, L_EH16); 968 __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); 969 _mm256_store_si256(blockA_256++, L_EH24); 970 } 971 972 // Finish the k dimension, padding with zeros 973 if (depth_8 < depth) { 974 __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H; 975 QInt8* ptr; 976 switch (depth - depth_8) { 977 case 1: 978 L_A = _mm256_setzero_si256(); 979 L_B = _mm256_setzero_si256(); 980 L_C = _mm256_setzero_si256(); 981 L_D = _mm256_setzero_si256(); 982 L_E = _mm256_setzero_si256(); 983 L_F = _mm256_setzero_si256(); 984 L_G = _mm256_setzero_si256(); 985 L_H = _mm256_setzero_si256(); 986 for (Index m = 0; m < rows - rows_32; m++) { 987 QInt8* ptr = (QInt8*)&L_A; 988 ptr[m] = lhs(rows_32 + m, depth_8); 989 } 990 break; 991 case 2: 992 L_A = _mm256_setzero_si256(); 993 L_B = _mm256_setzero_si256(); 994 L_C = _mm256_setzero_si256(); 995 L_D = _mm256_setzero_si256(); 996 L_E = _mm256_setzero_si256(); 997 L_F = _mm256_setzero_si256(); 998 L_G = _mm256_setzero_si256(); 999 L_H = _mm256_setzero_si256(); 1000 for (Index m = 0; m < rows - rows_32; m++) { 1001 ptr = (QInt8*)&L_A; 1002 ptr[m] = lhs(rows_32 + m, depth_8); 1003 ptr = (QInt8*)&L_B; 1004 ptr[m] = lhs(rows_32 + m, depth_8 + 1); 1005 } 1006 break; 1007 case 3: 1008 L_A = _mm256_setzero_si256(); 1009 L_B = _mm256_setzero_si256(); 1010 L_C = _mm256_setzero_si256(); 1011 L_D = _mm256_setzero_si256(); 1012 L_E = _mm256_setzero_si256(); 1013 L_F = _mm256_setzero_si256(); 1014 L_G = _mm256_setzero_si256(); 1015 L_H = _mm256_setzero_si256(); 1016 for (Index m = 0; m < rows - rows_32; m++) { 1017 ptr = (QInt8*)&L_A; 1018 ptr[m] = lhs(rows_32 + m, depth_8); 1019 ptr = (QInt8*)&L_B; 1020 ptr[m] = lhs(rows_32 + m, depth_8 + 1); 1021 ptr = (QInt8*)&L_C; 1022 ptr[m] = lhs(rows_32 + m, depth_8 + 2); 1023 } 1024 break; 1025 case 4: 1026 L_A = _mm256_setzero_si256(); 1027 L_B = _mm256_setzero_si256(); 1028 L_C = _mm256_setzero_si256(); 1029 L_D = _mm256_setzero_si256(); 1030 L_E = _mm256_setzero_si256(); 1031 L_F = _mm256_setzero_si256(); 1032 L_G = _mm256_setzero_si256(); 1033 L_H = _mm256_setzero_si256(); 1034 for (Index m = 0; m < rows - rows_32; m++) { 1035 ptr = (QInt8*)&L_A; 1036 ptr[m] = lhs(rows_32 + m, depth_8); 1037 ptr = (QInt8*)&L_B; 1038 ptr[m] = lhs(rows_32 + m, depth_8 + 1); 1039 ptr = (QInt8*)&L_C; 1040 ptr[m] = lhs(rows_32 + m, depth_8 + 2); 1041 ptr = (QInt8*)&L_D; 1042 ptr[m] = lhs(rows_32 + m, depth_8 + 3); 1043 } 1044 break; 1045 case 5: 1046 L_A = _mm256_setzero_si256(); 1047 L_B = _mm256_setzero_si256(); 1048 L_C = _mm256_setzero_si256(); 1049 L_D = _mm256_setzero_si256(); 1050 L_E = _mm256_setzero_si256(); 1051 L_F = _mm256_setzero_si256(); 1052 L_G = _mm256_setzero_si256(); 1053 L_H = _mm256_setzero_si256(); 1054 for (Index m = 0; m < rows - rows_32; m++) { 1055 ptr = (QInt8*)&L_A; 1056 ptr[m] = lhs(rows_32 + m, depth_8); 1057 ptr = (QInt8*)&L_B; 1058 ptr[m] = lhs(rows_32 + m, depth_8 + 1); 1059 ptr = (QInt8*)&L_C; 1060 ptr[m] = lhs(rows_32 + m, depth_8 + 2); 1061 ptr = (QInt8*)&L_D; 1062 ptr[m] = lhs(rows_32 + m, depth_8 + 3); 1063 ptr = (QInt8*)&L_E; 1064 ptr[m] = lhs(rows_32 + m, depth_8 + 4); 1065 } 1066 break; 1067 case 6: 1068 L_A = _mm256_setzero_si256(); 1069 L_B = _mm256_setzero_si256(); 1070 L_C = _mm256_setzero_si256(); 1071 L_D = _mm256_setzero_si256(); 1072 L_E = _mm256_setzero_si256(); 1073 L_F = _mm256_setzero_si256(); 1074 L_G = _mm256_setzero_si256(); 1075 L_H = _mm256_setzero_si256(); 1076 for (Index m = 0; m < rows - rows_32; m++) { 1077 ptr = (QInt8*)&L_A; 1078 ptr[m] = lhs(rows_32 + m, depth_8); 1079 ptr = (QInt8*)&L_B; 1080 ptr[m] = lhs(rows_32 + m, depth_8 + 1); 1081 ptr = (QInt8*)&L_C; 1082 ptr[m] = lhs(rows_32 + m, depth_8 + 2); 1083 ptr = (QInt8*)&L_D; 1084 ptr[m] = lhs(rows_32 + m, depth_8 + 3); 1085 ptr = (QInt8*)&L_E; 1086 ptr[m] = lhs(rows_32 + m, depth_8 + 4); 1087 ptr = (QInt8*)&L_F; 1088 ptr[m] = lhs(rows_32 + m, depth_8 + 5); 1089 } 1090 break; 1091 case 7: 1092 L_A = _mm256_setzero_si256(); 1093 L_B = _mm256_setzero_si256(); 1094 L_C = _mm256_setzero_si256(); 1095 L_D = _mm256_setzero_si256(); 1096 L_E = _mm256_setzero_si256(); 1097 L_F = _mm256_setzero_si256(); 1098 L_G = _mm256_setzero_si256(); 1099 L_H = _mm256_setzero_si256(); 1100 for (Index m = 0; m < rows - rows_32; m++) { 1101 ptr = (QInt8*)&L_A; 1102 ptr[m] = lhs(rows_32 + m, depth_8); 1103 ptr = (QInt8*)&L_B; 1104 ptr[m] = lhs(rows_32 + m, depth_8 + 1); 1105 ptr = (QInt8*)&L_C; 1106 ptr[m] = lhs(rows_32 + m, depth_8 + 2); 1107 ptr = (QInt8*)&L_D; 1108 ptr[m] = lhs(rows_32 + m, depth_8 + 3); 1109 ptr = (QInt8*)&L_E; 1110 ptr[m] = lhs(rows_32 + m, depth_8 + 4); 1111 ptr = (QInt8*)&L_F; 1112 ptr[m] = lhs(rows_32 + m, depth_8 + 5); 1113 ptr = (QInt8*)&L_G; 1114 ptr[m] = lhs(rows_32 + m, depth_8 + 6); 1115 } 1116 break; 1117 } 1118 1119 // Interleave 8-bit elements 1120 __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); 1121 __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); 1122 __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); 1123 __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); 1124 1125 // Interleave 16-bit elements 1126 __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); 1127 __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); 1128 1129 // Use permute before we store to cross 128-bit lanes 1130 __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); 1131 _mm256_store_si256(blockA_256++, L_AD0); 1132 1133 // Complete packing 1134 __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); 1135 __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); 1136 __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); 1137 __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); 1138 _mm256_store_si256(blockA_256++, L_AD8); 1139 _mm256_store_si256(blockA_256++, L_AD16); 1140 __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); 1141 _mm256_store_si256(blockA_256++, L_AD24); 1142 __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); 1143 __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); 1144 __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); 1145 __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); 1146 __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); 1147 __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); 1148 __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); 1149 _mm256_store_si256(blockA_256++, L_EH0); 1150 __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); 1151 __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); 1152 __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); 1153 __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); 1154 _mm256_store_si256(blockA_256++, L_EH8); 1155 _mm256_store_si256(blockA_256++, L_EH16); 1156 __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); 1157 _mm256_store_si256(blockA_256++, L_EH24); 1158 } 1159 } 1160 } 1161 1162 template <typename Index, typename DataMapper, int nr, bool Conjugate, 1163 bool PanelMode> 1164 EIGEN_DONT_INLINE void gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, 1165 ColMajor, Conjugate, PanelMode>:: 1166 operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, 1167 Index stride, Index offset) { 1168 eigen_assert(stride == 0); 1169 eigen_assert(offset == 0); 1170 1171 typedef typename packet_traits<QUInt8>::type Packet; 1172 1173 // Get vector pointer 1174 __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); 1175 1176 // Get even multiples of the dimensions 1177 Index cols_32 = (cols / 32) * 32; 1178 Index depth_32 = (depth / 32) * 32; 1179 1180 // Perform a step of the packing for 4 columns 1181 __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24; 1182 #define PACK_STEP \ 1183 R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \ 1184 R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \ 1185 R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \ 1186 R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \ 1187 R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \ 1188 R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \ 1189 R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \ 1190 R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \ 1191 _mm256_store_si256(blockB_256, R_AD_0); \ 1192 _mm256_store_si256(blockB_256 + 8, R_AD_8); \ 1193 _mm256_store_si256(blockB_256 + 16, R_AD_16); \ 1194 _mm256_store_si256(blockB_256 + 24, R_AD_24); \ 1195 blockB_256++; 1196 1197 // Pack cols in sets of 32 1198 for (Index n = 0; n < cols_32; n += 32) { 1199 // Pack depth in sets of 32 1200 for (Index k = 0; k < depth_32; k += 32) { 1201 __m256i R_A = rhs.template loadPacket<Packet>(k, n); 1202 __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1); 1203 __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2); 1204 __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3); 1205 PACK_STEP; 1206 1207 R_A = rhs.template loadPacket<Packet>(k, n + 4); 1208 R_B = rhs.template loadPacket<Packet>(k, n + 5); 1209 R_C = rhs.template loadPacket<Packet>(k, n + 6); 1210 R_D = rhs.template loadPacket<Packet>(k, n + 7); 1211 PACK_STEP; 1212 1213 R_A = rhs.template loadPacket<Packet>(k, n + 8); 1214 R_B = rhs.template loadPacket<Packet>(k, n + 9); 1215 R_C = rhs.template loadPacket<Packet>(k, n + 10); 1216 R_D = rhs.template loadPacket<Packet>(k, n + 11); 1217 PACK_STEP; 1218 1219 R_A = rhs.template loadPacket<Packet>(k, n + 12); 1220 R_B = rhs.template loadPacket<Packet>(k, n + 13); 1221 R_C = rhs.template loadPacket<Packet>(k, n + 14); 1222 R_D = rhs.template loadPacket<Packet>(k, n + 15); 1223 PACK_STEP; 1224 1225 R_A = rhs.template loadPacket<Packet>(k, n + 16); 1226 R_B = rhs.template loadPacket<Packet>(k, n + 17); 1227 R_C = rhs.template loadPacket<Packet>(k, n + 18); 1228 R_D = rhs.template loadPacket<Packet>(k, n + 19); 1229 PACK_STEP; 1230 1231 R_A = rhs.template loadPacket<Packet>(k, n + 20); 1232 R_B = rhs.template loadPacket<Packet>(k, n + 21); 1233 R_C = rhs.template loadPacket<Packet>(k, n + 22); 1234 R_D = rhs.template loadPacket<Packet>(k, n + 23); 1235 PACK_STEP; 1236 1237 R_A = rhs.template loadPacket<Packet>(k, n + 24); 1238 R_B = rhs.template loadPacket<Packet>(k, n + 25); 1239 R_C = rhs.template loadPacket<Packet>(k, n + 26); 1240 R_D = rhs.template loadPacket<Packet>(k, n + 27); 1241 PACK_STEP; 1242 1243 R_A = rhs.template loadPacket<Packet>(k, n + 28); 1244 R_B = rhs.template loadPacket<Packet>(k, n + 29); 1245 R_C = rhs.template loadPacket<Packet>(k, n + 30); 1246 R_D = rhs.template loadPacket<Packet>(k, n + 31); 1247 PACK_STEP; 1248 1249 blockB_256 += 24; 1250 } 1251 1252 if (depth_32 < depth) { 1253 QUInt8* ptr; 1254 __m256i R_A = _mm256_setzero_si256(); 1255 __m256i R_B = _mm256_setzero_si256(); 1256 __m256i R_C = _mm256_setzero_si256(); 1257 __m256i R_D = _mm256_setzero_si256(); 1258 for (Index k = depth_32; k < depth; k++) { 1259 ptr = (QUInt8*)&R_A; 1260 ptr[k - depth_32] = rhs(k, n); 1261 ptr = (QUInt8*)&R_B; 1262 ptr[k - depth_32] = rhs(k, n + 1); 1263 ptr = (QUInt8*)&R_C; 1264 ptr[k - depth_32] = rhs(k, n + 2); 1265 ptr = (QUInt8*)&R_D; 1266 ptr[k - depth_32] = rhs(k, n + 3); 1267 } 1268 PACK_STEP; 1269 1270 R_A = _mm256_setzero_si256(); 1271 R_B = _mm256_setzero_si256(); 1272 R_C = _mm256_setzero_si256(); 1273 R_D = _mm256_setzero_si256(); 1274 for (Index k = depth_32; k < depth; k++) { 1275 ptr = (QUInt8*)&R_A; 1276 ptr[k - depth_32] = rhs(k, n + 4); 1277 ptr = (QUInt8*)&R_B; 1278 ptr[k - depth_32] = rhs(k, n + 5); 1279 ptr = (QUInt8*)&R_C; 1280 ptr[k - depth_32] = rhs(k, n + 6); 1281 ptr = (QUInt8*)&R_D; 1282 ptr[k - depth_32] = rhs(k, n + 7); 1283 } 1284 PACK_STEP; 1285 1286 R_A = _mm256_setzero_si256(); 1287 R_B = _mm256_setzero_si256(); 1288 R_C = _mm256_setzero_si256(); 1289 R_D = _mm256_setzero_si256(); 1290 for (Index k = depth_32; k < depth; k++) { 1291 ptr = (QUInt8*)&R_A; 1292 ptr[k - depth_32] = rhs(k, n + 8); 1293 ptr = (QUInt8*)&R_B; 1294 ptr[k - depth_32] = rhs(k, n + 9); 1295 ptr = (QUInt8*)&R_C; 1296 ptr[k - depth_32] = rhs(k, n + 10); 1297 ptr = (QUInt8*)&R_D; 1298 ptr[k - depth_32] = rhs(k, n + 11); 1299 } 1300 PACK_STEP; 1301 1302 R_A = _mm256_setzero_si256(); 1303 R_B = _mm256_setzero_si256(); 1304 R_C = _mm256_setzero_si256(); 1305 R_D = _mm256_setzero_si256(); 1306 for (Index k = depth_32; k < depth; k++) { 1307 ptr = (QUInt8*)&R_A; 1308 ptr[k - depth_32] = rhs(k, n + 12); 1309 ptr = (QUInt8*)&R_B; 1310 ptr[k - depth_32] = rhs(k, n + 13); 1311 ptr = (QUInt8*)&R_C; 1312 ptr[k - depth_32] = rhs(k, n + 14); 1313 ptr = (QUInt8*)&R_D; 1314 ptr[k - depth_32] = rhs(k, n + 15); 1315 } 1316 PACK_STEP; 1317 1318 R_A = _mm256_setzero_si256(); 1319 R_B = _mm256_setzero_si256(); 1320 R_C = _mm256_setzero_si256(); 1321 R_D = _mm256_setzero_si256(); 1322 for (Index k = depth_32; k < depth; k++) { 1323 ptr = (QUInt8*)&R_A; 1324 ptr[k - depth_32] = rhs(k, n + 16); 1325 ptr = (QUInt8*)&R_B; 1326 ptr[k - depth_32] = rhs(k, n + 17); 1327 ptr = (QUInt8*)&R_C; 1328 ptr[k - depth_32] = rhs(k, n + 18); 1329 ptr = (QUInt8*)&R_D; 1330 ptr[k - depth_32] = rhs(k, n + 19); 1331 } 1332 PACK_STEP; 1333 1334 R_A = _mm256_setzero_si256(); 1335 R_B = _mm256_setzero_si256(); 1336 R_C = _mm256_setzero_si256(); 1337 R_D = _mm256_setzero_si256(); 1338 for (Index k = depth_32; k < depth; k++) { 1339 ptr = (QUInt8*)&R_A; 1340 ptr[k - depth_32] = rhs(k, n + 20); 1341 ptr = (QUInt8*)&R_B; 1342 ptr[k - depth_32] = rhs(k, n + 21); 1343 ptr = (QUInt8*)&R_C; 1344 ptr[k - depth_32] = rhs(k, n + 22); 1345 ptr = (QUInt8*)&R_D; 1346 ptr[k - depth_32] = rhs(k, n + 23); 1347 } 1348 PACK_STEP; 1349 1350 R_A = _mm256_setzero_si256(); 1351 R_B = _mm256_setzero_si256(); 1352 R_C = _mm256_setzero_si256(); 1353 R_D = _mm256_setzero_si256(); 1354 for (Index k = depth_32; k < depth; k++) { 1355 ptr = (QUInt8*)&R_A; 1356 ptr[k - depth_32] = rhs(k, n + 24); 1357 ptr = (QUInt8*)&R_B; 1358 ptr[k - depth_32] = rhs(k, n + 25); 1359 ptr = (QUInt8*)&R_C; 1360 ptr[k - depth_32] = rhs(k, n + 26); 1361 ptr = (QUInt8*)&R_D; 1362 ptr[k - depth_32] = rhs(k, n + 27); 1363 } 1364 PACK_STEP; 1365 1366 R_A = _mm256_setzero_si256(); 1367 R_B = _mm256_setzero_si256(); 1368 R_C = _mm256_setzero_si256(); 1369 R_D = _mm256_setzero_si256(); 1370 for (Index k = depth_32; k < depth; k++) { 1371 ptr = (QUInt8*)&R_A; 1372 ptr[k - depth_32] = rhs(k, n + 28); 1373 ptr = (QUInt8*)&R_B; 1374 ptr[k - depth_32] = rhs(k, n + 29); 1375 ptr = (QUInt8*)&R_C; 1376 ptr[k - depth_32] = rhs(k, n + 30); 1377 ptr = (QUInt8*)&R_D; 1378 ptr[k - depth_32] = rhs(k, n + 31); 1379 } 1380 PACK_STEP; 1381 blockB_256 += 24; 1382 } 1383 } 1384 1385 // Finish packing cols 1386 if (cols_32 < cols) { 1387 // Pack depth in sets of 32 1388 for (Index k = 0; k < depth_32; k += 32) { 1389 __m256i R_A, R_B, R_C, R_D; 1390 Index n; 1391 for (n = cols_32; n < cols; n += 4) { 1392 switch (cols - n) { 1393 case 1: 1394 R_A = rhs.template loadPacket<Packet>(k, n); 1395 R_B = _mm256_setzero_si256(); 1396 R_C = _mm256_setzero_si256(); 1397 R_D = _mm256_setzero_si256(); 1398 PACK_STEP; 1399 break; 1400 case 2: 1401 R_A = rhs.template loadPacket<Packet>(k, n); 1402 R_B = rhs.template loadPacket<Packet>(k, n + 1); 1403 R_C = _mm256_setzero_si256(); 1404 R_D = _mm256_setzero_si256(); 1405 PACK_STEP; 1406 break; 1407 case 3: 1408 R_A = rhs.template loadPacket<Packet>(k, n); 1409 R_B = rhs.template loadPacket<Packet>(k, n + 1); 1410 R_C = rhs.template loadPacket<Packet>(k, n + 2); 1411 R_D = _mm256_setzero_si256(); 1412 PACK_STEP; 1413 break; 1414 default: 1415 R_A = rhs.template loadPacket<Packet>(k, n); 1416 R_B = rhs.template loadPacket<Packet>(k, n + 1); 1417 R_C = rhs.template loadPacket<Packet>(k, n + 2); 1418 R_D = rhs.template loadPacket<Packet>(k, n + 3); 1419 PACK_STEP; 1420 break; 1421 } 1422 } 1423 1424 // Increment the block pointer. 1425 // We must pad if cols is not a multiple of 32. 1426 blockB_256 += 32 - (n - cols_32) / 4; 1427 } 1428 1429 if (depth_32 < depth) { 1430 for (Index n = cols_32; n < cols; n += 4) { 1431 QUInt8* ptr; 1432 __m256i R_A = _mm256_setzero_si256(); 1433 __m256i R_B = _mm256_setzero_si256(); 1434 __m256i R_C = _mm256_setzero_si256(); 1435 __m256i R_D = _mm256_setzero_si256(); 1436 switch (cols - n) { 1437 case 1: 1438 for (Index k = depth_32; k < depth; k++) { 1439 ptr = (QUInt8*)&R_A; 1440 ptr[k - depth_32] = rhs(k, n); 1441 } 1442 PACK_STEP; 1443 break; 1444 case 2: 1445 for (Index k = depth_32; k < depth; k++) { 1446 ptr = (QUInt8*)&R_A; 1447 ptr[k - depth_32] = rhs(k, n); 1448 ptr = (QUInt8*)&R_B; 1449 ptr[k - depth_32] = rhs(k, n + 1); 1450 } 1451 PACK_STEP; 1452 break; 1453 case 3: 1454 for (Index k = depth_32; k < depth; k++) { 1455 ptr = (QUInt8*)&R_A; 1456 ptr[k - depth_32] = rhs(k, n); 1457 ptr = (QUInt8*)&R_B; 1458 ptr[k - depth_32] = rhs(k, n + 1); 1459 ptr = (QUInt8*)&R_C; 1460 ptr[k - depth_32] = rhs(k, n + 2); 1461 } 1462 PACK_STEP; 1463 break; 1464 default: 1465 for (Index k = depth_32; k < depth; k++) { 1466 ptr = (QUInt8*)&R_A; 1467 ptr[k - depth_32] = rhs(k, n); 1468 ptr = (QUInt8*)&R_B; 1469 ptr[k - depth_32] = rhs(k, n + 1); 1470 ptr = (QUInt8*)&R_C; 1471 ptr[k - depth_32] = rhs(k, n + 2); 1472 ptr = (QUInt8*)&R_D; 1473 ptr[k - depth_32] = rhs(k, n + 3); 1474 } 1475 PACK_STEP; 1476 break; 1477 } 1478 } 1479 } 1480 } 1481 #undef PACK_STEP 1482 } 1483 1484 template <typename Index, typename DataMapper, int mr, int nr, 1485 bool ConjugateLhs, bool ConjugateRhs> 1486 EIGEN_DONT_INLINE void gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, 1487 ConjugateLhs, ConjugateRhs>:: 1488 operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, 1489 Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, 1490 Index strideB, Index offsetA, Index offsetB) { 1491 EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); 1492 EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); 1493 eigen_assert(alpha.value == 1); 1494 eigen_assert(strideA == -1); 1495 eigen_assert(strideB == -1); 1496 eigen_assert(offsetA == 0); 1497 eigen_assert(offsetB == 0); 1498 eigen_assert(rows > 0); 1499 eigen_assert(cols > 0); 1500 eigen_assert(depth > 0); 1501 eigen_assert(blockA); 1502 eigen_assert(blockB); 1503 1504 Index rows_32 = ((rows + 31) / 32) * 32; 1505 Index cols_32 = ((cols + 31) / 32) * 32; 1506 Index depth_32 = ((depth + 31) / 32) * 32; 1507 1508 // Create result block 1509 ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0); 1510 memset(blockO, 0, 32 * 32 * sizeof(QInt32)); 1511 1512 // Get vectorized pointers 1513 __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO); 1514 const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA); 1515 const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB); 1516 1517 // Loop over blocks of 32 columns 1518 for (Index n = 0; n < cols_32; n += 32) { 1519 // Reset index into blockA 1520 Index indexL = 0; 1521 // Loop over blocks of 32 rows 1522 for (Index m = 0; m < rows_32; m += 32) { 1523 // Reset index into blockB 1524 Index indexR = n / 32 * depth_32; 1525 // Loop over blocks of 8 on depth 1526 for (Index k = 0; k < depth_32; k += 8) { 1527 // Load inputs 1528 __m256i L_AD0 = blockA_256[indexL++]; 1529 __m256i L_AD8 = blockA_256[indexL++]; 1530 __m256i L_AD16 = blockA_256[indexL++]; 1531 __m256i L_AD24 = blockA_256[indexL++]; 1532 __m256i L_EH0 = blockA_256[indexL++]; 1533 __m256i L_EH8 = blockA_256[indexL++]; 1534 __m256i L_EH16 = blockA_256[indexL++]; 1535 __m256i L_EH24 = blockA_256[indexL++]; 1536 __m256i R_AH0 = blockB_256[indexR++]; 1537 __m256i R_AH4 = blockB_256[indexR++]; 1538 __m256i R_AH8 = blockB_256[indexR++]; 1539 __m256i R_AH12 = blockB_256[indexR++]; 1540 __m256i R_AH16 = blockB_256[indexR++]; 1541 __m256i R_AH20 = blockB_256[indexR++]; 1542 __m256i R_AH24 = blockB_256[indexR++]; 1543 __m256i R_AH28 = blockB_256[indexR++]; 1544 1545 // This constant is used with madd to convert 16 bit to 32 bit 1546 const __m256i ONE = _mm256_set1_epi32(0x00010001); 1547 1548 // Declare variables used in COMPUTE_STEP 1549 __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32; 1550 1551 #define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \ 1552 P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0); \ 1553 P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ 1554 P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0); \ 1555 P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ 1556 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 1557 _mm256_store_si256( \ 1558 blockO_256 + 4 * OFFSET, \ 1559 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32)); \ 1560 \ 1561 P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8); \ 1562 P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ 1563 P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8); \ 1564 P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ 1565 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 1566 _mm256_store_si256( \ 1567 blockO_256 + 4 * OFFSET + 1, \ 1568 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \ 1569 \ 1570 P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16); \ 1571 P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ 1572 P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16); \ 1573 P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ 1574 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 1575 _mm256_store_si256( \ 1576 blockO_256 + 4 * OFFSET + 2, \ 1577 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \ 1578 \ 1579 P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24); \ 1580 P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ 1581 P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24); \ 1582 P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ 1583 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 1584 _mm256_store_si256( \ 1585 blockO_256 + 4 * OFFSET + 3, \ 1586 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32)); 1587 1588 // Permute and shuffle to copy a single value across the entire vector 1589 // Then compute the multiplication 1590 __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00); 1591 __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1592 __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1593 COMPUTE_STEP(R_AD0, R_EH0, 0); 1594 __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1595 __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1596 COMPUTE_STEP(R_AD1, R_EH1, 1); 1597 R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11); 1598 __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1599 __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1600 COMPUTE_STEP(R_AD2, R_EH2, 2); 1601 __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1602 __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1603 COMPUTE_STEP(R_AD3, R_EH3, 3); 1604 1605 R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00); 1606 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1607 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1608 COMPUTE_STEP(R_AD0, R_EH0, 4); 1609 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1610 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1611 COMPUTE_STEP(R_AD1, R_EH1, 5); 1612 R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11); 1613 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1614 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1615 COMPUTE_STEP(R_AD2, R_EH2, 6); 1616 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1617 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1618 COMPUTE_STEP(R_AD3, R_EH3, 7); 1619 1620 R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00); 1621 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1622 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1623 COMPUTE_STEP(R_AD0, R_EH0, 8); 1624 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1625 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1626 COMPUTE_STEP(R_AD1, R_EH1, 9); 1627 R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11); 1628 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1629 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1630 COMPUTE_STEP(R_AD2, R_EH2, 10); 1631 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1632 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1633 COMPUTE_STEP(R_AD3, R_EH3, 11); 1634 1635 R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00); 1636 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1637 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1638 COMPUTE_STEP(R_AD0, R_EH0, 12); 1639 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1640 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1641 COMPUTE_STEP(R_AD1, R_EH1, 13); 1642 R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11); 1643 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1644 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1645 COMPUTE_STEP(R_AD2, R_EH2, 14); 1646 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1647 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1648 COMPUTE_STEP(R_AD3, R_EH3, 15); 1649 1650 R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00); 1651 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1652 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1653 COMPUTE_STEP(R_AD0, R_EH0, 16); 1654 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1655 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1656 COMPUTE_STEP(R_AD1, R_EH1, 17); 1657 R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11); 1658 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1659 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1660 COMPUTE_STEP(R_AD2, R_EH2, 18); 1661 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1662 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1663 COMPUTE_STEP(R_AD3, R_EH3, 19); 1664 1665 R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00); 1666 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1667 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1668 COMPUTE_STEP(R_AD0, R_EH0, 20); 1669 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1670 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1671 COMPUTE_STEP(R_AD1, R_EH1, 21); 1672 R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11); 1673 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1674 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1675 COMPUTE_STEP(R_AD2, R_EH2, 22); 1676 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1677 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1678 COMPUTE_STEP(R_AD3, R_EH3, 23); 1679 1680 R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00); 1681 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1682 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1683 COMPUTE_STEP(R_AD0, R_EH0, 24); 1684 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1685 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1686 COMPUTE_STEP(R_AD1, R_EH1, 25); 1687 R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11); 1688 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1689 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1690 COMPUTE_STEP(R_AD2, R_EH2, 26); 1691 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1692 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1693 COMPUTE_STEP(R_AD3, R_EH3, 27); 1694 1695 R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00); 1696 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1697 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1698 COMPUTE_STEP(R_AD0, R_EH0, 28); 1699 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1700 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1701 COMPUTE_STEP(R_AD1, R_EH1, 29); 1702 R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11); 1703 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 1704 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 1705 COMPUTE_STEP(R_AD2, R_EH2, 30); 1706 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 1707 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 1708 COMPUTE_STEP(R_AD3, R_EH3, 31); 1709 1710 #undef COMPUTE_STEP 1711 } 1712 1713 // Transfer the results to the result matrix. 1714 if (m + 32 <= rows && n + 32 <= cols) { 1715 Index i = 0; 1716 for (Index j = n; j < n + 32; j++) { 1717 LinearMapper r0 = res.getLinearMapper(m, j); 1718 LinearMapper r1 = res.getLinearMapper(m + 8, j); 1719 LinearMapper r2 = res.getLinearMapper(m + 16, j); 1720 LinearMapper r3 = res.getLinearMapper(m + 24, j); 1721 typedef typename packet_traits<QInt32>::type Packet; 1722 r0.template storePacket<Packet>( 1723 0, _mm256_add_epi32(blockO_256[i++], 1724 r0.template loadPacket<Packet>(0))); 1725 r1.template storePacket<Packet>( 1726 0, _mm256_add_epi32(blockO_256[i++], 1727 r1.template loadPacket<Packet>(0))); 1728 r2.template storePacket<Packet>( 1729 0, _mm256_add_epi32(blockO_256[i++], 1730 r2.template loadPacket<Packet>(0))); 1731 r3.template storePacket<Packet>( 1732 0, _mm256_add_epi32(blockO_256[i++], 1733 r3.template loadPacket<Packet>(0))); 1734 } 1735 } else { 1736 for (Index j = n; j < cols; j++) { 1737 for (Index i = m; i < rows; i++) { 1738 res(i, j) = blockO[(j - n) * 32 + (i - m)]; 1739 } 1740 } 1741 } 1742 1743 // Zero the result block so it can be reused 1744 memset(blockO, 0, 32 * 32 * sizeof(QInt32)); 1745 } 1746 } 1747 } 1748 1749 // Below are the fully optimized versions that are correct only for sizes that 1750 // are multiple of 32. It is about a 10% performance benefit to keep these 1751 // implementations separate. 1752 1753 // Arrange a block of the left input matrix in contiguous memory. 1754 // 1755 // Given column major input (A0 beside A1 in memory): 1756 // A0 B0 C0 D0 E0 F0 G0 H0 ... 1757 // A1 B1 C1 D1 E1 F1 G1 H1 ... 1758 // A2 B2 C2 D2 E2 F2 G2 H2 ... 1759 // A3 B3 C3 D3 E3 F3 G3 H3 ... 1760 // A4 B4 C4 D4 E4 F4 G4 H4 ... 1761 // A5 B5 C5 D5 E5 F5 G5 H5 ... 1762 // A6 B6 C6 D6 E6 F6 G6 H6 ... 1763 // A7 B7 C7 D7 E7 F7 G7 H7 ... 1764 // A8 ... 1765 // ... 1766 // 1767 // Packing yields output (A0 beside B0 in memory): 1768 // A0 B0 C0 D0 1769 // A1 B1 C1 D1 1770 // A2 B2 C2 D2 1771 // A3 B3 C3 D3 1772 // A4 B4 C4 D4 1773 // A5 B5 C5 D5 1774 // A6 B6 C6 D6 1775 // A7 B7 C7 D7 1776 // ... 1777 // A31 B31 C31 D31 1778 // E0 F0 G0 H0 1779 // E1 F1 G1 H1 1780 // E2 F2 G2 H2 1781 // E3 F3 G3 H3 1782 // E4 F4 G4 H4 1783 // E5 F5 G5 H5 1784 // E6 F6 G6 H6 1785 // E7 F7 G7 H7 1786 // ... 1787 // 1788 // Four elements of the same row are arranged contiguously because maddubs and 1789 // madd both perform an adjacent addition in the kernel. 1790 template <typename Index, typename DataMapper, int Pack1, int Pack2, 1791 bool Conjugate, bool PanelMode> 1792 struct gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, QInt8, ColMajor, 1793 Conjugate, PanelMode> { 1794 EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs, 1795 Index depth, Index rows, Index stride = 0, 1796 Index offset = 0); 1797 }; 1798 1799 template <typename Index, typename DataMapper, int Pack1, int Pack2, 1800 bool Conjugate, bool PanelMode> 1801 EIGEN_DONT_INLINE void gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, 1802 QInt8, ColMajor, Conjugate, PanelMode>:: 1803 operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows, 1804 Index stride, Index offset) { 1805 eigen_assert(stride == 0); 1806 eigen_assert(offset == 0); 1807 1808 typedef typename packet_traits<QInt8>::type Packet; 1809 1810 // Use alternate function for weird sizes 1811 if (rows % 32 != 0 || depth % 32 != 0) { 1812 gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, 1813 Conjugate, PanelMode> lhs_pack; 1814 return lhs_pack(blockA, lhs, depth, rows, stride, offset); 1815 } 1816 1817 // Get vector pointer 1818 __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA); 1819 1820 // Pack rows in sets of 32 1821 for (Index m = 0; m < rows; m += 32) { 1822 // Pack depth in sets of 8 1823 for (Index k = 0; k < depth; k += 8) { 1824 // Load vectors 1825 __m256i L_A = lhs.template loadPacket<Packet>(m, k); 1826 __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1); 1827 1828 // Interleave 8-bit elements 1829 __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B); 1830 __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B); 1831 1832 __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2); 1833 __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3); 1834 __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D); 1835 __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D); 1836 1837 // Interleave 16-bit elements 1838 __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16); 1839 __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16); 1840 1841 // Use permute before we store to cross 128-bit lanes 1842 __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20); 1843 _mm256_store_si256(blockA_256++, L_AD0); 1844 1845 // Complete packing for 32 x 8 block 1846 __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31); 1847 __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24); 1848 __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24); 1849 __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20); 1850 _mm256_store_si256(blockA_256++, L_AD8); 1851 _mm256_store_si256(blockA_256++, L_AD16); 1852 __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31); 1853 _mm256_store_si256(blockA_256++, L_AD24); 1854 __m256i L_E = lhs.template loadPacket<Packet>(m, k + 4); 1855 __m256i L_F = lhs.template loadPacket<Packet>(m, k + 5); 1856 __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F); 1857 __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F); 1858 __m256i L_G = lhs.template loadPacket<Packet>(m, k + 6); 1859 __m256i L_H = lhs.template loadPacket<Packet>(m, k + 7); 1860 __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H); 1861 __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H); 1862 __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16); 1863 __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16); 1864 __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20); 1865 _mm256_store_si256(blockA_256++, L_EH0); 1866 __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31); 1867 __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24); 1868 __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24); 1869 __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20); 1870 _mm256_store_si256(blockA_256++, L_EH8); 1871 _mm256_store_si256(blockA_256++, L_EH16); 1872 __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31); 1873 _mm256_store_si256(blockA_256++, L_EH24); 1874 } 1875 } 1876 } 1877 1878 // Arrange a block of the right input matrix in contiguous memory. 1879 // 1880 // Given column major input (A0 beside A1 in memory): 1881 // A0 B0 C0 D0 E0 F0 G0 H0 ... 1882 // A1 B1 C1 D1 E1 F1 G1 H1 ... 1883 // A2 B2 C2 D2 E2 F2 G2 H2 ... 1884 // A3 B3 C3 D3 E3 F3 G3 H3 ... 1885 // A4 B4 C4 D4 E4 F4 G4 H4 ... 1886 // A5 B5 C5 D5 E5 F5 G5 H5 ... 1887 // A6 B6 C6 D6 E6 F6 G6 H6 ... 1888 // A7 B7 C7 D7 E7 F7 G7 H7 ... 1889 // A8 ... 1890 // ... 1891 // 1892 // Packing yields row major output (A0 beside A1 in memory): 1893 // A0 A1 A2 A3 A4 A5 A6 A7 1894 // B0 B1 B2 B3 B4 B5 B6 B7 1895 // ... 1896 // 1897 // At least four elements of the same col are arranged contiguously because 1898 // maddubs and madd both perform an adjacent addition in the kernel. We can 1899 // save work by leaving 8 adjacent elements because kr = 8. 1900 template <typename Index, typename DataMapper, int nr, bool Conjugate, 1901 bool PanelMode> 1902 struct gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, 1903 PanelMode> { 1904 EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs, 1905 Index depth, Index cols, Index stride = 0, 1906 Index offset = 0); 1907 }; 1908 1909 template <typename Index, typename DataMapper, int nr, bool Conjugate, 1910 bool PanelMode> 1911 EIGEN_DONT_INLINE void gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor, 1912 Conjugate, PanelMode>:: 1913 operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols, 1914 Index stride, Index offset) { 1915 eigen_assert(stride == 0); 1916 eigen_assert(offset == 0); 1917 1918 typedef typename packet_traits<QUInt8>::type Packet; 1919 1920 // Use alternate function for weird sizes 1921 if (cols % 32 != 0 || depth % 32 != 0) { 1922 gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate, 1923 PanelMode> rhs_pack; 1924 return rhs_pack(blockB, rhs, depth, cols, stride, offset); 1925 } 1926 1927 // Get vector pointer 1928 __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB); 1929 1930 // Perform a step of the packing for 4 columns 1931 __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24; 1932 #define PACK_STEP \ 1933 R_AB_L = _mm256_unpacklo_epi64(R_A, R_B); \ 1934 R_CD_L = _mm256_unpacklo_epi64(R_C, R_D); \ 1935 R_AB_H = _mm256_unpackhi_epi64(R_A, R_B); \ 1936 R_CD_H = _mm256_unpackhi_epi64(R_C, R_D); \ 1937 R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20); \ 1938 R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \ 1939 R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20); \ 1940 R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \ 1941 _mm256_store_si256(blockB_256, R_AD_0); \ 1942 _mm256_store_si256(blockB_256 + 8, R_AD_8); \ 1943 _mm256_store_si256(blockB_256 + 16, R_AD_16); \ 1944 _mm256_store_si256(blockB_256 + 24, R_AD_24); \ 1945 blockB_256++; 1946 1947 // Pack cols in sets of 32 1948 for (Index n = 0; n < cols; n += 32) { 1949 // Pack depth in sets of 32 1950 for (Index k = 0; k < depth; k += 32) { 1951 __m256i R_A = rhs.template loadPacket<Packet>(k, n); 1952 __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1); 1953 __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2); 1954 __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3); 1955 PACK_STEP; 1956 1957 R_A = rhs.template loadPacket<Packet>(k, n + 4); 1958 R_B = rhs.template loadPacket<Packet>(k, n + 5); 1959 R_C = rhs.template loadPacket<Packet>(k, n + 6); 1960 R_D = rhs.template loadPacket<Packet>(k, n + 7); 1961 PACK_STEP; 1962 1963 R_A = rhs.template loadPacket<Packet>(k, n + 8); 1964 R_B = rhs.template loadPacket<Packet>(k, n + 9); 1965 R_C = rhs.template loadPacket<Packet>(k, n + 10); 1966 R_D = rhs.template loadPacket<Packet>(k, n + 11); 1967 PACK_STEP; 1968 1969 R_A = rhs.template loadPacket<Packet>(k, n + 12); 1970 R_B = rhs.template loadPacket<Packet>(k, n + 13); 1971 R_C = rhs.template loadPacket<Packet>(k, n + 14); 1972 R_D = rhs.template loadPacket<Packet>(k, n + 15); 1973 PACK_STEP; 1974 1975 R_A = rhs.template loadPacket<Packet>(k, n + 16); 1976 R_B = rhs.template loadPacket<Packet>(k, n + 17); 1977 R_C = rhs.template loadPacket<Packet>(k, n + 18); 1978 R_D = rhs.template loadPacket<Packet>(k, n + 19); 1979 PACK_STEP; 1980 1981 R_A = rhs.template loadPacket<Packet>(k, n + 20); 1982 R_B = rhs.template loadPacket<Packet>(k, n + 21); 1983 R_C = rhs.template loadPacket<Packet>(k, n + 22); 1984 R_D = rhs.template loadPacket<Packet>(k, n + 23); 1985 PACK_STEP; 1986 1987 R_A = rhs.template loadPacket<Packet>(k, n + 24); 1988 R_B = rhs.template loadPacket<Packet>(k, n + 25); 1989 R_C = rhs.template loadPacket<Packet>(k, n + 26); 1990 R_D = rhs.template loadPacket<Packet>(k, n + 27); 1991 PACK_STEP; 1992 1993 R_A = rhs.template loadPacket<Packet>(k, n + 28); 1994 R_B = rhs.template loadPacket<Packet>(k, n + 29); 1995 R_C = rhs.template loadPacket<Packet>(k, n + 30); 1996 R_D = rhs.template loadPacket<Packet>(k, n + 31); 1997 PACK_STEP; 1998 1999 blockB_256 += 24; 2000 } 2001 } 2002 #undef PACK_STEP 2003 } 2004 2005 // Perform the actual multiplication on packed inputs 2006 template <typename Index, typename DataMapper, int mr, int nr, 2007 bool ConjugateLhs, bool ConjugateRhs> 2008 struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, 2009 ConjugateRhs> { 2010 typedef typename DataMapper::LinearMapper LinearMapper; 2011 2012 EIGEN_DONT_INLINE 2013 void operator()(const DataMapper& res, const QInt8* blockA, 2014 const QUInt8* blockB, Index rows, Index depth, Index cols, 2015 QInt32 alpha, Index strideA = -1, Index strideB = -1, 2016 Index offsetA = 0, Index offsetB = 0); 2017 }; 2018 2019 template <typename Index, typename DataMapper, int mr, int nr, 2020 bool ConjugateLhs, bool ConjugateRhs> 2021 EIGEN_DONT_INLINE void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, 2022 ConjugateLhs, ConjugateRhs>:: 2023 operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB, 2024 Index rows, Index depth, Index cols, QInt32 alpha, Index strideA, 2025 Index strideB, Index offsetA, Index offsetB) { 2026 EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE); 2027 EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE); 2028 eigen_assert(alpha.value == 1); 2029 eigen_assert(strideA == -1); 2030 eigen_assert(strideB == -1); 2031 eigen_assert(offsetA == 0); 2032 eigen_assert(offsetB == 0); 2033 eigen_assert(rows > 0); 2034 eigen_assert(cols > 0); 2035 eigen_assert(depth > 0); 2036 eigen_assert(blockA); 2037 eigen_assert(blockB); 2038 2039 // Use alternate function for weird sizes 2040 if (rows % 32 != 0 || cols % 32 != 0 || depth % 32 != 0) { 2041 gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs, 2042 ConjugateRhs> gebp; 2043 return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, 2044 offsetA, offsetB); 2045 } 2046 2047 // Create result block 2048 QInt32* blockO = aligned_new<QInt32>(32 * 32); 2049 // Allocating the result block is about 5-10% faster than declaring stack 2050 // space. It is unclear why this is the case. 2051 // ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0); 2052 memset(blockO, 0, 32 * 32 * sizeof(QInt32)); 2053 2054 // Get vectorized pointers 2055 __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO); 2056 const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA); 2057 const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB); 2058 2059 // Loop over blocks of 32 columns 2060 for (Index n = 0; n < cols; n += 32) { 2061 // Reset index into blockA 2062 Index indexL = 0; 2063 // Loop over blocks of 32 rows 2064 for (Index m = 0; m < rows; m += 32) { 2065 // Reset index into blockB 2066 Index indexR = n / 32 * depth; 2067 // Loop over blocks of 8 on depth 2068 for (Index k = 0; k < depth; k += 8) { 2069 // Load inputs 2070 __m256i L_AD0 = blockA_256[indexL++]; 2071 __m256i L_AD8 = blockA_256[indexL++]; 2072 __m256i L_AD16 = blockA_256[indexL++]; 2073 __m256i L_AD24 = blockA_256[indexL++]; 2074 __m256i L_EH0 = blockA_256[indexL++]; 2075 __m256i L_EH8 = blockA_256[indexL++]; 2076 __m256i L_EH16 = blockA_256[indexL++]; 2077 __m256i L_EH24 = blockA_256[indexL++]; 2078 __m256i R_AH0 = blockB_256[indexR++]; 2079 __m256i R_AH4 = blockB_256[indexR++]; 2080 __m256i R_AH8 = blockB_256[indexR++]; 2081 __m256i R_AH12 = blockB_256[indexR++]; 2082 __m256i R_AH16 = blockB_256[indexR++]; 2083 __m256i R_AH20 = blockB_256[indexR++]; 2084 __m256i R_AH24 = blockB_256[indexR++]; 2085 __m256i R_AH28 = blockB_256[indexR++]; 2086 2087 // This constant is used with madd to convert 16 bit to 32 bit 2088 const __m256i ONE = _mm256_set1_epi32(0x00010001); 2089 2090 // Declare variables used in COMPUTE_STEP 2091 __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32; 2092 2093 #define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET) \ 2094 P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0); \ 2095 P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ 2096 P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0); \ 2097 P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ 2098 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 2099 _mm256_store_si256( \ 2100 blockO_256 + 4 * OFFSET, \ 2101 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32)); \ 2102 \ 2103 P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8); \ 2104 P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ 2105 P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8); \ 2106 P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ 2107 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 2108 _mm256_store_si256( \ 2109 blockO_256 + 4 * OFFSET + 1, \ 2110 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \ 2111 \ 2112 P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16); \ 2113 P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ 2114 P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16); \ 2115 P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ 2116 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 2117 _mm256_store_si256( \ 2118 blockO_256 + 4 * OFFSET + 2, \ 2119 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \ 2120 \ 2121 P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24); \ 2122 P_32_A = _mm256_madd_epi16(P_16_A, ONE); \ 2123 P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24); \ 2124 P_32_B = _mm256_madd_epi16(P_16_B, ONE); \ 2125 P_32 = _mm256_add_epi32(P_32_A, P_32_B); \ 2126 _mm256_store_si256( \ 2127 blockO_256 + 4 * OFFSET + 3, \ 2128 _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32)); 2129 2130 // Permute and shuffle to copy a single value across the entire vector 2131 // Then compute the multiplication 2132 __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00); 2133 __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2134 __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2135 COMPUTE_STEP(R_AD0, R_EH0, 0); 2136 __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2137 __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2138 COMPUTE_STEP(R_AD1, R_EH1, 1); 2139 R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11); 2140 __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2141 __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2142 COMPUTE_STEP(R_AD2, R_EH2, 2); 2143 __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2144 __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2145 COMPUTE_STEP(R_AD3, R_EH3, 3); 2146 2147 R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00); 2148 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2149 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2150 COMPUTE_STEP(R_AD0, R_EH0, 4); 2151 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2152 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2153 COMPUTE_STEP(R_AD1, R_EH1, 5); 2154 R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11); 2155 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2156 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2157 COMPUTE_STEP(R_AD2, R_EH2, 6); 2158 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2159 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2160 COMPUTE_STEP(R_AD3, R_EH3, 7); 2161 2162 R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00); 2163 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2164 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2165 COMPUTE_STEP(R_AD0, R_EH0, 8); 2166 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2167 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2168 COMPUTE_STEP(R_AD1, R_EH1, 9); 2169 R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11); 2170 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2171 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2172 COMPUTE_STEP(R_AD2, R_EH2, 10); 2173 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2174 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2175 COMPUTE_STEP(R_AD3, R_EH3, 11); 2176 2177 R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00); 2178 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2179 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2180 COMPUTE_STEP(R_AD0, R_EH0, 12); 2181 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2182 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2183 COMPUTE_STEP(R_AD1, R_EH1, 13); 2184 R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11); 2185 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2186 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2187 COMPUTE_STEP(R_AD2, R_EH2, 14); 2188 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2189 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2190 COMPUTE_STEP(R_AD3, R_EH3, 15); 2191 2192 R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00); 2193 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2194 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2195 COMPUTE_STEP(R_AD0, R_EH0, 16); 2196 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2197 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2198 COMPUTE_STEP(R_AD1, R_EH1, 17); 2199 R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11); 2200 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2201 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2202 COMPUTE_STEP(R_AD2, R_EH2, 18); 2203 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2204 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2205 COMPUTE_STEP(R_AD3, R_EH3, 19); 2206 2207 R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00); 2208 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2209 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2210 COMPUTE_STEP(R_AD0, R_EH0, 20); 2211 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2212 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2213 COMPUTE_STEP(R_AD1, R_EH1, 21); 2214 R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11); 2215 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2216 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2217 COMPUTE_STEP(R_AD2, R_EH2, 22); 2218 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2219 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2220 COMPUTE_STEP(R_AD3, R_EH3, 23); 2221 2222 R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00); 2223 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2224 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2225 COMPUTE_STEP(R_AD0, R_EH0, 24); 2226 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2227 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2228 COMPUTE_STEP(R_AD1, R_EH1, 25); 2229 R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11); 2230 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2231 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2232 COMPUTE_STEP(R_AD2, R_EH2, 26); 2233 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2234 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2235 COMPUTE_STEP(R_AD3, R_EH3, 27); 2236 2237 R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00); 2238 R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2239 R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2240 COMPUTE_STEP(R_AD0, R_EH0, 28); 2241 R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2242 R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2243 COMPUTE_STEP(R_AD1, R_EH1, 29); 2244 R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11); 2245 R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00); 2246 R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55); 2247 COMPUTE_STEP(R_AD2, R_EH2, 30); 2248 R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA); 2249 R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF); 2250 COMPUTE_STEP(R_AD3, R_EH3, 31); 2251 2252 #undef COMPUTE_STEP 2253 } 2254 2255 // Transfer the results to the result matrix 2256 Index i = 0; 2257 for (Index j = n; j < n + 32; j++) { 2258 LinearMapper r0 = res.getLinearMapper(m, j); 2259 LinearMapper r1 = res.getLinearMapper(m + 8, j); 2260 LinearMapper r2 = res.getLinearMapper(m + 16, j); 2261 LinearMapper r3 = res.getLinearMapper(m + 24, j); 2262 typedef typename packet_traits<QInt32>::type Packet; 2263 r0.template storePacket<Packet>( 2264 0, _mm256_add_epi32(blockO_256[i++], 2265 r0.template loadPacket<Packet>(0))); 2266 r1.template storePacket<Packet>( 2267 0, _mm256_add_epi32(blockO_256[i++], 2268 r1.template loadPacket<Packet>(0))); 2269 r2.template storePacket<Packet>( 2270 0, _mm256_add_epi32(blockO_256[i++], 2271 r2.template loadPacket<Packet>(0))); 2272 r3.template storePacket<Packet>( 2273 0, _mm256_add_epi32(blockO_256[i++], 2274 r3.template loadPacket<Packet>(0))); 2275 } 2276 2277 // Zero the result block so it can be reused 2278 memset(blockO, 0, 32 * 32 * sizeof(QInt32)); 2279 } 2280 } 2281 aligned_delete(blockO, 32 * 32); 2282 } 2283 2284 #endif // EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT 2285 2286 } // namespace internal 2287 } // namespace Eigen 2288 2289 #endif // CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_ 2290