1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <immintrin.h>
10
11 #include <qnnpack/q8gemm_sparse.h>
12 #include <requantization/runtime-sse2.h>
13
14 #include "8x4c1x4-packed-sse2.h"
15
16 // This is a super slow kernel in that it does not use intrinsics to
17 // tranpose. Since this is for x86 we are not optimizing it.
18 // For ARM this will be optimized.
pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2(const size_t mr,const size_t K,const uint8_t * a,const size_t a_stride,uint8_t * a_packed)19 void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2(
20 const size_t mr,
21 const size_t K,
22 const uint8_t* a,
23 const size_t a_stride,
24 uint8_t* a_packed) {
25
26 // Packed A format.
27 // 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory.
28 // Original A
29 // --------- K ----------- -- (K + 4 - 1) / 4 --
30 // | | | |
31 // | | (M + 8 - 1)/8 |
32 // | | Packed | |
33 // M | => |-------------------|
34 // | | Thus Packed A has (K + 4 - 1)/4 * (M + 8 -1)/8 blocks
35 // | |
36 // |---------------------|
37 //
38 // Each 8 x 4 blocks is transposed and stored.
39 // Each of the (K + 4 - 1)/4 blocks for a given group of 8 m blocks
40 // are stored adjacent in memory
41 // Thus, each block:
42 // |----8m-----|----8m-----|
43 // 4k | | ..... (K + 4 - 1)/4 blocks
44 // |-----------|-----------|
45 // This locality helps in loading 8kx8m blocks of activations
46 // Note when M is not multiple of 8, the rest can contain arbitrary
47 // data in packed A as we will not be writing those out.
48 // This wil be taken care by just copying the appropriate valid data
49
50 // Note that parts of A that are not filled are:
51 // Remainder of M blocks. So some m values are random. This is ok
52 // because when sparse gemm accumulated into them, those values will not
53 // be written out.
54 // Remainder of K blocks. When K is not multiple of 4 the remaining k
55 // in 4x8 blocks are also random. this is also ok because the packed
56 // weights will be packed with zeros such that multiplication will result
57 // in zero.
58 uint32_t num_k_blocks = (K + COL_BLOCK_SIZE -1) / COL_BLOCK_SIZE;
59 for (uint32_t k_block = 0; k_block < num_k_blocks - 1; k_block++) {
60 for (uint32_t k = 0; k < COL_BLOCK_SIZE; k++) {
61 for (uint32_t m = 0; m < mr; m++) {
62 *(a_packed + k_block * PACKED_A_BLOCK_SIZE + k * 8 + m) =
63 *(a + m * a_stride + k_block * COL_BLOCK_SIZE + k);
64 }
65 }
66 }
67 for (uint32_t k = 0; k < (K - ((num_k_blocks - 1) * COL_BLOCK_SIZE)); k++) {
68 for (uint32_t m = 0; m < mr; m++) {
69 *(a_packed + (num_k_blocks - 1) * PACKED_A_BLOCK_SIZE + k * 8 + m) =
70 *(a + m * a_stride + (num_k_blocks - 1) * COL_BLOCK_SIZE + k);
71 }
72 }
73
74 }
75