1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Benoit Steiner <[email protected]>
5 // Copyright (C) 2015 Matthew Sarett <[email protected]>
6 // Copyright (C) 2016 Nishant Patil <[email protected]>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_
13 #define CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_
14 
15 namespace Eigen {
16 namespace internal {
17 
18 // AVX2 optimized implementation of Mat-Mat product.
19 // LHS is encoded using signed 16-bit integers.
20 // RHS is encoded using signed 16-bit integers.
21 #ifdef EIGEN_USE_OPTIMIZED_INT16_INT16_MAT_MAT_PRODUCT
22 
23 // Define quantized traits
24 template <bool _ConjLhs, bool _ConjRhs>
25 class gebp_traits<QInt16, QInt16, _ConjLhs, _ConjRhs> {
26  public:
27   typedef QInt16 LhsScalar;
28   typedef QInt16 RhsScalar;
29   typedef QInt32 ResScalar;
30 
31   typedef typename packet_traits<LhsScalar>::type LhsPacket;
32   typedef LhsPacket LhsPacket4Packing;
33 
34   enum {
35     // Define register blocking scheme.
36     nr = 16,
37     mr = 16,
38     kr = 4,
39     // Ignore progress tracking per loop iteration.
40     LhsProgress = -1,
41     RhsProgress = -1
42   };
43 };
44 
45 // Specialized blocking for quantized implementations.
46 // Used by TensorContractionThreadPool, inputs must have dimensions that are
47 // multiples of 32.
48 template <typename Index, int ShardingType>
49 class TensorContractionBlocking<QInt16, QInt16, QInt16, Index, ShardingType> {
50  public:
51   TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1)
52       : kc_(((k + 15) / 16) * 16),
53         mc_(((m + 15) / 16) * 16),
54         nc_(((n + 15) / 16) * 16) {
55     eigen_assert(mc_ % 16 == 0);
56     eigen_assert(kc_ % 16 == 0);
57     if (!k || !m || !n) {
58       return;
59     }
60 
61     if (ShardingType == ShardByCol) {
62       eigen_assert(nc_ % 16 == 0);
63       nc_ = (((nc_ / num_threads) + 15) / 16) * 16;
64     } else {
65       eigen_assert(nc_ % 16 == 0);
66       mc_ = (((mc_ / num_threads) + 15) / 16) * 16;
67     }
68   }
69 
kc()70   EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
mc()71   EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
nc()72   EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
73 
74  private:
75   Index kc_;
76   Index mc_;
77   Index nc_;
78 };
79 
80 // Specialized blocking for quantized implementations.
81 // Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to
82 // multiples of 32.
83 template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
84 class gemm_blocking_space<ColMajor, QInt16, QInt16, MaxRows, MaxCols, MaxDepth,
85                           KcFactor, false>
86     : public level3_blocking<QInt16, QInt16> {
87   DenseIndex m_sizeA;
88   DenseIndex m_sizeB;
89 
90  public:
gemm_blocking_space(DenseIndex rows,DenseIndex cols,DenseIndex depth,DenseIndex,bool)91   gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth,
92                       DenseIndex /*num_threads*/, bool /*l3_blocking*/) {
93     this->m_mc = ((rows + 15) / 16) * 16;
94     this->m_nc = ((cols + 15) / 16) * 16;
95     this->m_kc = ((depth + 15) / 16) * 16;
96     m_sizeA = this->m_mc * this->m_kc;
97     m_sizeB = this->m_kc * this->m_nc;
98   }
allocateA()99   void allocateA() {
100     if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt16>(m_sizeA);
101   }
allocateB()102   void allocateB() {
103     if (this->m_blockB == 0) this->m_blockB = aligned_new<QInt16>(m_sizeB);
104   }
allocateAll()105   void allocateAll() {
106     allocateA();
107     allocateB();
108   }
~gemm_blocking_space()109   ~gemm_blocking_space() {
110     aligned_delete(this->m_blockA, m_sizeA);
111     aligned_delete(this->m_blockB, m_sizeB);
112   }
113 };
114 
115 // Below are the fully optimized versions that are correct only for sizes that
116 // are multiple of 16.  It is about a 10% performance benefit to keep these
117 // implementations separate.
118 
119 // Arrange a block of the left input matrix in contiguous memory.
120 //
121 // Given column major input (A0 beside A1 in memory):
122 // A0 B0 C0 D0 E0 F0 G0 H0 ...
123 // A1 B1 C1 D1 E1 F1 G1 H1 ...
124 // A2 B2 C2 D2 E2 F2 G2 H2 ...
125 // A3 B3 C3 D3 E3 F3 G3 H3 ...
126 // A4 B4 C4 D4 E4 F4 G4 H4 ...
127 // A5 B5 C5 D5 E5 F5 G5 H5 ...
128 // A6 B6 C6 D6 E6 F6 G6 H6 ...
129 // A7 B7 C7 D7 E7 F7 G7 H7 ...
130 // A8 ...
131 // ...
132 //
133 // Packing with m = 8 yields row major output (A0 beside B0 in memory):
134 // A0 B0
135 // A1 B1
136 // A2 B2
137 // A3 B3
138 // A4 B4
139 // A5 B5
140 // A6 B6
141 // A7 B7
142 // ...
143 //
144 // The purpose is to collect m rows of size k.  Two elements of the same
145 // row are arranged contiguously because madd performs an adjacent addition
146 // in the kernel.
147 
148 template <typename Index, typename DataMapper, int Pack1, int Pack2,
149           bool Conjugate, bool PanelMode>
150 struct gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, QInt16, ColMajor,
151                      Conjugate, PanelMode> {
152   EIGEN_DONT_INLINE void operator()(QInt16* blockA, const DataMapper& lhs,
153                                     Index depth, Index rows, Index stride = 0,
154                                     Index offset = 0);
155 };
156 
157 template <typename Index, typename DataMapper, int Pack1, int Pack2,
158           bool Conjugate, bool PanelMode>
159 EIGEN_DONT_INLINE void gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2,
160                                      QInt16, ColMajor, Conjugate, PanelMode>::
161 operator()(QInt16* blockA, const DataMapper& lhs, Index depth, Index rows,
162            Index stride, Index offset) {
163   eigen_assert(stride == 0);
164   eigen_assert(offset == 0);
165 
166   typedef typename packet_traits<QInt16>::type Packet;
167 
168   // Use alternate function for weird sizes
169   if (rows % 16 != 0 || depth % 16 != 0) {
170     assert(false &&
171            "only depths and rows that are a multiple of 16 are currently "
172            "supported");
173     // gemm_pack_lhs_any<QInt16, Index, DataMapper, Pack1, Pack2, ColMajor,
174     // Conjugate, PanelMode> lhs_pack;
175     // return lhs_pack(blockA, lhs, depth, rows, stride, offset);
176   }
177 
178   // Get vector pointer
179   __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);
180 
181   // Pack rows in sets of 16
182   for (Index m = 0; m < rows; m += 16) {
183     // Pack depth in sets of 4
184     for (Index k = 0; k < depth; k += 4) {
185       // Load vectors
186       __m256i L_A = lhs.template loadPacket<Packet>(m, k);
187       __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);
188       __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
189       __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);
190 
191       // Rearrange the inputs as required by the kernel
192       __m256i L_AB0_AB7 = _mm256_unpacklo_epi16(L_A, L_B);
193       __m256i L_AB8_AB15 = _mm256_unpackhi_epi16(L_A, L_B);
194       __m256i L_CD0_CD7 = _mm256_unpacklo_epi16(L_C, L_D);
195       __m256i L_CD8_CD15 = _mm256_unpackhi_epi16(L_C, L_D);
196 
197       __m256i L_AD0 = _mm256_permute2x128_si256(L_AB0_AB7, L_AB8_AB15, 0x20);
198       _mm256_store_si256(blockA_256++, L_AD0);
199       __m256i L_AD8 = _mm256_permute2x128_si256(L_CD0_CD7, L_CD8_CD15, 0x20);
200       _mm256_store_si256(blockA_256++, L_AD8);
201       __m256i L_AD16 = _mm256_permute2x128_si256(L_AB0_AB7, L_AB8_AB15, 0x31);
202       _mm256_store_si256(blockA_256++, L_AD16);
203       __m256i L_AD24 = _mm256_permute2x128_si256(L_CD0_CD7, L_CD8_CD15, 0x31);
204       _mm256_store_si256(blockA_256++, L_AD24);
205     }
206   }
207 }
208 
209 // Arrange a block of the right input matrix in contiguous memory.
210 //
211 // Given column major input (A0 beside A1 in memory):
212 // A0 B0 C0 D0 E0 F0 G0 H0 ...
213 // A1 B1 C1 D1 E1 F1 G1 H1 ...
214 // A2 B2 C2 D2 E2 F2 G2 H2 ...
215 // A3 B3 C3 D3 E3 F3 G3 H3 ...
216 // A4 B4 C4 D4 E4 F4 G4 H4 ...
217 // A5 B5 C5 D5 E5 F5 G5 H5 ...
218 // A6 B6 C6 D6 E6 F6 G6 H6 ...
219 // A7 B7 C7 D7 E7 F7 G7 H7 ...
220 // A8 ...
221 // ...
222 // Packing yields row major output (A0 beside A1 in memory):
223 // A0 A1 A2 A3 A4 A5 A6 A7
224 // B0 B1 B2 B3 B4 B5 B6 B7
225 // ...
226 //
227 // At least two elements of the same col are arranged contiguously because
228 // maddubs and madd both perform an adjacent addition in the kernel.  We can
229 // save work by leaving 4 adjacent elements because kr = 4.
230 // The purpose is to collect n cols of size k.  Two elements of the same
231 // col are arranged contiguously because madd performs an adjacent addition
232 // in the kernel.
233 template <typename Index, typename DataMapper, int nr, bool Conjugate,
234           bool PanelMode>
235 struct gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor, Conjugate,
236                      PanelMode> {
237   EIGEN_DONT_INLINE void operator()(QInt16* blockB, const DataMapper& rhs,
238                                     Index depth, Index cols, Index stride = 0,
239                                     Index offset = 0);
240 };
241 
242 template <typename Index, typename DataMapper, int nr, bool Conjugate,
243           bool PanelMode>
244 EIGEN_DONT_INLINE void gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor,
245                                      Conjugate, PanelMode>::
246 operator()(QInt16* blockB, const DataMapper& rhs, Index depth, Index cols,
247            Index stride, Index offset) {
248   eigen_assert(stride == 0);
249   eigen_assert(offset == 0);
250 
251   typedef typename packet_traits<QInt16>::type Packet;
252 
253   // Use alternate function for weird sizes
254   if (cols % 16 != 0 || depth % 16 != 0) {
255     assert(false &&
256            "only depths and cols that are a multiple of 16 are currently "
257            "supported");
258     // gemm_pack_rhs_any<QInt16, Index, DataMapper, nr, ColMajor, Conjugate,
259     // PanelMode> rhs_pack;
260     // return rhs_pack(blockB, rhs, depth, cols, stride, offset);
261   }
262 
263   // Get vector pointer
264   __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);
265 
266   // Perform a step of the packing for 4 columns
267   __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_4, R_AD_8, R_AD_12;
268 #define PACK_STEP                                            \
269   R_AB_L = _mm256_unpacklo_epi64(R_A, R_B);                  \
270   R_CD_L = _mm256_unpacklo_epi64(R_C, R_D);                  \
271   R_AB_H = _mm256_unpackhi_epi64(R_A, R_B);                  \
272   R_CD_H = _mm256_unpackhi_epi64(R_C, R_D);                  \
273   R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20);  \
274   R_AD_8 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31);  \
275   R_AD_4 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20);  \
276   R_AD_12 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \
277   _mm256_store_si256(blockB_256, R_AD_0);                    \
278   _mm256_store_si256(blockB_256 + 4, R_AD_4);                \
279   _mm256_store_si256(blockB_256 + 8, R_AD_8);                \
280   _mm256_store_si256(blockB_256 + 12, R_AD_12);              \
281   blockB_256++;
282 
283   // Pack cols in sets of 16
284   for (Index n = 0; n < cols; n += 16) {
285     // Pack depth in sets of 16
286     for (Index k = 0; k < depth; k += 16) {
287       __m256i R_A = rhs.template loadPacket<Packet>(k, n);
288       __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
289       __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
290       __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
291       PACK_STEP;
292 
293       R_A = rhs.template loadPacket<Packet>(k, n + 4);
294       R_B = rhs.template loadPacket<Packet>(k, n + 5);
295       R_C = rhs.template loadPacket<Packet>(k, n + 6);
296       R_D = rhs.template loadPacket<Packet>(k, n + 7);
297       PACK_STEP;
298 
299       R_A = rhs.template loadPacket<Packet>(k, n + 8);
300       R_B = rhs.template loadPacket<Packet>(k, n + 9);
301       R_C = rhs.template loadPacket<Packet>(k, n + 10);
302       R_D = rhs.template loadPacket<Packet>(k, n + 11);
303       PACK_STEP;
304 
305       R_A = rhs.template loadPacket<Packet>(k, n + 12);
306       R_B = rhs.template loadPacket<Packet>(k, n + 13);
307       R_C = rhs.template loadPacket<Packet>(k, n + 14);
308       R_D = rhs.template loadPacket<Packet>(k, n + 15);
309       PACK_STEP;
310 
311       blockB_256 += 12;
312     }
313   }
314 #undef PACK_STEP
315 }
316 
317 // Perform the actual multiplication on packed inputs
318 template <typename Index, typename DataMapper, int mr, int nr,
319           bool ConjugateLhs, bool ConjugateRhs>
320 struct gebp_kernel<QInt16, QInt16, Index, DataMapper, mr, nr, ConjugateLhs,
321                    ConjugateRhs> {
322   typedef typename DataMapper::LinearMapper LinearMapper;
323 
324   EIGEN_DONT_INLINE
325   void operator()(const DataMapper& res, const QInt16* blockA,
326                   const QInt16* blockB, Index rows, Index depth, Index cols,
327                   QInt32 alpha, Index strideA = -1, Index strideB = -1,
328                   Index offsetA = 0, Index offsetB = 0);
329 };
330 
331 template <typename Index, typename DataMapper, int mr, int nr,
332           bool ConjugateLhs, bool ConjugateRhs>
333 EIGEN_DONT_INLINE void gebp_kernel<QInt16, QInt16, Index, DataMapper, mr, nr,
334                                    ConjugateLhs, ConjugateRhs>::
335 operator()(const DataMapper& res, const QInt16* blockA, const QInt16* blockB,
336            Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
337            Index strideB, Index offsetA, Index offsetB) {
338   EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
339   EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
340   eigen_assert(alpha.value == 1);
341   eigen_assert(strideA == -1);
342   eigen_assert(strideB == -1);
343   eigen_assert(offsetA == 0);
344   eigen_assert(offsetB == 0);
345   eigen_assert(rows > 0);
346   eigen_assert(cols > 0);
347   eigen_assert(depth > 0);
348   eigen_assert(blockA);
349   eigen_assert(blockB);
350 
351   // Use alternate function for weird sizes
352   if (rows % 16 != 0 || cols % 16 != 0 || depth % 16 != 0) {
353     assert(false &&
354            "only depths, cols and rows that are a multiple of 16 are currently "
355            "supported");
356     // gebp_kernel_any<QInt16, QInt16, Index, DataMapper, mr, nr, ConjugateLhs,
357     // ConjugateRhs> gebp;
358     // return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA,
359     // strideB, offsetA, offsetB);
360   }
361 
362   // Create result block
363   QInt32* blockO = aligned_new<QInt32>(16 * 16);
364   memset(blockO, 0, 16 * 16 * sizeof(QInt32));
365 
366   // Get vectorized pointers
367   __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO);
368   const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA);
369   const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB);
370 
371   // Loop over blocks of 16 columns
372   for (Index n = 0; n < cols; n += 16) {
373     // Reset index into blockA
374     Index indexL = 0;
375     // Loop over blocks of 16 rows
376     for (Index m = 0; m < rows; m += 16) {
377       // Reset index into blockB
378       Index indexR = n / 16 * depth;
379       // Loop over blocks of 4 on depth
380       for (Index k = 0; k < depth; k += 4) {
381         // Load inputs
382         __m256i L_AD0 = blockA_256[indexL++];
383         __m256i L_AD8 = blockA_256[indexL++];
384         __m256i L_EH0 = blockA_256[indexL++];
385         __m256i L_EH8 = blockA_256[indexL++];
386 
387         __m256i R_AH0 = blockB_256[indexR++];
388         __m256i R_AH4 = blockB_256[indexR++];
389         __m256i R_AH8 = blockB_256[indexR++];
390         __m256i R_AH12 = blockB_256[indexR++];
391 
392         // Declare variables used in COMPUTE_STEP
393         __m256i P_32_A, P_32_B, P_32;
394 
395 #define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET)                         \
396   P_32_A = _mm256_madd_epi16(R_INPUT_A, L_AD0);                            \
397   P_32_B = _mm256_madd_epi16(R_INPUT_B, L_AD8);                            \
398   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                 \
399   _mm256_store_si256(                                                      \
400       blockO_256 + 2 * OFFSET,                                             \
401       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 2 * OFFSET), P_32)); \
402                                                                            \
403   P_32_A = _mm256_madd_epi16(R_INPUT_A, L_EH0);                            \
404   P_32_B = _mm256_madd_epi16(R_INPUT_B, L_EH8);                            \
405   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                 \
406   _mm256_store_si256(                                                      \
407       blockO_256 + 2 * OFFSET + 1,                                         \
408       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 2 * OFFSET + 1), P_32));
409 
410         // Permute and shuffle to copy a single value across the entire vector
411         // Then compute the multiplication
412         // Replicate lower 128-bits of R_AH0 across both lanes
413         __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00);
414         // Copy first two elements of R_AH0 across entire vector
415         __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
416         // Copy second two elements of R_AH0 across entire vector
417         __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
418 
419         COMPUTE_STEP(R_AD0, R_EH0, 0);
420         __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
421         __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
422         COMPUTE_STEP(R_AD1, R_EH1, 1);
423 
424         // Replicate upper 128-bits of R_AH0 across both lanes
425         R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11);
426         __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
427         __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
428         COMPUTE_STEP(R_AD2, R_EH2, 2);
429         __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
430         __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
431         COMPUTE_STEP(R_AD3, R_EH3, 3);
432 
433         R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00);
434         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
435         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
436         COMPUTE_STEP(R_AD0, R_EH0, 4);
437         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
438         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
439         COMPUTE_STEP(R_AD1, R_EH1, 5);
440         R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11);
441         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
442         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
443         COMPUTE_STEP(R_AD2, R_EH2, 6);
444         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
445         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
446         COMPUTE_STEP(R_AD3, R_EH3, 7);
447 
448         R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00);
449         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
450         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
451         COMPUTE_STEP(R_AD0, R_EH0, 8);
452         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
453         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
454         COMPUTE_STEP(R_AD1, R_EH1, 9);
455         R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11);
456         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
457         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
458         COMPUTE_STEP(R_AD2, R_EH2, 10);
459         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
460         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
461         COMPUTE_STEP(R_AD3, R_EH3, 11);
462 
463         R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00);
464         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
465         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
466         COMPUTE_STEP(R_AD0, R_EH0, 12);
467         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
468         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
469         COMPUTE_STEP(R_AD1, R_EH1, 13);
470         R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11);
471         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
472         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
473         COMPUTE_STEP(R_AD2, R_EH2, 14);
474         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
475         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
476         COMPUTE_STEP(R_AD3, R_EH3, 15);
477 
478 #undef COMPUTE_STEP
479       }
480 
481       // Transfer the results to the result matrix
482       Index i = 0;
483       for (Index j = n; j < n + 16; j++) {
484         LinearMapper r0 = res.getLinearMapper(m, j);
485         LinearMapper r1 = res.getLinearMapper(m + 8, j);
486         typedef typename packet_traits<QInt32>::type Packet;
487         r0.template storePacket<Packet>(
488             0, _mm256_add_epi32(blockO_256[i++],
489                                 r0.template loadPacket<Packet>(0)));
490         r1.template storePacket<Packet>(
491             0, _mm256_add_epi32(blockO_256[i++],
492                                 r1.template loadPacket<Packet>(0)));
493       }
494 
495       // Zero the result block so it can be reused
496       memset(blockO, 0, 16 * 16 * sizeof(QInt32));
497     }
498   }
499   aligned_delete(blockO, 16 * 16);
500 }
501 
502 #endif
503 
504 // AVX2 optimized implementation of Mat-Mat product.
505 // LHS is encoded using signed 8-bit integers.
506 // RHS is encoded using unsigned 8-bit integers.
507 #ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
508 
509 // Define quantized traits
510 template <bool _ConjLhs, bool _ConjRhs>
511 class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> {
512  public:
513   typedef QInt8 LhsScalar;
514   typedef QUInt8 RhsScalar;
515   typedef QInt32 ResScalar;
516 
517   typedef typename packet_traits<LhsScalar>::type LhsPacket;
518   typedef LhsPacket LhsPacket4Packing;
519 
520   enum {
521     // Define register blocking scheme.
522     nr = 32,
523     mr = 32,
524     kr = 8,
525     // Ignore progress tracking per loop iteration.
526     LhsProgress = -1,
527     RhsProgress = -1
528   };
529 };
530 
531 // Specialized blocking for quantized implementations.
532 // Used by TensorContractionThreadPool, inputs must have dimensions that are
533 // multiples of 32.
534 template <typename ResScalar, typename Index, typename LeftTensor,
535           typename left_nocontract_t, typename left_contract_t,
536           bool left_inner_dim_contiguous, bool left_inner_dim_reordered,
537           int LeftAlignment, typename RightTensor, typename right_nocontract_t,
538           typename right_contract_t, bool right_inner_dim_contiguous,
539           bool right_inner_dim_reordered, int RightAlignment, int ShardingType>
540 class TensorContractionBlocking<
541     ResScalar,
542     TensorContractionInputMapper<
543         QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32,
544         left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>,
545     TensorContractionInputMapper<QUInt8, Index, Rhs, RightTensor,
546                                  right_nocontract_t, right_contract_t, 32,
547                                  right_inner_dim_contiguous,
548                                  right_inner_dim_reordered, RightAlignment>,
549     Index, ShardingType> {
550  public:
551   typedef QInt8 LhsScalar;
552   typedef QUInt8 RhsScalar;
553 
554   TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1)
555       : kc_(k), mc_(m), nc_(n) {
556     eigen_assert(m % 32 == 0);
557     eigen_assert(k % 32 == 0);
558     if (!k || !m || !n) {
559       return;
560     }
561 
562     if (ShardingType == ShardByCol) {
563       eigen_assert(n % 32 == 0);
564       nc_ = (((n / num_threads) + 31) / 32) * 32;
565     } else {
566       eigen_assert(n % 32 == 0 || n == 1);
567       // Special case to avoid breaking the unimplemented matrix-vector case
568       if (n == 1) {
569         nc_ = 32;
570       }
571       mc_ = (((m / num_threads) + 31) / 32) * 32;
572     }
573   }
574 
575   EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
576   EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
577   EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
578 
579  private:
580   Index kc_;
581   Index mc_;
582   Index nc_;
583 };
584 
585 // Specialized blocking for quantized implementations.
586 // Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to
587 // multiples of 32.
588 template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
589 class gemm_blocking_space<ColMajor, QInt8, QInt8, MaxRows, MaxCols, MaxDepth,
590                           KcFactor, false>
591     : public level3_blocking<QInt8, QInt8> {
592   DenseIndex m_sizeA;
593   DenseIndex m_sizeB;
594 
595  public:
596   gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth,
597                       DenseIndex /*num_threads*/, bool /*l3_blocking*/) {
598     this->m_mc = ((rows + 31) / 32) * 32;
599     this->m_nc = ((cols + 31) / 32) * 32;
600     this->m_kc = ((depth + 31) / 32) * 32;
601     m_sizeA = this->m_mc * this->m_kc;
602     m_sizeB = this->m_kc * this->m_nc;
603   }
604   void allocateA() {
605     if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA);
606   }
607   void allocateB() {
608     if (this->m_blockB == 0) this->m_blockB = aligned_new<QInt8>(m_sizeB);
609   }
610   void allocateAll() {
611     allocateA();
612     allocateB();
613   }
614   ~gemm_blocking_space() {
615     aligned_delete(this->m_blockA, m_sizeA);
616     aligned_delete(this->m_blockB, m_sizeB);
617   }
618 };
619 
620 template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
621 class gemm_blocking_space<ColMajor, QInt8, QUInt8, MaxRows, MaxCols, MaxDepth,
622                           KcFactor, false>
623     : public level3_blocking<QInt8, QUInt8> {
624   DenseIndex m_sizeA;
625   DenseIndex m_sizeB;
626 
627  public:
628   gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth,
629                       DenseIndex /*num_threads*/, bool /*l3_blocking*/) {
630     this->m_mc = ((rows + 31) / 32) * 32;
631     this->m_nc = ((cols + 31) / 32) * 32;
632     this->m_kc = ((depth + 31) / 32) * 32;
633     m_sizeA = this->m_mc * this->m_kc;
634     m_sizeB = this->m_kc * this->m_nc;
635   }
636   void allocateA() {
637     if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA);
638   }
639   void allocateB() {
640     if (this->m_blockB == 0) this->m_blockB = aligned_new<QUInt8>(m_sizeB);
641   }
642   void allocateAll() {
643     allocateA();
644     allocateB();
645   }
646   ~gemm_blocking_space() {
647     aligned_delete(this->m_blockA, m_sizeA);
648     aligned_delete(this->m_blockB, m_sizeB);
649   }
650 };
651 
652 // Alternate templates for any input sizes
653 template <typename Scalar, typename Index, typename DataMapper, int Pack1,
654           int Pack2, int StorageOrder, bool Conjugate = false,
655           bool PanelMode = false>
656 struct gemm_pack_lhs_any;
657 template <typename Index, typename DataMapper, int Pack1, int Pack2,
658           bool Conjugate, bool PanelMode>
659 struct gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor,
660                          Conjugate, PanelMode> {
661   EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs,
662                                     Index depth, Index rows, Index stride = 0,
663                                     Index offset = 0);
664 };
665 
666 template <typename Scalar, typename Index, typename DataMapper, int nr,
667           int StorageOrder, bool Conjugate = false, bool PanelMode = false>
668 struct gemm_pack_rhs_any;
669 template <typename Index, typename DataMapper, int nr, bool Conjugate,
670           bool PanelMode>
671 struct gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
672                          PanelMode> {
673   EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs,
674                                     Index depth, Index cols, Index stride = 0,
675                                     Index offset = 0);
676 };
677 
678 template <typename LhsScalar, typename RhsScalar, typename Index,
679           typename DataMapper, int mr, int nr, bool ConjugateLhs = false,
680           bool ConjugateRhs = false>
681 struct gebp_kernel_any;
682 template <typename Index, typename DataMapper, int mr, int nr,
683           bool ConjugateLhs, bool ConjugateRhs>
684 struct gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
685                        ConjugateRhs> {
686   typedef typename DataMapper::LinearMapper LinearMapper;
687 
688   EIGEN_DONT_INLINE
689   void operator()(const DataMapper& res, const QInt8* blockA,
690                   const QUInt8* blockB, Index rows, Index depth, Index cols,
691                   QInt32 alpha, Index strideA = -1, Index strideB = -1,
692                   Index offsetA = 0, Index offsetB = 0);
693 };
694 
695 // Alternate implementations for any input sizes
696 template <typename Index, typename DataMapper, int Pack1, int Pack2,
697           bool Conjugate, bool PanelMode>
698 EIGEN_DONT_INLINE void gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2,
699                                          ColMajor, Conjugate, PanelMode>::
700 operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows,
701            Index stride, Index offset) {
702   eigen_assert(stride == 0);
703   eigen_assert(offset == 0);
704 
705   typedef typename packet_traits<QInt8>::type Packet;
706 
707   // Get vector pointer
708   __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);
709 
710   // Get even multiples of the dimensions
711   Index rows_32 = (rows / 32) * 32;
712   Index depth_8 = (depth / 8) * 8;
713 
714   // Get padding for when depth is not a multiple of 32
715   int padding = 0;
716   if (depth % 32 != 0) {
717     int depth_32 = (depth / 32) * 32;
718     int extra_depth = depth - depth_32;
719     int extra_depth_8 = ((extra_depth + 7) / 8) * 8;
720     padding = 32 - extra_depth_8;
721   }
722 
723   // Pack rows in sets of 32
724   for (Index m = 0; m < rows_32; m += 32) {
725     // Pack depth in sets of 8
726     for (Index k = 0; k < depth_8; k += 8) {
727       // Load vectors
728       __m256i L_A = lhs.template loadPacket<Packet>(m, k);
729       __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);
730 
731       // Interleave 8-bit elements
732       __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
733       __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
734 
735       __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
736       __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);
737       __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
738       __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
739 
740       // Interleave 16-bit elements
741       __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
742       __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
743 
744       // Use permute before we store to cross 128-bit lanes
745       __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
746       _mm256_store_si256(blockA_256++, L_AD0);
747 
748       // Complete packing for 32 x 8 block
749       __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
750       __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
751       __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
752       __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
753       _mm256_store_si256(blockA_256++, L_AD8);
754       _mm256_store_si256(blockA_256++, L_AD16);
755       __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
756       _mm256_store_si256(blockA_256++, L_AD24);
757       __m256i L_E = lhs.template loadPacket<Packet>(m, k + 4);
758       __m256i L_F = lhs.template loadPacket<Packet>(m, k + 5);
759       __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
760       __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
761       __m256i L_G = lhs.template loadPacket<Packet>(m, k + 6);
762       __m256i L_H = lhs.template loadPacket<Packet>(m, k + 7);
763       __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
764       __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
765       __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
766       __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
767       __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
768       _mm256_store_si256(blockA_256++, L_EH0);
769       __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
770       __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
771       __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
772       __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
773       _mm256_store_si256(blockA_256++, L_EH8);
774       _mm256_store_si256(blockA_256++, L_EH16);
775       __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
776       _mm256_store_si256(blockA_256++, L_EH24);
777     }
778 
779     // Finish the k dimension, padding with zeros
780     if (depth_8 < depth) {
781       __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H;
782       switch (depth - depth_8) {
783         case 1:
784           L_A = lhs.template loadPacket<Packet>(m, depth_8);
785           L_B = _mm256_setzero_si256();
786           L_C = _mm256_setzero_si256();
787           L_D = _mm256_setzero_si256();
788           L_E = _mm256_setzero_si256();
789           L_F = _mm256_setzero_si256();
790           L_G = _mm256_setzero_si256();
791           L_H = _mm256_setzero_si256();
792           break;
793         case 2:
794           L_A = lhs.template loadPacket<Packet>(m, depth_8);
795           L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
796           L_C = _mm256_setzero_si256();
797           L_D = _mm256_setzero_si256();
798           L_E = _mm256_setzero_si256();
799           L_F = _mm256_setzero_si256();
800           L_G = _mm256_setzero_si256();
801           L_H = _mm256_setzero_si256();
802           break;
803         case 3:
804           L_A = lhs.template loadPacket<Packet>(m, depth_8);
805           L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
806           L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
807           L_D = _mm256_setzero_si256();
808           L_E = _mm256_setzero_si256();
809           L_F = _mm256_setzero_si256();
810           L_G = _mm256_setzero_si256();
811           L_H = _mm256_setzero_si256();
812           break;
813         case 4:
814           L_A = lhs.template loadPacket<Packet>(m, depth_8);
815           L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
816           L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
817           L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
818           L_E = _mm256_setzero_si256();
819           L_F = _mm256_setzero_si256();
820           L_G = _mm256_setzero_si256();
821           L_H = _mm256_setzero_si256();
822           break;
823         case 5:
824           L_A = lhs.template loadPacket<Packet>(m, depth_8);
825           L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
826           L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
827           L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
828           L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
829           L_F = _mm256_setzero_si256();
830           L_G = _mm256_setzero_si256();
831           L_H = _mm256_setzero_si256();
832           break;
833         case 6:
834           L_A = lhs.template loadPacket<Packet>(m, depth_8);
835           L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
836           L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
837           L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
838           L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
839           L_F = lhs.template loadPacket<Packet>(m, depth_8 + 5);
840           L_G = _mm256_setzero_si256();
841           L_H = _mm256_setzero_si256();
842           break;
843         case 7:
844           L_A = lhs.template loadPacket<Packet>(m, depth_8);
845           L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
846           L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
847           L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
848           L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
849           L_F = lhs.template loadPacket<Packet>(m, depth_8 + 5);
850           L_G = lhs.template loadPacket<Packet>(m, depth_8 + 6);
851           L_H = _mm256_setzero_si256();
852           break;
853       }
854 
855       // Interleave 8-bit elements
856       __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
857       __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
858 
859       __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
860       __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
861 
862       // Interleave 16-bit elements
863       __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
864       __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
865 
866       // Use permute before we store to cross 128-bit lanes
867       __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
868       _mm256_store_si256(blockA_256++, L_AD0);
869 
870       // Complete packing
871       __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
872       __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
873       __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
874       __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
875       _mm256_store_si256(blockA_256++, L_AD8);
876       _mm256_store_si256(blockA_256++, L_AD16);
877       __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
878       _mm256_store_si256(blockA_256++, L_AD24);
879       __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
880       __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
881       __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
882       __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
883       __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
884       __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
885       __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
886       _mm256_store_si256(blockA_256++, L_EH0);
887       __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
888       __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
889       __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
890       __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
891       _mm256_store_si256(blockA_256++, L_EH8);
892       _mm256_store_si256(blockA_256++, L_EH16);
893       __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
894       _mm256_store_si256(blockA_256++, L_EH24);
895     }
896     blockA_256 += padding;
897   }
898 
899   // Finish the m dimension, padding with zeros
900   if (rows_32 < rows) {
901     // Pack depth in sets of 8
902     for (Index k = 0; k < depth_8; k += 8) {
903       // Load vectors
904       __m256i L_A = _mm256_setzero_si256();
905       __m256i L_B = _mm256_setzero_si256();
906       __m256i L_C = _mm256_setzero_si256();
907       __m256i L_D = _mm256_setzero_si256();
908       __m256i L_E = _mm256_setzero_si256();
909       __m256i L_F = _mm256_setzero_si256();
910       __m256i L_G = _mm256_setzero_si256();
911       __m256i L_H = _mm256_setzero_si256();
912       for (Index m = 0; m < rows - rows_32; m++) {
913         QInt8* ptr = (QInt8*)&L_A;
914         ptr[m] = lhs(rows_32 + m, k);
915         ptr = (QInt8*)&L_B;
916         ptr[m] = lhs(rows_32 + m, k + 1);
917         ptr = (QInt8*)&L_C;
918         ptr[m] = lhs(rows_32 + m, k + 2);
919         ptr = (QInt8*)&L_D;
920         ptr[m] = lhs(rows_32 + m, k + 3);
921         ptr = (QInt8*)&L_E;
922         ptr[m] = lhs(rows_32 + m, k + 4);
923         ptr = (QInt8*)&L_F;
924         ptr[m] = lhs(rows_32 + m, k + 5);
925         ptr = (QInt8*)&L_G;
926         ptr[m] = lhs(rows_32 + m, k + 6);
927         ptr = (QInt8*)&L_H;
928         ptr[m] = lhs(rows_32 + m, k + 7);
929       }
930 
931       // Interleave 8-bit elements
932       __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
933       __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
934       __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
935       __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
936 
937       // Interleave 16-bit elements
938       __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
939       __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
940 
941       // Use permute before we store to cross 128-bit lanes
942       __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
943       _mm256_store_si256(blockA_256++, L_AD0);
944 
945       // Complete packing for 32 x 8 block
946       __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
947       __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
948       __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
949       __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
950       _mm256_store_si256(blockA_256++, L_AD8);
951       _mm256_store_si256(blockA_256++, L_AD16);
952       __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
953       _mm256_store_si256(blockA_256++, L_AD24);
954       __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
955       __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
956       __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
957       __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
958       __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
959       __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
960       __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
961       _mm256_store_si256(blockA_256++, L_EH0);
962       __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
963       __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
964       __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
965       __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
966       _mm256_store_si256(blockA_256++, L_EH8);
967       _mm256_store_si256(blockA_256++, L_EH16);
968       __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
969       _mm256_store_si256(blockA_256++, L_EH24);
970     }
971 
972     // Finish the k dimension, padding with zeros
973     if (depth_8 < depth) {
974       __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H;
975       QInt8* ptr;
976       switch (depth - depth_8) {
977         case 1:
978           L_A = _mm256_setzero_si256();
979           L_B = _mm256_setzero_si256();
980           L_C = _mm256_setzero_si256();
981           L_D = _mm256_setzero_si256();
982           L_E = _mm256_setzero_si256();
983           L_F = _mm256_setzero_si256();
984           L_G = _mm256_setzero_si256();
985           L_H = _mm256_setzero_si256();
986           for (Index m = 0; m < rows - rows_32; m++) {
987             QInt8* ptr = (QInt8*)&L_A;
988             ptr[m] = lhs(rows_32 + m, depth_8);
989           }
990           break;
991         case 2:
992           L_A = _mm256_setzero_si256();
993           L_B = _mm256_setzero_si256();
994           L_C = _mm256_setzero_si256();
995           L_D = _mm256_setzero_si256();
996           L_E = _mm256_setzero_si256();
997           L_F = _mm256_setzero_si256();
998           L_G = _mm256_setzero_si256();
999           L_H = _mm256_setzero_si256();
1000           for (Index m = 0; m < rows - rows_32; m++) {
1001             ptr = (QInt8*)&L_A;
1002             ptr[m] = lhs(rows_32 + m, depth_8);
1003             ptr = (QInt8*)&L_B;
1004             ptr[m] = lhs(rows_32 + m, depth_8 + 1);
1005           }
1006           break;
1007         case 3:
1008           L_A = _mm256_setzero_si256();
1009           L_B = _mm256_setzero_si256();
1010           L_C = _mm256_setzero_si256();
1011           L_D = _mm256_setzero_si256();
1012           L_E = _mm256_setzero_si256();
1013           L_F = _mm256_setzero_si256();
1014           L_G = _mm256_setzero_si256();
1015           L_H = _mm256_setzero_si256();
1016           for (Index m = 0; m < rows - rows_32; m++) {
1017             ptr = (QInt8*)&L_A;
1018             ptr[m] = lhs(rows_32 + m, depth_8);
1019             ptr = (QInt8*)&L_B;
1020             ptr[m] = lhs(rows_32 + m, depth_8 + 1);
1021             ptr = (QInt8*)&L_C;
1022             ptr[m] = lhs(rows_32 + m, depth_8 + 2);
1023           }
1024           break;
1025         case 4:
1026           L_A = _mm256_setzero_si256();
1027           L_B = _mm256_setzero_si256();
1028           L_C = _mm256_setzero_si256();
1029           L_D = _mm256_setzero_si256();
1030           L_E = _mm256_setzero_si256();
1031           L_F = _mm256_setzero_si256();
1032           L_G = _mm256_setzero_si256();
1033           L_H = _mm256_setzero_si256();
1034           for (Index m = 0; m < rows - rows_32; m++) {
1035             ptr = (QInt8*)&L_A;
1036             ptr[m] = lhs(rows_32 + m, depth_8);
1037             ptr = (QInt8*)&L_B;
1038             ptr[m] = lhs(rows_32 + m, depth_8 + 1);
1039             ptr = (QInt8*)&L_C;
1040             ptr[m] = lhs(rows_32 + m, depth_8 + 2);
1041             ptr = (QInt8*)&L_D;
1042             ptr[m] = lhs(rows_32 + m, depth_8 + 3);
1043           }
1044           break;
1045         case 5:
1046           L_A = _mm256_setzero_si256();
1047           L_B = _mm256_setzero_si256();
1048           L_C = _mm256_setzero_si256();
1049           L_D = _mm256_setzero_si256();
1050           L_E = _mm256_setzero_si256();
1051           L_F = _mm256_setzero_si256();
1052           L_G = _mm256_setzero_si256();
1053           L_H = _mm256_setzero_si256();
1054           for (Index m = 0; m < rows - rows_32; m++) {
1055             ptr = (QInt8*)&L_A;
1056             ptr[m] = lhs(rows_32 + m, depth_8);
1057             ptr = (QInt8*)&L_B;
1058             ptr[m] = lhs(rows_32 + m, depth_8 + 1);
1059             ptr = (QInt8*)&L_C;
1060             ptr[m] = lhs(rows_32 + m, depth_8 + 2);
1061             ptr = (QInt8*)&L_D;
1062             ptr[m] = lhs(rows_32 + m, depth_8 + 3);
1063             ptr = (QInt8*)&L_E;
1064             ptr[m] = lhs(rows_32 + m, depth_8 + 4);
1065           }
1066           break;
1067         case 6:
1068           L_A = _mm256_setzero_si256();
1069           L_B = _mm256_setzero_si256();
1070           L_C = _mm256_setzero_si256();
1071           L_D = _mm256_setzero_si256();
1072           L_E = _mm256_setzero_si256();
1073           L_F = _mm256_setzero_si256();
1074           L_G = _mm256_setzero_si256();
1075           L_H = _mm256_setzero_si256();
1076           for (Index m = 0; m < rows - rows_32; m++) {
1077             ptr = (QInt8*)&L_A;
1078             ptr[m] = lhs(rows_32 + m, depth_8);
1079             ptr = (QInt8*)&L_B;
1080             ptr[m] = lhs(rows_32 + m, depth_8 + 1);
1081             ptr = (QInt8*)&L_C;
1082             ptr[m] = lhs(rows_32 + m, depth_8 + 2);
1083             ptr = (QInt8*)&L_D;
1084             ptr[m] = lhs(rows_32 + m, depth_8 + 3);
1085             ptr = (QInt8*)&L_E;
1086             ptr[m] = lhs(rows_32 + m, depth_8 + 4);
1087             ptr = (QInt8*)&L_F;
1088             ptr[m] = lhs(rows_32 + m, depth_8 + 5);
1089           }
1090           break;
1091         case 7:
1092           L_A = _mm256_setzero_si256();
1093           L_B = _mm256_setzero_si256();
1094           L_C = _mm256_setzero_si256();
1095           L_D = _mm256_setzero_si256();
1096           L_E = _mm256_setzero_si256();
1097           L_F = _mm256_setzero_si256();
1098           L_G = _mm256_setzero_si256();
1099           L_H = _mm256_setzero_si256();
1100           for (Index m = 0; m < rows - rows_32; m++) {
1101             ptr = (QInt8*)&L_A;
1102             ptr[m] = lhs(rows_32 + m, depth_8);
1103             ptr = (QInt8*)&L_B;
1104             ptr[m] = lhs(rows_32 + m, depth_8 + 1);
1105             ptr = (QInt8*)&L_C;
1106             ptr[m] = lhs(rows_32 + m, depth_8 + 2);
1107             ptr = (QInt8*)&L_D;
1108             ptr[m] = lhs(rows_32 + m, depth_8 + 3);
1109             ptr = (QInt8*)&L_E;
1110             ptr[m] = lhs(rows_32 + m, depth_8 + 4);
1111             ptr = (QInt8*)&L_F;
1112             ptr[m] = lhs(rows_32 + m, depth_8 + 5);
1113             ptr = (QInt8*)&L_G;
1114             ptr[m] = lhs(rows_32 + m, depth_8 + 6);
1115           }
1116           break;
1117       }
1118 
1119       // Interleave 8-bit elements
1120       __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
1121       __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
1122       __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
1123       __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
1124 
1125       // Interleave 16-bit elements
1126       __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
1127       __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
1128 
1129       // Use permute before we store to cross 128-bit lanes
1130       __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
1131       _mm256_store_si256(blockA_256++, L_AD0);
1132 
1133       // Complete packing
1134       __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
1135       __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
1136       __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
1137       __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
1138       _mm256_store_si256(blockA_256++, L_AD8);
1139       _mm256_store_si256(blockA_256++, L_AD16);
1140       __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
1141       _mm256_store_si256(blockA_256++, L_AD24);
1142       __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
1143       __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
1144       __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
1145       __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
1146       __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
1147       __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
1148       __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
1149       _mm256_store_si256(blockA_256++, L_EH0);
1150       __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
1151       __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
1152       __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
1153       __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
1154       _mm256_store_si256(blockA_256++, L_EH8);
1155       _mm256_store_si256(blockA_256++, L_EH16);
1156       __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
1157       _mm256_store_si256(blockA_256++, L_EH24);
1158     }
1159   }
1160 }
1161 
1162 template <typename Index, typename DataMapper, int nr, bool Conjugate,
1163           bool PanelMode>
1164 EIGEN_DONT_INLINE void gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr,
1165                                          ColMajor, Conjugate, PanelMode>::
1166 operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols,
1167            Index stride, Index offset) {
1168   eigen_assert(stride == 0);
1169   eigen_assert(offset == 0);
1170 
1171   typedef typename packet_traits<QUInt8>::type Packet;
1172 
1173   // Get vector pointer
1174   __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);
1175 
1176   // Get even multiples of the dimensions
1177   Index cols_32 = (cols / 32) * 32;
1178   Index depth_32 = (depth / 32) * 32;
1179 
1180   // Perform a step of the packing for 4 columns
1181   __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24;
1182 #define PACK_STEP                                            \
1183   R_AB_L = _mm256_unpacklo_epi64(R_A, R_B);                  \
1184   R_CD_L = _mm256_unpacklo_epi64(R_C, R_D);                  \
1185   R_AB_H = _mm256_unpackhi_epi64(R_A, R_B);                  \
1186   R_CD_H = _mm256_unpackhi_epi64(R_C, R_D);                  \
1187   R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20);  \
1188   R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \
1189   R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20);  \
1190   R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \
1191   _mm256_store_si256(blockB_256, R_AD_0);                    \
1192   _mm256_store_si256(blockB_256 + 8, R_AD_8);                \
1193   _mm256_store_si256(blockB_256 + 16, R_AD_16);              \
1194   _mm256_store_si256(blockB_256 + 24, R_AD_24);              \
1195   blockB_256++;
1196 
1197   // Pack cols in sets of 32
1198   for (Index n = 0; n < cols_32; n += 32) {
1199     // Pack depth in sets of 32
1200     for (Index k = 0; k < depth_32; k += 32) {
1201       __m256i R_A = rhs.template loadPacket<Packet>(k, n);
1202       __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
1203       __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
1204       __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
1205       PACK_STEP;
1206 
1207       R_A = rhs.template loadPacket<Packet>(k, n + 4);
1208       R_B = rhs.template loadPacket<Packet>(k, n + 5);
1209       R_C = rhs.template loadPacket<Packet>(k, n + 6);
1210       R_D = rhs.template loadPacket<Packet>(k, n + 7);
1211       PACK_STEP;
1212 
1213       R_A = rhs.template loadPacket<Packet>(k, n + 8);
1214       R_B = rhs.template loadPacket<Packet>(k, n + 9);
1215       R_C = rhs.template loadPacket<Packet>(k, n + 10);
1216       R_D = rhs.template loadPacket<Packet>(k, n + 11);
1217       PACK_STEP;
1218 
1219       R_A = rhs.template loadPacket<Packet>(k, n + 12);
1220       R_B = rhs.template loadPacket<Packet>(k, n + 13);
1221       R_C = rhs.template loadPacket<Packet>(k, n + 14);
1222       R_D = rhs.template loadPacket<Packet>(k, n + 15);
1223       PACK_STEP;
1224 
1225       R_A = rhs.template loadPacket<Packet>(k, n + 16);
1226       R_B = rhs.template loadPacket<Packet>(k, n + 17);
1227       R_C = rhs.template loadPacket<Packet>(k, n + 18);
1228       R_D = rhs.template loadPacket<Packet>(k, n + 19);
1229       PACK_STEP;
1230 
1231       R_A = rhs.template loadPacket<Packet>(k, n + 20);
1232       R_B = rhs.template loadPacket<Packet>(k, n + 21);
1233       R_C = rhs.template loadPacket<Packet>(k, n + 22);
1234       R_D = rhs.template loadPacket<Packet>(k, n + 23);
1235       PACK_STEP;
1236 
1237       R_A = rhs.template loadPacket<Packet>(k, n + 24);
1238       R_B = rhs.template loadPacket<Packet>(k, n + 25);
1239       R_C = rhs.template loadPacket<Packet>(k, n + 26);
1240       R_D = rhs.template loadPacket<Packet>(k, n + 27);
1241       PACK_STEP;
1242 
1243       R_A = rhs.template loadPacket<Packet>(k, n + 28);
1244       R_B = rhs.template loadPacket<Packet>(k, n + 29);
1245       R_C = rhs.template loadPacket<Packet>(k, n + 30);
1246       R_D = rhs.template loadPacket<Packet>(k, n + 31);
1247       PACK_STEP;
1248 
1249       blockB_256 += 24;
1250     }
1251 
1252     if (depth_32 < depth) {
1253       QUInt8* ptr;
1254       __m256i R_A = _mm256_setzero_si256();
1255       __m256i R_B = _mm256_setzero_si256();
1256       __m256i R_C = _mm256_setzero_si256();
1257       __m256i R_D = _mm256_setzero_si256();
1258       for (Index k = depth_32; k < depth; k++) {
1259         ptr = (QUInt8*)&R_A;
1260         ptr[k - depth_32] = rhs(k, n);
1261         ptr = (QUInt8*)&R_B;
1262         ptr[k - depth_32] = rhs(k, n + 1);
1263         ptr = (QUInt8*)&R_C;
1264         ptr[k - depth_32] = rhs(k, n + 2);
1265         ptr = (QUInt8*)&R_D;
1266         ptr[k - depth_32] = rhs(k, n + 3);
1267       }
1268       PACK_STEP;
1269 
1270       R_A = _mm256_setzero_si256();
1271       R_B = _mm256_setzero_si256();
1272       R_C = _mm256_setzero_si256();
1273       R_D = _mm256_setzero_si256();
1274       for (Index k = depth_32; k < depth; k++) {
1275         ptr = (QUInt8*)&R_A;
1276         ptr[k - depth_32] = rhs(k, n + 4);
1277         ptr = (QUInt8*)&R_B;
1278         ptr[k - depth_32] = rhs(k, n + 5);
1279         ptr = (QUInt8*)&R_C;
1280         ptr[k - depth_32] = rhs(k, n + 6);
1281         ptr = (QUInt8*)&R_D;
1282         ptr[k - depth_32] = rhs(k, n + 7);
1283       }
1284       PACK_STEP;
1285 
1286       R_A = _mm256_setzero_si256();
1287       R_B = _mm256_setzero_si256();
1288       R_C = _mm256_setzero_si256();
1289       R_D = _mm256_setzero_si256();
1290       for (Index k = depth_32; k < depth; k++) {
1291         ptr = (QUInt8*)&R_A;
1292         ptr[k - depth_32] = rhs(k, n + 8);
1293         ptr = (QUInt8*)&R_B;
1294         ptr[k - depth_32] = rhs(k, n + 9);
1295         ptr = (QUInt8*)&R_C;
1296         ptr[k - depth_32] = rhs(k, n + 10);
1297         ptr = (QUInt8*)&R_D;
1298         ptr[k - depth_32] = rhs(k, n + 11);
1299       }
1300       PACK_STEP;
1301 
1302       R_A = _mm256_setzero_si256();
1303       R_B = _mm256_setzero_si256();
1304       R_C = _mm256_setzero_si256();
1305       R_D = _mm256_setzero_si256();
1306       for (Index k = depth_32; k < depth; k++) {
1307         ptr = (QUInt8*)&R_A;
1308         ptr[k - depth_32] = rhs(k, n + 12);
1309         ptr = (QUInt8*)&R_B;
1310         ptr[k - depth_32] = rhs(k, n + 13);
1311         ptr = (QUInt8*)&R_C;
1312         ptr[k - depth_32] = rhs(k, n + 14);
1313         ptr = (QUInt8*)&R_D;
1314         ptr[k - depth_32] = rhs(k, n + 15);
1315       }
1316       PACK_STEP;
1317 
1318       R_A = _mm256_setzero_si256();
1319       R_B = _mm256_setzero_si256();
1320       R_C = _mm256_setzero_si256();
1321       R_D = _mm256_setzero_si256();
1322       for (Index k = depth_32; k < depth; k++) {
1323         ptr = (QUInt8*)&R_A;
1324         ptr[k - depth_32] = rhs(k, n + 16);
1325         ptr = (QUInt8*)&R_B;
1326         ptr[k - depth_32] = rhs(k, n + 17);
1327         ptr = (QUInt8*)&R_C;
1328         ptr[k - depth_32] = rhs(k, n + 18);
1329         ptr = (QUInt8*)&R_D;
1330         ptr[k - depth_32] = rhs(k, n + 19);
1331       }
1332       PACK_STEP;
1333 
1334       R_A = _mm256_setzero_si256();
1335       R_B = _mm256_setzero_si256();
1336       R_C = _mm256_setzero_si256();
1337       R_D = _mm256_setzero_si256();
1338       for (Index k = depth_32; k < depth; k++) {
1339         ptr = (QUInt8*)&R_A;
1340         ptr[k - depth_32] = rhs(k, n + 20);
1341         ptr = (QUInt8*)&R_B;
1342         ptr[k - depth_32] = rhs(k, n + 21);
1343         ptr = (QUInt8*)&R_C;
1344         ptr[k - depth_32] = rhs(k, n + 22);
1345         ptr = (QUInt8*)&R_D;
1346         ptr[k - depth_32] = rhs(k, n + 23);
1347       }
1348       PACK_STEP;
1349 
1350       R_A = _mm256_setzero_si256();
1351       R_B = _mm256_setzero_si256();
1352       R_C = _mm256_setzero_si256();
1353       R_D = _mm256_setzero_si256();
1354       for (Index k = depth_32; k < depth; k++) {
1355         ptr = (QUInt8*)&R_A;
1356         ptr[k - depth_32] = rhs(k, n + 24);
1357         ptr = (QUInt8*)&R_B;
1358         ptr[k - depth_32] = rhs(k, n + 25);
1359         ptr = (QUInt8*)&R_C;
1360         ptr[k - depth_32] = rhs(k, n + 26);
1361         ptr = (QUInt8*)&R_D;
1362         ptr[k - depth_32] = rhs(k, n + 27);
1363       }
1364       PACK_STEP;
1365 
1366       R_A = _mm256_setzero_si256();
1367       R_B = _mm256_setzero_si256();
1368       R_C = _mm256_setzero_si256();
1369       R_D = _mm256_setzero_si256();
1370       for (Index k = depth_32; k < depth; k++) {
1371         ptr = (QUInt8*)&R_A;
1372         ptr[k - depth_32] = rhs(k, n + 28);
1373         ptr = (QUInt8*)&R_B;
1374         ptr[k - depth_32] = rhs(k, n + 29);
1375         ptr = (QUInt8*)&R_C;
1376         ptr[k - depth_32] = rhs(k, n + 30);
1377         ptr = (QUInt8*)&R_D;
1378         ptr[k - depth_32] = rhs(k, n + 31);
1379       }
1380       PACK_STEP;
1381       blockB_256 += 24;
1382     }
1383   }
1384 
1385   // Finish packing cols
1386   if (cols_32 < cols) {
1387     // Pack depth in sets of 32
1388     for (Index k = 0; k < depth_32; k += 32) {
1389       __m256i R_A, R_B, R_C, R_D;
1390       Index n;
1391       for (n = cols_32; n < cols; n += 4) {
1392         switch (cols - n) {
1393           case 1:
1394             R_A = rhs.template loadPacket<Packet>(k, n);
1395             R_B = _mm256_setzero_si256();
1396             R_C = _mm256_setzero_si256();
1397             R_D = _mm256_setzero_si256();
1398             PACK_STEP;
1399             break;
1400           case 2:
1401             R_A = rhs.template loadPacket<Packet>(k, n);
1402             R_B = rhs.template loadPacket<Packet>(k, n + 1);
1403             R_C = _mm256_setzero_si256();
1404             R_D = _mm256_setzero_si256();
1405             PACK_STEP;
1406             break;
1407           case 3:
1408             R_A = rhs.template loadPacket<Packet>(k, n);
1409             R_B = rhs.template loadPacket<Packet>(k, n + 1);
1410             R_C = rhs.template loadPacket<Packet>(k, n + 2);
1411             R_D = _mm256_setzero_si256();
1412             PACK_STEP;
1413             break;
1414           default:
1415             R_A = rhs.template loadPacket<Packet>(k, n);
1416             R_B = rhs.template loadPacket<Packet>(k, n + 1);
1417             R_C = rhs.template loadPacket<Packet>(k, n + 2);
1418             R_D = rhs.template loadPacket<Packet>(k, n + 3);
1419             PACK_STEP;
1420             break;
1421         }
1422       }
1423 
1424       // Increment the block pointer.
1425       // We must pad if cols is not a multiple of 32.
1426       blockB_256 += 32 - (n - cols_32) / 4;
1427     }
1428 
1429     if (depth_32 < depth) {
1430       for (Index n = cols_32; n < cols; n += 4) {
1431         QUInt8* ptr;
1432         __m256i R_A = _mm256_setzero_si256();
1433         __m256i R_B = _mm256_setzero_si256();
1434         __m256i R_C = _mm256_setzero_si256();
1435         __m256i R_D = _mm256_setzero_si256();
1436         switch (cols - n) {
1437           case 1:
1438             for (Index k = depth_32; k < depth; k++) {
1439               ptr = (QUInt8*)&R_A;
1440               ptr[k - depth_32] = rhs(k, n);
1441             }
1442             PACK_STEP;
1443             break;
1444           case 2:
1445             for (Index k = depth_32; k < depth; k++) {
1446               ptr = (QUInt8*)&R_A;
1447               ptr[k - depth_32] = rhs(k, n);
1448               ptr = (QUInt8*)&R_B;
1449               ptr[k - depth_32] = rhs(k, n + 1);
1450             }
1451             PACK_STEP;
1452             break;
1453           case 3:
1454             for (Index k = depth_32; k < depth; k++) {
1455               ptr = (QUInt8*)&R_A;
1456               ptr[k - depth_32] = rhs(k, n);
1457               ptr = (QUInt8*)&R_B;
1458               ptr[k - depth_32] = rhs(k, n + 1);
1459               ptr = (QUInt8*)&R_C;
1460               ptr[k - depth_32] = rhs(k, n + 2);
1461             }
1462             PACK_STEP;
1463             break;
1464           default:
1465             for (Index k = depth_32; k < depth; k++) {
1466               ptr = (QUInt8*)&R_A;
1467               ptr[k - depth_32] = rhs(k, n);
1468               ptr = (QUInt8*)&R_B;
1469               ptr[k - depth_32] = rhs(k, n + 1);
1470               ptr = (QUInt8*)&R_C;
1471               ptr[k - depth_32] = rhs(k, n + 2);
1472               ptr = (QUInt8*)&R_D;
1473               ptr[k - depth_32] = rhs(k, n + 3);
1474             }
1475             PACK_STEP;
1476             break;
1477         }
1478       }
1479     }
1480   }
1481 #undef PACK_STEP
1482 }
1483 
1484 template <typename Index, typename DataMapper, int mr, int nr,
1485           bool ConjugateLhs, bool ConjugateRhs>
1486 EIGEN_DONT_INLINE void gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr,
1487                                        ConjugateLhs, ConjugateRhs>::
1488 operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
1489            Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
1490            Index strideB, Index offsetA, Index offsetB) {
1491   EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
1492   EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
1493   eigen_assert(alpha.value == 1);
1494   eigen_assert(strideA == -1);
1495   eigen_assert(strideB == -1);
1496   eigen_assert(offsetA == 0);
1497   eigen_assert(offsetB == 0);
1498   eigen_assert(rows > 0);
1499   eigen_assert(cols > 0);
1500   eigen_assert(depth > 0);
1501   eigen_assert(blockA);
1502   eigen_assert(blockB);
1503 
1504   Index rows_32 = ((rows + 31) / 32) * 32;
1505   Index cols_32 = ((cols + 31) / 32) * 32;
1506   Index depth_32 = ((depth + 31) / 32) * 32;
1507 
1508   // Create result block
1509   ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0);
1510   memset(blockO, 0, 32 * 32 * sizeof(QInt32));
1511 
1512   // Get vectorized pointers
1513   __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO);
1514   const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA);
1515   const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB);
1516 
1517   // Loop over blocks of 32 columns
1518   for (Index n = 0; n < cols_32; n += 32) {
1519     // Reset index into blockA
1520     Index indexL = 0;
1521     // Loop over blocks of 32 rows
1522     for (Index m = 0; m < rows_32; m += 32) {
1523       // Reset index into blockB
1524       Index indexR = n / 32 * depth_32;
1525       // Loop over blocks of 8 on depth
1526       for (Index k = 0; k < depth_32; k += 8) {
1527         // Load inputs
1528         __m256i L_AD0 = blockA_256[indexL++];
1529         __m256i L_AD8 = blockA_256[indexL++];
1530         __m256i L_AD16 = blockA_256[indexL++];
1531         __m256i L_AD24 = blockA_256[indexL++];
1532         __m256i L_EH0 = blockA_256[indexL++];
1533         __m256i L_EH8 = blockA_256[indexL++];
1534         __m256i L_EH16 = blockA_256[indexL++];
1535         __m256i L_EH24 = blockA_256[indexL++];
1536         __m256i R_AH0 = blockB_256[indexR++];
1537         __m256i R_AH4 = blockB_256[indexR++];
1538         __m256i R_AH8 = blockB_256[indexR++];
1539         __m256i R_AH12 = blockB_256[indexR++];
1540         __m256i R_AH16 = blockB_256[indexR++];
1541         __m256i R_AH20 = blockB_256[indexR++];
1542         __m256i R_AH24 = blockB_256[indexR++];
1543         __m256i R_AH28 = blockB_256[indexR++];
1544 
1545         // This constant is used with madd to convert 16 bit to 32 bit
1546         const __m256i ONE = _mm256_set1_epi32(0x00010001);
1547 
1548         // Declare variables used in COMPUTE_STEP
1549         __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32;
1550 
1551 #define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET)                             \
1552   P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0);                             \
1553   P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
1554   P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0);                             \
1555   P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
1556   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
1557   _mm256_store_si256(                                                          \
1558       blockO_256 + 4 * OFFSET,                                                 \
1559       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32));     \
1560                                                                                \
1561   P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8);                             \
1562   P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
1563   P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8);                             \
1564   P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
1565   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
1566   _mm256_store_si256(                                                          \
1567       blockO_256 + 4 * OFFSET + 1,                                             \
1568       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \
1569                                                                                \
1570   P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16);                            \
1571   P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
1572   P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16);                            \
1573   P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
1574   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
1575   _mm256_store_si256(                                                          \
1576       blockO_256 + 4 * OFFSET + 2,                                             \
1577       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \
1578                                                                                \
1579   P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24);                            \
1580   P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
1581   P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24);                            \
1582   P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
1583   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
1584   _mm256_store_si256(                                                          \
1585       blockO_256 + 4 * OFFSET + 3,                                             \
1586       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32));
1587 
1588         // Permute and shuffle to copy a single value across the entire vector
1589         // Then compute the multiplication
1590         __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00);
1591         __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1592         __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1593         COMPUTE_STEP(R_AD0, R_EH0, 0);
1594         __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1595         __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1596         COMPUTE_STEP(R_AD1, R_EH1, 1);
1597         R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11);
1598         __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1599         __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1600         COMPUTE_STEP(R_AD2, R_EH2, 2);
1601         __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1602         __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1603         COMPUTE_STEP(R_AD3, R_EH3, 3);
1604 
1605         R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00);
1606         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1607         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1608         COMPUTE_STEP(R_AD0, R_EH0, 4);
1609         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1610         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1611         COMPUTE_STEP(R_AD1, R_EH1, 5);
1612         R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11);
1613         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1614         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1615         COMPUTE_STEP(R_AD2, R_EH2, 6);
1616         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1617         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1618         COMPUTE_STEP(R_AD3, R_EH3, 7);
1619 
1620         R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00);
1621         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1622         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1623         COMPUTE_STEP(R_AD0, R_EH0, 8);
1624         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1625         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1626         COMPUTE_STEP(R_AD1, R_EH1, 9);
1627         R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11);
1628         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1629         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1630         COMPUTE_STEP(R_AD2, R_EH2, 10);
1631         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1632         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1633         COMPUTE_STEP(R_AD3, R_EH3, 11);
1634 
1635         R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00);
1636         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1637         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1638         COMPUTE_STEP(R_AD0, R_EH0, 12);
1639         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1640         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1641         COMPUTE_STEP(R_AD1, R_EH1, 13);
1642         R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11);
1643         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1644         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1645         COMPUTE_STEP(R_AD2, R_EH2, 14);
1646         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1647         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1648         COMPUTE_STEP(R_AD3, R_EH3, 15);
1649 
1650         R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00);
1651         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1652         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1653         COMPUTE_STEP(R_AD0, R_EH0, 16);
1654         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1655         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1656         COMPUTE_STEP(R_AD1, R_EH1, 17);
1657         R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11);
1658         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1659         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1660         COMPUTE_STEP(R_AD2, R_EH2, 18);
1661         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1662         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1663         COMPUTE_STEP(R_AD3, R_EH3, 19);
1664 
1665         R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00);
1666         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1667         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1668         COMPUTE_STEP(R_AD0, R_EH0, 20);
1669         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1670         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1671         COMPUTE_STEP(R_AD1, R_EH1, 21);
1672         R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11);
1673         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1674         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1675         COMPUTE_STEP(R_AD2, R_EH2, 22);
1676         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1677         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1678         COMPUTE_STEP(R_AD3, R_EH3, 23);
1679 
1680         R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00);
1681         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1682         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1683         COMPUTE_STEP(R_AD0, R_EH0, 24);
1684         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1685         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1686         COMPUTE_STEP(R_AD1, R_EH1, 25);
1687         R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11);
1688         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1689         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1690         COMPUTE_STEP(R_AD2, R_EH2, 26);
1691         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1692         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1693         COMPUTE_STEP(R_AD3, R_EH3, 27);
1694 
1695         R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00);
1696         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1697         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1698         COMPUTE_STEP(R_AD0, R_EH0, 28);
1699         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1700         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1701         COMPUTE_STEP(R_AD1, R_EH1, 29);
1702         R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11);
1703         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
1704         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
1705         COMPUTE_STEP(R_AD2, R_EH2, 30);
1706         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
1707         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
1708         COMPUTE_STEP(R_AD3, R_EH3, 31);
1709 
1710 #undef COMPUTE_STEP
1711       }
1712 
1713       // Transfer the results to the result matrix.
1714       if (m + 32 <= rows && n + 32 <= cols) {
1715         Index i = 0;
1716         for (Index j = n; j < n + 32; j++) {
1717           LinearMapper r0 = res.getLinearMapper(m, j);
1718           LinearMapper r1 = res.getLinearMapper(m + 8, j);
1719           LinearMapper r2 = res.getLinearMapper(m + 16, j);
1720           LinearMapper r3 = res.getLinearMapper(m + 24, j);
1721           typedef typename packet_traits<QInt32>::type Packet;
1722           r0.template storePacket<Packet>(
1723               0, _mm256_add_epi32(blockO_256[i++],
1724                                   r0.template loadPacket<Packet>(0)));
1725           r1.template storePacket<Packet>(
1726               0, _mm256_add_epi32(blockO_256[i++],
1727                                   r1.template loadPacket<Packet>(0)));
1728           r2.template storePacket<Packet>(
1729               0, _mm256_add_epi32(blockO_256[i++],
1730                                   r2.template loadPacket<Packet>(0)));
1731           r3.template storePacket<Packet>(
1732               0, _mm256_add_epi32(blockO_256[i++],
1733                                   r3.template loadPacket<Packet>(0)));
1734         }
1735       } else {
1736         for (Index j = n; j < cols; j++) {
1737           for (Index i = m; i < rows; i++) {
1738             res(i, j) = blockO[(j - n) * 32 + (i - m)];
1739           }
1740         }
1741       }
1742 
1743       // Zero the result block so it can be reused
1744       memset(blockO, 0, 32 * 32 * sizeof(QInt32));
1745     }
1746   }
1747 }
1748 
1749 // Below are the fully optimized versions that are correct only for sizes that
1750 // are multiple of 32.  It is about a 10% performance benefit to keep these
1751 // implementations separate.
1752 
1753 // Arrange a block of the left input matrix in contiguous memory.
1754 //
1755 // Given column major input (A0 beside A1 in memory):
1756 // A0 B0 C0 D0 E0 F0 G0 H0 ...
1757 // A1 B1 C1 D1 E1 F1 G1 H1 ...
1758 // A2 B2 C2 D2 E2 F2 G2 H2 ...
1759 // A3 B3 C3 D3 E3 F3 G3 H3 ...
1760 // A4 B4 C4 D4 E4 F4 G4 H4 ...
1761 // A5 B5 C5 D5 E5 F5 G5 H5 ...
1762 // A6 B6 C6 D6 E6 F6 G6 H6 ...
1763 // A7 B7 C7 D7 E7 F7 G7 H7 ...
1764 // A8 ...
1765 // ...
1766 //
1767 // Packing yields output (A0 beside B0 in memory):
1768 // A0 B0 C0 D0
1769 // A1 B1 C1 D1
1770 // A2 B2 C2 D2
1771 // A3 B3 C3 D3
1772 // A4 B4 C4 D4
1773 // A5 B5 C5 D5
1774 // A6 B6 C6 D6
1775 // A7 B7 C7 D7
1776 // ...
1777 // A31 B31 C31 D31
1778 // E0 F0 G0 H0
1779 // E1 F1 G1 H1
1780 // E2 F2 G2 H2
1781 // E3 F3 G3 H3
1782 // E4 F4 G4 H4
1783 // E5 F5 G5 H5
1784 // E6 F6 G6 H6
1785 // E7 F7 G7 H7
1786 // ...
1787 //
1788 // Four elements of the same row are arranged contiguously because maddubs and
1789 // madd both perform an adjacent addition in the kernel.
1790 template <typename Index, typename DataMapper, int Pack1, int Pack2,
1791           bool Conjugate, bool PanelMode>
1792 struct gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, QInt8, ColMajor,
1793                      Conjugate, PanelMode> {
1794   EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs,
1795                                     Index depth, Index rows, Index stride = 0,
1796                                     Index offset = 0);
1797 };
1798 
1799 template <typename Index, typename DataMapper, int Pack1, int Pack2,
1800           bool Conjugate, bool PanelMode>
1801 EIGEN_DONT_INLINE void gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2,
1802                                      QInt8, ColMajor, Conjugate, PanelMode>::
1803 operator()(QInt8* blockA, const DataMapper& lhs, Index depth, Index rows,
1804            Index stride, Index offset) {
1805   eigen_assert(stride == 0);
1806   eigen_assert(offset == 0);
1807 
1808   typedef typename packet_traits<QInt8>::type Packet;
1809 
1810   // Use alternate function for weird sizes
1811   if (rows % 32 != 0 || depth % 32 != 0) {
1812     gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor,
1813                       Conjugate, PanelMode> lhs_pack;
1814     return lhs_pack(blockA, lhs, depth, rows, stride, offset);
1815   }
1816 
1817   // Get vector pointer
1818   __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);
1819 
1820   // Pack rows in sets of 32
1821   for (Index m = 0; m < rows; m += 32) {
1822     // Pack depth in sets of 8
1823     for (Index k = 0; k < depth; k += 8) {
1824       // Load vectors
1825       __m256i L_A = lhs.template loadPacket<Packet>(m, k);
1826       __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);
1827 
1828       // Interleave 8-bit elements
1829       __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
1830       __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
1831 
1832       __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
1833       __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);
1834       __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
1835       __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);
1836 
1837       // Interleave 16-bit elements
1838       __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
1839       __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);
1840 
1841       // Use permute before we store to cross 128-bit lanes
1842       __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
1843       _mm256_store_si256(blockA_256++, L_AD0);
1844 
1845       // Complete packing for 32 x 8 block
1846       __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
1847       __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
1848       __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
1849       __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
1850       _mm256_store_si256(blockA_256++, L_AD8);
1851       _mm256_store_si256(blockA_256++, L_AD16);
1852       __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
1853       _mm256_store_si256(blockA_256++, L_AD24);
1854       __m256i L_E = lhs.template loadPacket<Packet>(m, k + 4);
1855       __m256i L_F = lhs.template loadPacket<Packet>(m, k + 5);
1856       __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
1857       __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
1858       __m256i L_G = lhs.template loadPacket<Packet>(m, k + 6);
1859       __m256i L_H = lhs.template loadPacket<Packet>(m, k + 7);
1860       __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
1861       __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
1862       __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
1863       __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
1864       __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
1865       _mm256_store_si256(blockA_256++, L_EH0);
1866       __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
1867       __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
1868       __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
1869       __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
1870       _mm256_store_si256(blockA_256++, L_EH8);
1871       _mm256_store_si256(blockA_256++, L_EH16);
1872       __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
1873       _mm256_store_si256(blockA_256++, L_EH24);
1874     }
1875   }
1876 }
1877 
1878 // Arrange a block of the right input matrix in contiguous memory.
1879 //
1880 // Given column major input (A0 beside A1 in memory):
1881 // A0 B0 C0 D0 E0 F0 G0 H0 ...
1882 // A1 B1 C1 D1 E1 F1 G1 H1 ...
1883 // A2 B2 C2 D2 E2 F2 G2 H2 ...
1884 // A3 B3 C3 D3 E3 F3 G3 H3 ...
1885 // A4 B4 C4 D4 E4 F4 G4 H4 ...
1886 // A5 B5 C5 D5 E5 F5 G5 H5 ...
1887 // A6 B6 C6 D6 E6 F6 G6 H6 ...
1888 // A7 B7 C7 D7 E7 F7 G7 H7 ...
1889 // A8 ...
1890 // ...
1891 //
1892 // Packing yields row major output (A0 beside A1 in memory):
1893 // A0 A1 A2 A3 A4 A5 A6 A7
1894 // B0 B1 B2 B3 B4 B5 B6 B7
1895 // ...
1896 //
1897 // At least four elements of the same col are arranged contiguously because
1898 // maddubs and madd both perform an adjacent addition in the kernel.  We can
1899 // save work by leaving 8 adjacent elements because kr = 8.
1900 template <typename Index, typename DataMapper, int nr, bool Conjugate,
1901           bool PanelMode>
1902 struct gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
1903                      PanelMode> {
1904   EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs,
1905                                     Index depth, Index cols, Index stride = 0,
1906                                     Index offset = 0);
1907 };
1908 
1909 template <typename Index, typename DataMapper, int nr, bool Conjugate,
1910           bool PanelMode>
1911 EIGEN_DONT_INLINE void gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor,
1912                                      Conjugate, PanelMode>::
1913 operator()(QUInt8* blockB, const DataMapper& rhs, Index depth, Index cols,
1914            Index stride, Index offset) {
1915   eigen_assert(stride == 0);
1916   eigen_assert(offset == 0);
1917 
1918   typedef typename packet_traits<QUInt8>::type Packet;
1919 
1920   // Use alternate function for weird sizes
1921   if (cols % 32 != 0 || depth % 32 != 0) {
1922     gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
1923                       PanelMode> rhs_pack;
1924     return rhs_pack(blockB, rhs, depth, cols, stride, offset);
1925   }
1926 
1927   // Get vector pointer
1928   __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);
1929 
1930   // Perform a step of the packing for 4 columns
1931   __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24;
1932 #define PACK_STEP                                            \
1933   R_AB_L = _mm256_unpacklo_epi64(R_A, R_B);                  \
1934   R_CD_L = _mm256_unpacklo_epi64(R_C, R_D);                  \
1935   R_AB_H = _mm256_unpackhi_epi64(R_A, R_B);                  \
1936   R_CD_H = _mm256_unpackhi_epi64(R_C, R_D);                  \
1937   R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20);  \
1938   R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \
1939   R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20);  \
1940   R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \
1941   _mm256_store_si256(blockB_256, R_AD_0);                    \
1942   _mm256_store_si256(blockB_256 + 8, R_AD_8);                \
1943   _mm256_store_si256(blockB_256 + 16, R_AD_16);              \
1944   _mm256_store_si256(blockB_256 + 24, R_AD_24);              \
1945   blockB_256++;
1946 
1947   // Pack cols in sets of 32
1948   for (Index n = 0; n < cols; n += 32) {
1949     // Pack depth in sets of 32
1950     for (Index k = 0; k < depth; k += 32) {
1951       __m256i R_A = rhs.template loadPacket<Packet>(k, n);
1952       __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
1953       __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
1954       __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
1955       PACK_STEP;
1956 
1957       R_A = rhs.template loadPacket<Packet>(k, n + 4);
1958       R_B = rhs.template loadPacket<Packet>(k, n + 5);
1959       R_C = rhs.template loadPacket<Packet>(k, n + 6);
1960       R_D = rhs.template loadPacket<Packet>(k, n + 7);
1961       PACK_STEP;
1962 
1963       R_A = rhs.template loadPacket<Packet>(k, n + 8);
1964       R_B = rhs.template loadPacket<Packet>(k, n + 9);
1965       R_C = rhs.template loadPacket<Packet>(k, n + 10);
1966       R_D = rhs.template loadPacket<Packet>(k, n + 11);
1967       PACK_STEP;
1968 
1969       R_A = rhs.template loadPacket<Packet>(k, n + 12);
1970       R_B = rhs.template loadPacket<Packet>(k, n + 13);
1971       R_C = rhs.template loadPacket<Packet>(k, n + 14);
1972       R_D = rhs.template loadPacket<Packet>(k, n + 15);
1973       PACK_STEP;
1974 
1975       R_A = rhs.template loadPacket<Packet>(k, n + 16);
1976       R_B = rhs.template loadPacket<Packet>(k, n + 17);
1977       R_C = rhs.template loadPacket<Packet>(k, n + 18);
1978       R_D = rhs.template loadPacket<Packet>(k, n + 19);
1979       PACK_STEP;
1980 
1981       R_A = rhs.template loadPacket<Packet>(k, n + 20);
1982       R_B = rhs.template loadPacket<Packet>(k, n + 21);
1983       R_C = rhs.template loadPacket<Packet>(k, n + 22);
1984       R_D = rhs.template loadPacket<Packet>(k, n + 23);
1985       PACK_STEP;
1986 
1987       R_A = rhs.template loadPacket<Packet>(k, n + 24);
1988       R_B = rhs.template loadPacket<Packet>(k, n + 25);
1989       R_C = rhs.template loadPacket<Packet>(k, n + 26);
1990       R_D = rhs.template loadPacket<Packet>(k, n + 27);
1991       PACK_STEP;
1992 
1993       R_A = rhs.template loadPacket<Packet>(k, n + 28);
1994       R_B = rhs.template loadPacket<Packet>(k, n + 29);
1995       R_C = rhs.template loadPacket<Packet>(k, n + 30);
1996       R_D = rhs.template loadPacket<Packet>(k, n + 31);
1997       PACK_STEP;
1998 
1999       blockB_256 += 24;
2000     }
2001   }
2002 #undef PACK_STEP
2003 }
2004 
2005 // Perform the actual multiplication on packed inputs
2006 template <typename Index, typename DataMapper, int mr, int nr,
2007           bool ConjugateLhs, bool ConjugateRhs>
2008 struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
2009                    ConjugateRhs> {
2010   typedef typename DataMapper::LinearMapper LinearMapper;
2011 
2012   EIGEN_DONT_INLINE
2013   void operator()(const DataMapper& res, const QInt8* blockA,
2014                   const QUInt8* blockB, Index rows, Index depth, Index cols,
2015                   QInt32 alpha, Index strideA = -1, Index strideB = -1,
2016                   Index offsetA = 0, Index offsetB = 0);
2017 };
2018 
2019 template <typename Index, typename DataMapper, int mr, int nr,
2020           bool ConjugateLhs, bool ConjugateRhs>
2021 EIGEN_DONT_INLINE void gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr,
2022                                    ConjugateLhs, ConjugateRhs>::
2023 operator()(const DataMapper& res, const QInt8* blockA, const QUInt8* blockB,
2024            Index rows, Index depth, Index cols, QInt32 alpha, Index strideA,
2025            Index strideB, Index offsetA, Index offsetB) {
2026   EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
2027   EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
2028   eigen_assert(alpha.value == 1);
2029   eigen_assert(strideA == -1);
2030   eigen_assert(strideB == -1);
2031   eigen_assert(offsetA == 0);
2032   eigen_assert(offsetB == 0);
2033   eigen_assert(rows > 0);
2034   eigen_assert(cols > 0);
2035   eigen_assert(depth > 0);
2036   eigen_assert(blockA);
2037   eigen_assert(blockB);
2038 
2039   // Use alternate function for weird sizes
2040   if (rows % 32 != 0 || cols % 32 != 0 || depth % 32 != 0) {
2041     gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
2042                     ConjugateRhs> gebp;
2043     return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB,
2044                 offsetA, offsetB);
2045   }
2046 
2047   // Create result block
2048   QInt32* blockO = aligned_new<QInt32>(32 * 32);
2049   // Allocating the result block is about 5-10% faster than declaring stack
2050   // space.  It is unclear why this is the case.
2051   // ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0);
2052   memset(blockO, 0, 32 * 32 * sizeof(QInt32));
2053 
2054   // Get vectorized pointers
2055   __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO);
2056   const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA);
2057   const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB);
2058 
2059   // Loop over blocks of 32 columns
2060   for (Index n = 0; n < cols; n += 32) {
2061     // Reset index into blockA
2062     Index indexL = 0;
2063     // Loop over blocks of 32 rows
2064     for (Index m = 0; m < rows; m += 32) {
2065       // Reset index into blockB
2066       Index indexR = n / 32 * depth;
2067       // Loop over blocks of 8 on depth
2068       for (Index k = 0; k < depth; k += 8) {
2069         // Load inputs
2070         __m256i L_AD0 = blockA_256[indexL++];
2071         __m256i L_AD8 = blockA_256[indexL++];
2072         __m256i L_AD16 = blockA_256[indexL++];
2073         __m256i L_AD24 = blockA_256[indexL++];
2074         __m256i L_EH0 = blockA_256[indexL++];
2075         __m256i L_EH8 = blockA_256[indexL++];
2076         __m256i L_EH16 = blockA_256[indexL++];
2077         __m256i L_EH24 = blockA_256[indexL++];
2078         __m256i R_AH0 = blockB_256[indexR++];
2079         __m256i R_AH4 = blockB_256[indexR++];
2080         __m256i R_AH8 = blockB_256[indexR++];
2081         __m256i R_AH12 = blockB_256[indexR++];
2082         __m256i R_AH16 = blockB_256[indexR++];
2083         __m256i R_AH20 = blockB_256[indexR++];
2084         __m256i R_AH24 = blockB_256[indexR++];
2085         __m256i R_AH28 = blockB_256[indexR++];
2086 
2087         // This constant is used with madd to convert 16 bit to 32 bit
2088         const __m256i ONE = _mm256_set1_epi32(0x00010001);
2089 
2090         // Declare variables used in COMPUTE_STEP
2091         __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32;
2092 
2093 #define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET)                             \
2094   P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0);                             \
2095   P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
2096   P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0);                             \
2097   P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
2098   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
2099   _mm256_store_si256(                                                          \
2100       blockO_256 + 4 * OFFSET,                                                 \
2101       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32));     \
2102                                                                                \
2103   P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8);                             \
2104   P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
2105   P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8);                             \
2106   P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
2107   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
2108   _mm256_store_si256(                                                          \
2109       blockO_256 + 4 * OFFSET + 1,                                             \
2110       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \
2111                                                                                \
2112   P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16);                            \
2113   P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
2114   P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16);                            \
2115   P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
2116   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
2117   _mm256_store_si256(                                                          \
2118       blockO_256 + 4 * OFFSET + 2,                                             \
2119       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \
2120                                                                                \
2121   P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24);                            \
2122   P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
2123   P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24);                            \
2124   P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
2125   P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
2126   _mm256_store_si256(                                                          \
2127       blockO_256 + 4 * OFFSET + 3,                                             \
2128       _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32));
2129 
2130         // Permute and shuffle to copy a single value across the entire vector
2131         // Then compute the multiplication
2132         __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00);
2133         __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2134         __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2135         COMPUTE_STEP(R_AD0, R_EH0, 0);
2136         __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2137         __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2138         COMPUTE_STEP(R_AD1, R_EH1, 1);
2139         R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11);
2140         __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2141         __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2142         COMPUTE_STEP(R_AD2, R_EH2, 2);
2143         __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2144         __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2145         COMPUTE_STEP(R_AD3, R_EH3, 3);
2146 
2147         R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00);
2148         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2149         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2150         COMPUTE_STEP(R_AD0, R_EH0, 4);
2151         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2152         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2153         COMPUTE_STEP(R_AD1, R_EH1, 5);
2154         R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11);
2155         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2156         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2157         COMPUTE_STEP(R_AD2, R_EH2, 6);
2158         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2159         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2160         COMPUTE_STEP(R_AD3, R_EH3, 7);
2161 
2162         R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00);
2163         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2164         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2165         COMPUTE_STEP(R_AD0, R_EH0, 8);
2166         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2167         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2168         COMPUTE_STEP(R_AD1, R_EH1, 9);
2169         R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11);
2170         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2171         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2172         COMPUTE_STEP(R_AD2, R_EH2, 10);
2173         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2174         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2175         COMPUTE_STEP(R_AD3, R_EH3, 11);
2176 
2177         R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00);
2178         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2179         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2180         COMPUTE_STEP(R_AD0, R_EH0, 12);
2181         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2182         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2183         COMPUTE_STEP(R_AD1, R_EH1, 13);
2184         R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11);
2185         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2186         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2187         COMPUTE_STEP(R_AD2, R_EH2, 14);
2188         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2189         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2190         COMPUTE_STEP(R_AD3, R_EH3, 15);
2191 
2192         R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00);
2193         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2194         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2195         COMPUTE_STEP(R_AD0, R_EH0, 16);
2196         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2197         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2198         COMPUTE_STEP(R_AD1, R_EH1, 17);
2199         R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11);
2200         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2201         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2202         COMPUTE_STEP(R_AD2, R_EH2, 18);
2203         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2204         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2205         COMPUTE_STEP(R_AD3, R_EH3, 19);
2206 
2207         R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00);
2208         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2209         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2210         COMPUTE_STEP(R_AD0, R_EH0, 20);
2211         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2212         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2213         COMPUTE_STEP(R_AD1, R_EH1, 21);
2214         R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11);
2215         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2216         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2217         COMPUTE_STEP(R_AD2, R_EH2, 22);
2218         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2219         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2220         COMPUTE_STEP(R_AD3, R_EH3, 23);
2221 
2222         R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00);
2223         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2224         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2225         COMPUTE_STEP(R_AD0, R_EH0, 24);
2226         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2227         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2228         COMPUTE_STEP(R_AD1, R_EH1, 25);
2229         R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11);
2230         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2231         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2232         COMPUTE_STEP(R_AD2, R_EH2, 26);
2233         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2234         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2235         COMPUTE_STEP(R_AD3, R_EH3, 27);
2236 
2237         R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00);
2238         R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2239         R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2240         COMPUTE_STEP(R_AD0, R_EH0, 28);
2241         R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2242         R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2243         COMPUTE_STEP(R_AD1, R_EH1, 29);
2244         R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11);
2245         R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
2246         R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
2247         COMPUTE_STEP(R_AD2, R_EH2, 30);
2248         R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
2249         R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
2250         COMPUTE_STEP(R_AD3, R_EH3, 31);
2251 
2252 #undef COMPUTE_STEP
2253       }
2254 
2255       // Transfer the results to the result matrix
2256       Index i = 0;
2257       for (Index j = n; j < n + 32; j++) {
2258         LinearMapper r0 = res.getLinearMapper(m, j);
2259         LinearMapper r1 = res.getLinearMapper(m + 8, j);
2260         LinearMapper r2 = res.getLinearMapper(m + 16, j);
2261         LinearMapper r3 = res.getLinearMapper(m + 24, j);
2262         typedef typename packet_traits<QInt32>::type Packet;
2263         r0.template storePacket<Packet>(
2264             0, _mm256_add_epi32(blockO_256[i++],
2265                                 r0.template loadPacket<Packet>(0)));
2266         r1.template storePacket<Packet>(
2267             0, _mm256_add_epi32(blockO_256[i++],
2268                                 r1.template loadPacket<Packet>(0)));
2269         r2.template storePacket<Packet>(
2270             0, _mm256_add_epi32(blockO_256[i++],
2271                                 r2.template loadPacket<Packet>(0)));
2272         r3.template storePacket<Packet>(
2273             0, _mm256_add_epi32(blockO_256[i++],
2274                                 r3.template loadPacket<Packet>(0)));
2275       }
2276 
2277       // Zero the result block so it can be reused
2278       memset(blockO, 0, 32 * 32 * sizeof(QInt32));
2279     }
2280   }
2281   aligned_delete(blockO, 32 * 32);
2282 }
2283 
2284 #endif  // EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT
2285 
2286 }  // namespace internal
2287 }  // namespace Eigen
2288 
2289 #endif  // CXX11_SRC_FIXEDPOINT_MATMATPRODUCTAVX2_H_
2290