xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <immintrin.h>
10 
11 #include <qnnpack/q8gemm_sparse.h>
12 #include <requantization/runtime-sse2.h>
13 
14 #include "8x4c1x4-packed-sse2.h"
15 
16 // This is a super slow kernel in that it does not use intrinsics to
17 // tranpose. Since this is for x86 we are not optimizing it.
18 // For ARM this will be optimized.
pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2(const size_t mr,const size_t K,const uint8_t * a,const size_t a_stride,uint8_t * a_packed)19 void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2(
20     const size_t mr,
21     const size_t K,
22     const uint8_t* a,
23     const size_t a_stride,
24     uint8_t* a_packed) {
25 
26   // Packed A format.
27   // 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory.
28   // Original A
29   // --------- K -----------          -- (K + 4 - 1) / 4 --
30   // |                     |          |                   |
31   // |                     |        (M + 8 - 1)/8         |
32   // |                     | Packed   |                   |
33   // M                     |  =>      |-------------------|
34   // |                     |        Thus Packed A has (K + 4 - 1)/4 * (M + 8 -1)/8 blocks
35   // |                     |
36   // |---------------------|
37   //
38   // Each 8 x 4 blocks is transposed and stored.
39   // Each of the (K + 4 - 1)/4 blocks for a given group of 8 m blocks
40   // are stored adjacent in memory
41   // Thus, each block:
42   // |----8m-----|----8m-----|
43   // 4k          |           | ..... (K + 4 - 1)/4 blocks
44   // |-----------|-----------|
45   // This locality helps in loading 8kx8m blocks of activations
46   // Note when M is not multiple of 8, the rest can contain arbitrary
47   // data in packed A as we will not be writing those out.
48   // This wil be taken care by just copying the appropriate valid data
49 
50   // Note that parts of A that are not filled are:
51   // Remainder of M blocks. So some m values are random. This is ok
52   // because when sparse gemm accumulated into them, those values will not
53   // be written out.
54   // Remainder of K blocks. When K is not multiple of 4 the remaining k
55   // in 4x8 blocks are also random. this is also ok because the packed
56   // weights will be packed with zeros such that multiplication will result
57   // in zero.
58   uint32_t num_k_blocks = (K + COL_BLOCK_SIZE -1) / COL_BLOCK_SIZE;
59   for (uint32_t k_block = 0; k_block < num_k_blocks - 1; k_block++) {
60     for (uint32_t k = 0; k < COL_BLOCK_SIZE; k++) {
61       for (uint32_t m = 0; m < mr; m++) {
62         *(a_packed + k_block * PACKED_A_BLOCK_SIZE + k * 8 + m) =
63           *(a + m * a_stride + k_block * COL_BLOCK_SIZE + k);
64       }
65     }
66   }
67   for (uint32_t k = 0; k < (K - ((num_k_blocks - 1) * COL_BLOCK_SIZE)); k++) {
68     for (uint32_t m = 0; m < mr; m++) {
69       *(a_packed + (num_k_blocks - 1) * PACKED_A_BLOCK_SIZE + k * 8 + m) =
70         *(a + m * a_stride + (num_k_blocks - 1) * COL_BLOCK_SIZE + k);
71     }
72   }
73 
74 }
75