xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8vadd/neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <arm_neon.h>
10 
11 #include <qnnpack/common.h>
12 #include <qnnpack/q8vadd.h>
13 
pytorch_q8vadd_ukernel__neon(size_t n,const uint8_t * a,const uint8_t * b,uint8_t * y,const union pytorch_qnnp_add_quantization_params quantization_params[restrict static1])14 void pytorch_q8vadd_ukernel__neon(
15     size_t n,
16     const uint8_t* a,
17     const uint8_t* b,
18     uint8_t* y,
19     const union pytorch_qnnp_add_quantization_params
20         quantization_params[restrict static 1]) {
21   const uint8x8_t va_zero_point =
22       vld1_dup_u8(&quantization_params->neon.a_zero_point);
23   const uint8x8_t vb_zero_point =
24       vld1_dup_u8(&quantization_params->neon.b_zero_point);
25   const int16x8_t vy_zero_point =
26       vld1q_dup_s16(&quantization_params->neon.y_zero_point);
27   const int32x4_t va_multiplier =
28       vld1q_dup_s32(&quantization_params->neon.a_multiplier);
29   const int32x4_t vb_multiplier =
30       vld1q_dup_s32(&quantization_params->neon.b_multiplier);
31   const int32x4_t vright_shift =
32       vld1q_dup_s32(&quantization_params->neon.right_shift);
33   const int32x4_t vzero_shift_mask =
34       vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
35   const uint8x16_t vy_max = vld1q_dup_u8(&quantization_params->neon.y_max);
36   const uint8x16_t vy_min = vld1q_dup_u8(&quantization_params->neon.y_min);
37   if
38     PYTORCH_QNNP_LIKELY(n >= 8) {
39 #ifdef __aarch64__
40       for (; n >= 32; n -= 32) {
41         const uint8x16_t va01 = vld1q_u8(a);
42         a += 16;
43         const uint8x16_t vb01 = vld1q_u8(b);
44         b += 16;
45         const uint8x16_t va23 = vld1q_u8(a);
46         a += 16;
47         const uint8x16_t vb23 = vld1q_u8(b);
48         b += 16;
49 
50         /* Subtract zero point */
51         const int16x8_t vxa0 =
52             vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va01), va_zero_point));
53         const int16x8_t vxb0 =
54             vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb01), vb_zero_point));
55         const int16x8_t vxa1 =
56             vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va01), va_zero_point));
57         const int16x8_t vxb1 =
58             vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb01), vb_zero_point));
59         const int16x8_t vxa2 =
60             vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va23), va_zero_point));
61         const int16x8_t vxb2 =
62             vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb23), vb_zero_point));
63         const int16x8_t vxa3 =
64             vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va23), va_zero_point));
65         const int16x8_t vxb3 =
66             vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb23), vb_zero_point));
67 
68         /* Multiply by factors and accumulate products */
69         int32x4_t vacc0_lo =
70             vmulq_s32(vmovl_s16(vget_low_s16(vxa0)), va_multiplier);
71         int32x4_t vacc1_lo =
72             vmulq_s32(vmovl_s16(vget_low_s16(vxa1)), va_multiplier);
73         int32x4_t vacc2_lo =
74             vmulq_s32(vmovl_s16(vget_low_s16(vxa2)), va_multiplier);
75         int32x4_t vacc3_lo =
76             vmulq_s32(vmovl_s16(vget_low_s16(vxa3)), va_multiplier);
77         int32x4_t vacc0_hi = vmulq_s32(vmovl_high_s16(vxa0), va_multiplier);
78         int32x4_t vacc1_hi = vmulq_s32(vmovl_high_s16(vxa1), va_multiplier);
79         int32x4_t vacc2_hi = vmulq_s32(vmovl_high_s16(vxa2), va_multiplier);
80         int32x4_t vacc3_hi = vmulq_s32(vmovl_high_s16(vxa3), va_multiplier);
81 
82         vacc0_lo =
83             vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vxb0)), vb_multiplier);
84         vacc1_lo =
85             vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vxb1)), vb_multiplier);
86         vacc2_lo =
87             vmlaq_s32(vacc2_lo, vmovl_s16(vget_low_s16(vxb2)), vb_multiplier);
88         vacc3_lo =
89             vmlaq_s32(vacc3_lo, vmovl_s16(vget_low_s16(vxb3)), vb_multiplier);
90         vacc0_hi = vmlaq_s32(vacc0_hi, vmovl_high_s16(vxb0), vb_multiplier);
91         vacc1_hi = vmlaq_s32(vacc1_hi, vmovl_high_s16(vxb1), vb_multiplier);
92         vacc2_hi = vmlaq_s32(vacc2_hi, vmovl_high_s16(vxb2), vb_multiplier);
93         vacc3_hi = vmlaq_s32(vacc3_hi, vmovl_high_s16(vxb3), vb_multiplier);
94 
95         /* Shift right and round */
96         vacc0_lo =
97             vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
98         vacc1_lo =
99             vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
100         vacc2_lo =
101             vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
102         vacc3_lo =
103             vsraq_n_s32(vacc3_lo, vbicq_s32(vacc3_lo, vzero_shift_mask), 31);
104         vacc0_hi =
105             vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
106         vacc1_hi =
107             vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
108         vacc2_hi =
109             vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
110         vacc3_hi =
111             vsraq_n_s32(vacc3_hi, vbicq_s32(vacc3_hi, vzero_shift_mask), 31);
112 
113         vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
114         vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
115         vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
116         vacc3_lo = vrshlq_s32(vacc3_lo, vright_shift);
117         vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
118         vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
119         vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
120         vacc3_hi = vrshlq_s32(vacc3_hi, vright_shift);
121 
122         /* Pack, saturate, and add output zero point */
123         const int16x8_t vacc0 = vqaddq_s16(
124             vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), vy_zero_point);
125         const int16x8_t vacc1 = vqaddq_s16(
126             vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), vy_zero_point);
127         const int16x8_t vacc2 = vqaddq_s16(
128             vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), vy_zero_point);
129         const int16x8_t vacc3 = vqaddq_s16(
130             vqmovn_high_s32(vqmovn_s32(vacc3_lo), vacc3_hi), vy_zero_point);
131 
132         uint8x16_t vy01 = vqmovun_high_s16(vqmovun_s16(vacc0), vacc1);
133         uint8x16_t vy23 = vqmovun_high_s16(vqmovun_s16(vacc2), vacc3);
134 
135         vy01 = vmaxq_u8(vy01, vy_min);
136         vy23 = vmaxq_u8(vy23, vy_min);
137         vy01 = vminq_u8(vy01, vy_max);
138         vy23 = vminq_u8(vy23, vy_max);
139 
140         vst1q_u8(y, vy01);
141         y += 16;
142         vst1q_u8(y, vy23);
143         y += 16;
144       }
145 #else
146       for (; n >= 16; n -= 16) {
147         const uint8x16_t va01 = vld1q_u8(a);
148         a += 16;
149         const uint8x16_t vb01 = vld1q_u8(b);
150         b += 16;
151 
152         /* Subtract zero point */
153         const int16x8_t vxa0 =
154             vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va01), va_zero_point));
155         const int16x8_t vxb0 =
156             vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb01), vb_zero_point));
157         const int16x8_t vxa1 =
158             vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va01), va_zero_point));
159         const int16x8_t vxb1 =
160             vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb01), vb_zero_point));
161 
162         /* Multiply by factors and accumulate products */
163         int32x4_t vacc0_lo =
164             vmulq_s32(vmovl_s16(vget_low_s16(vxa0)), va_multiplier);
165         int32x4_t vacc1_lo =
166             vmulq_s32(vmovl_s16(vget_low_s16(vxa1)), va_multiplier);
167         int32x4_t vacc0_hi =
168             vmulq_s32(vmovl_s16(vget_high_s16(vxa0)), va_multiplier);
169         int32x4_t vacc1_hi =
170             vmulq_s32(vmovl_s16(vget_high_s16(vxa1)), va_multiplier);
171 
172         __builtin_prefetch(a + 640);
173         __builtin_prefetch(b + 640);
174 
175         vacc0_lo =
176             vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vxb0)), vb_multiplier);
177         vacc1_lo =
178             vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vxb1)), vb_multiplier);
179         vacc0_hi =
180             vmlaq_s32(vacc0_hi, vmovl_s16(vget_high_s16(vxb0)), vb_multiplier);
181         vacc1_hi =
182             vmlaq_s32(vacc1_hi, vmovl_s16(vget_high_s16(vxb1)), vb_multiplier);
183 
184         /* Shift right and round */
185         vacc0_lo =
186             vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
187         vacc1_lo =
188             vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
189         vacc0_hi =
190             vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
191         vacc1_hi =
192             vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
193 
194         vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
195         vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
196         vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
197         vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
198 
199         /* Pack, saturate, and add output zero point */
200         const int16x8_t vacc0 = vqaddq_s16(
201             vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)),
202             vy_zero_point);
203         const int16x8_t vacc1 = vqaddq_s16(
204             vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)),
205             vy_zero_point);
206 
207         uint8x16_t vy01 = vcombine_u8(vqmovun_s16(vacc0), vqmovun_s16(vacc1));
208         vy01 = vmaxq_u8(vy01, vy_min);
209         vy01 = vminq_u8(vy01, vy_max);
210 
211         vst1q_u8(y, vy01);
212         y += 16;
213       }
214 #endif
215       for (; n >= 8; n -= 8) {
216         const uint8x8_t va = vld1_u8(a);
217         a += 8;
218         const uint8x8_t vb = vld1_u8(b);
219         b += 8;
220 
221         /* Subtract zero point */
222         const int16x8_t vxa =
223             vreinterpretq_s16_u16(vsubl_u8(va, va_zero_point));
224         const int16x8_t vxb =
225             vreinterpretq_s16_u16(vsubl_u8(vb, vb_zero_point));
226 
227         /* Multiply by factors and accumulate products */
228         int32x4_t vacc_lo =
229             vmulq_s32(vmovl_s16(vget_low_s16(vxa)), va_multiplier);
230 #ifdef __aarch64__
231         int32x4_t vacc_hi = vmulq_s32(vmovl_high_s16(vxa), va_multiplier);
232 #else
233         int32x4_t vacc_hi =
234             vmulq_s32(vmovl_s16(vget_high_s16(vxa)), va_multiplier);
235 #endif
236 
237         vacc_lo =
238             vmlaq_s32(vacc_lo, vmovl_s16(vget_low_s16(vxb)), vb_multiplier);
239 #ifdef __aarch64__
240         vacc_hi = vmlaq_s32(vacc_hi, vmovl_high_s16(vxb), vb_multiplier);
241 #else
242         vacc_hi =
243             vmlaq_s32(vacc_hi, vmovl_s16(vget_high_s16(vxb)), vb_multiplier);
244 #endif
245 
246         /* Shift right and round */
247         vacc_lo =
248             vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
249         vacc_hi =
250             vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
251 
252         vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
253         vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
254 
255         /* Pack, saturate, and add output zero point */
256 #ifdef __aarch64__
257         const int16x8_t vacc = vqaddq_s16(
258             vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vy_zero_point);
259 #else
260         const int16x8_t vacc = vqaddq_s16(
261             vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)),
262             vy_zero_point);
263 #endif
264 
265         uint8x8_t vy = vqmovun_s16(vacc);
266         vy = vmax_u8(vy, vget_low_u8(vy_min));
267         vy = vmin_u8(vy, vget_low_u8(vy_max));
268 
269         vst1_u8(y, vy);
270         y += 8;
271       }
272       if (n != 0) {
273         const size_t n_increment = n - 8;
274         const int64x1_t vld_shift = vmov_n_s64(8 * n_increment);
275         const uint8x8_t va = vreinterpret_u8_u64(
276             vshl_u64(vreinterpret_u64_u8(vld1_u8(a + n_increment)), vld_shift));
277         const uint8x8_t vb = vreinterpret_u8_u64(
278             vshl_u64(vreinterpret_u64_u8(vld1_u8(b + n_increment)), vld_shift));
279 
280         /* Subtract zero point */
281         const int16x8_t vxa =
282             vreinterpretq_s16_u16(vsubl_u8(va, va_zero_point));
283         const int16x8_t vxb =
284             vreinterpretq_s16_u16(vsubl_u8(vb, vb_zero_point));
285 
286         /* Multiply by factors and accumulate products */
287         int32x4_t vacc_lo =
288             vmulq_s32(vmovl_s16(vget_low_s16(vxa)), va_multiplier);
289 #ifdef __aarch64__
290         int32x4_t vacc_hi = vmulq_s32(vmovl_high_s16(vxa), va_multiplier);
291 #else
292         int32x4_t vacc_hi =
293             vmulq_s32(vmovl_s16(vget_high_s16(vxa)), va_multiplier);
294 #endif
295 
296         vacc_lo =
297             vmlaq_s32(vacc_lo, vmovl_s16(vget_low_s16(vxb)), vb_multiplier);
298 #ifdef __aarch64__
299         vacc_hi = vmlaq_s32(vacc_hi, vmovl_high_s16(vxb), vb_multiplier);
300 #else
301         vacc_hi =
302             vmlaq_s32(vacc_hi, vmovl_s16(vget_high_s16(vxb)), vb_multiplier);
303 #endif
304 
305         /* Shift right and round */
306         vacc_lo =
307             vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
308         vacc_hi =
309             vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
310 
311         vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
312         vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
313 
314         /* Pack, saturate, and add output zero point */
315 #ifdef __aarch64__
316         const int16x8_t vacc = vqaddq_s16(
317             vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vy_zero_point);
318 #else
319         const int16x8_t vacc = vqaddq_s16(
320             vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)),
321             vy_zero_point);
322 #endif
323 
324         uint8x8_t vy = vqmovun_s16(vacc);
325         vy = vmax_u8(vy, vget_low_u8(vy_min));
326         vy = vmin_u8(vy, vget_low_u8(vy_max));
327 
328         if (n & 4) {
329           vst1_lane_u32(
330               __builtin_assume_aligned(y, 1), vreinterpret_u32_u8(vy), 0);
331           y += 4;
332           vy = vext_u8(vy, vy, 4);
333         }
334         if (n & 2) {
335           vst1_lane_u16(
336               __builtin_assume_aligned(y, 1), vreinterpret_u16_u8(vy), 0);
337           y += 2;
338           vy = vext_u8(vy, vy, 2);
339         }
340         if (n & 1) {
341           vst1_lane_u8(y, vy, 0);
342         }
343       }
344     }
345   else {
346     for (; n != 0; n--) {
347       const uint8x8_t va = vld1_dup_u8(a);
348       a += 1;
349       const uint8x8_t vb = vld1_dup_u8(b);
350       b += 1;
351 
352       /* Subtract zero point */
353       const int16x4_t vxa =
354           vreinterpret_s16_u16(vget_low_u16(vsubl_u8(va, va_zero_point)));
355       const int16x4_t vxb =
356           vreinterpret_s16_u16(vget_low_u16(vsubl_u8(vb, vb_zero_point)));
357 
358       /* Multiply by factors and accumulate products */
359       int32x2_t vacc =
360           vmul_s32(vget_low_s32(vmovl_s16(vxa)), vget_low_s32(va_multiplier));
361       vacc = vmla_s32(
362           vacc, vget_low_s32(vmovl_s16(vxb)), vget_low_s32(vb_multiplier));
363 
364       /* Shift right and round */
365       vacc =
366           vsra_n_s32(vacc, vbic_s32(vacc, vget_low_s32(vzero_shift_mask)), 31);
367 
368       vacc = vrshl_s32(vacc, vget_low_s32(vright_shift));
369 
370       const int16x4_t vacc16 = vqadd_s16(
371           vqmovn_s32(vcombine_s32(vacc, vacc)), vget_low_s16(vy_zero_point));
372 
373       /* Pack, saturate, and add output zero point */
374       uint8x8_t vy = vqmovun_s16(vcombine_s16(vacc16, vacc16));
375       vy = vmin_u8(vy, vget_low_u8(vy_max));
376       vy = vmax_u8(vy, vget_low_u8(vy_min));
377 
378       vst1_lane_u8(y, vy, 0);
379       y += 1;
380     }
381   }
382 }
383