xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x-sumrows-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <arm_neon.h>
10 
11 #include <qnnpack/q8gemm.h>
12 
pytorch_q8sumrows_ukernel_4x__neon(const uint8_t * restrict a,size_t m,size_t k,size_t stride,const int32_t multiplier,int32_t * restrict a_sum)13 void pytorch_q8sumrows_ukernel_4x__neon(
14     const uint8_t* restrict a,
15     size_t m,
16     size_t k,
17     size_t stride,
18     const int32_t multiplier,
19     int32_t* restrict a_sum) {
20   const uint8_t* a0 = a;
21   const uint8_t* a1 = a0;
22   if (m >= 2) {
23     a1 += stride;
24   }
25   const uint8_t* a2 = a1;
26   if (m > 2) {
27     a2 += stride;
28   }
29   const uint8_t* a3 = a2;
30   if (m == 4) {
31     a3 += stride;
32   }
33 
34   uint32x4_t vacc0x0123 = vmovq_n_u32(0); // row 0
35   uint32x4_t vacc1x0123 = vmovq_n_u32(0); // row 1
36   uint32x4_t vacc2x0123 = vmovq_n_u32(0); // row 2
37   uint32x4_t vacc3x0123 = vmovq_n_u32(0); // row 3
38   for (; k >= 16; k -= 16) {
39     // row 0
40     const uint8x16_t va0x0_15 = vld1q_u8(a0);
41     a0 += 16;
42     vacc0x0123 = vpadalq_u16(
43         vacc0x0123, vaddl_u8(vget_low_u8(va0x0_15), vget_high_u8(va0x0_15)));
44 
45     // row 1
46     const uint8x16_t va1x0_15 = vld1q_u8(a1);
47     a1 += 16;
48     vacc1x0123 = vpadalq_u16(
49         vacc1x0123, vaddl_u8(vget_low_u8(va1x0_15), vget_high_u8(va1x0_15)));
50 
51     // row 2
52     const uint8x16_t va2x0_15 = vld1q_u8(a2);
53     a2 += 16;
54     vacc2x0123 = vpadalq_u16(
55         vacc2x0123, vaddl_u8(vget_low_u8(va2x0_15), vget_high_u8(va2x0_15)));
56 
57     // row 3
58     const uint8x16_t va3x0_15 = vld1q_u8(a3);
59     a3 += 16;
60     vacc3x0123 = vpadalq_u16(
61         vacc3x0123, vaddl_u8(vget_low_u8(va3x0_15), vget_high_u8(va3x0_15)));
62   }
63 
64   if (k >= 8) {
65     vacc0x0123 = vaddw_u16(vacc0x0123, vpaddl_u8(vld1_u8(a0)));
66     a0 += 8;
67     vacc1x0123 = vaddw_u16(vacc1x0123, vpaddl_u8(vld1_u8(a1)));
68     a1 += 8;
69     vacc2x0123 = vaddw_u16(vacc2x0123, vpaddl_u8(vld1_u8(a2)));
70     a2 += 8;
71     vacc3x0123 = vaddw_u16(vacc3x0123, vpaddl_u8(vld1_u8(a3)));
72     a3 += 8;
73     k -= 8;
74   }
75 
76   if (k >= 4) {
77     vacc0x0123 = vaddw_u16(
78         vacc0x0123,
79         vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
80             vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a0, 1))))));
81     a0 += 4;
82     vacc1x0123 = vaddw_u16(
83         vacc1x0123,
84         vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
85             vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a1, 1))))));
86     a1 += 4;
87     vacc2x0123 = vaddw_u16(
88         vacc2x0123,
89         vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
90             vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a2, 1))))));
91     a2 += 4;
92     vacc3x0123 = vaddw_u16(
93         vacc3x0123,
94         vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
95             vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a3, 1))))));
96     a3 += 4;
97     k -= 4;
98   }
99 
100   const uint32x2_t vsum0x01 =
101       vpadd_u32(vget_low_u32(vacc0x0123), vget_high_u32(vacc0x0123));
102   const uint32x2_t vsum1x01 =
103       vpadd_u32(vget_low_u32(vacc1x0123), vget_high_u32(vacc1x0123));
104   const uint32x2_t vsum2x01 =
105       vpadd_u32(vget_low_u32(vacc2x0123), vget_high_u32(vacc2x0123));
106   const uint32x2_t vsum3x01 =
107       vpadd_u32(vget_low_u32(vacc3x0123), vget_high_u32(vacc3x0123));
108   uint32x4_t vacc0123 = vcombine_u32(
109       vpadd_u32(vsum0x01, vsum1x01), vpadd_u32(vsum2x01, vsum3x01));
110 
111   if (k >= 2) {
112     const uint8x8_t va0x01010101 = vreinterpret_u8_u16(
113         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1)));
114     a0 += 2;
115     const uint8x8_t va1x01010101 = vreinterpret_u8_u16(
116         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1)));
117     a1 += 2;
118     const uint8x8_t va2x01010101 = vreinterpret_u8_u16(
119         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1)));
120     a2 += 2;
121     const uint8x8_t va3x01010101 = vreinterpret_u8_u16(
122         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1)));
123     a3 += 2;
124     const uint8x8_t va0x01_1x010101 = vext_u8(va0x01010101, va1x01010101, 2);
125     const uint8x8_t va2x01_3x010101 = vext_u8(va2x01010101, va3x01010101, 6);
126     const uint8x8_t va0123x01 = vext_u8(va0x01_1x010101, va2x01_3x010101, 4);
127     vacc0123 = vaddw_u16(vacc0123, vpaddl_u8(va0123x01));
128     k -= 2;
129   }
130 
131   if (k > 0) {
132     uint8x8_t vax0x1x2x3 = vmov_n_u8(0);
133     vax0x1x2x3 = vld1_lane_u8(a0, vax0x1x2x3, 0);
134     vax0x1x2x3 = vld1_lane_u8(a1, vax0x1x2x3, 2);
135     vax0x1x2x3 = vld1_lane_u8(a2, vax0x1x2x3, 4);
136     vax0x1x2x3 = vld1_lane_u8(a3, vax0x1x2x3, 6);
137     vacc0123 = vaddw_u16(vacc0123, vpaddl_u8(vax0x1x2x3));
138   }
139 
140   int32x4_t vsum0123 = vmulq_n_s32(vreinterpretq_s32_u32(vacc0123), multiplier);
141   if (m == 4) {
142     vst1q_s32(a_sum, vsum0123);
143   } else {
144     if (m >= 2) {
145       vst1_s32(a_sum, vget_low_s32(vsum0123));
146       a_sum += 2;
147       vsum0123 = vextq_s32(vsum0123, vsum0123, 2);
148       m -= 2;
149     }
150     if (m != 0) {
151       vst1q_lane_s32(a_sum, vsum0123, 0);
152     }
153   }
154 }
155