1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <arm_neon.h>
10
11 #include <qnnpack/common.h>
12 #include <qnnpack/q8vadd.h>
13
pytorch_q8vadd_ukernel__neon(size_t n,const uint8_t * a,const uint8_t * b,uint8_t * y,const union pytorch_qnnp_add_quantization_params quantization_params[restrict static1])14 void pytorch_q8vadd_ukernel__neon(
15 size_t n,
16 const uint8_t* a,
17 const uint8_t* b,
18 uint8_t* y,
19 const union pytorch_qnnp_add_quantization_params
20 quantization_params[restrict static 1]) {
21 const uint8x8_t va_zero_point =
22 vld1_dup_u8(&quantization_params->neon.a_zero_point);
23 const uint8x8_t vb_zero_point =
24 vld1_dup_u8(&quantization_params->neon.b_zero_point);
25 const int16x8_t vy_zero_point =
26 vld1q_dup_s16(&quantization_params->neon.y_zero_point);
27 const int32x4_t va_multiplier =
28 vld1q_dup_s32(&quantization_params->neon.a_multiplier);
29 const int32x4_t vb_multiplier =
30 vld1q_dup_s32(&quantization_params->neon.b_multiplier);
31 const int32x4_t vright_shift =
32 vld1q_dup_s32(&quantization_params->neon.right_shift);
33 const int32x4_t vzero_shift_mask =
34 vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
35 const uint8x16_t vy_max = vld1q_dup_u8(&quantization_params->neon.y_max);
36 const uint8x16_t vy_min = vld1q_dup_u8(&quantization_params->neon.y_min);
37 if
38 PYTORCH_QNNP_LIKELY(n >= 8) {
39 #ifdef __aarch64__
40 for (; n >= 32; n -= 32) {
41 const uint8x16_t va01 = vld1q_u8(a);
42 a += 16;
43 const uint8x16_t vb01 = vld1q_u8(b);
44 b += 16;
45 const uint8x16_t va23 = vld1q_u8(a);
46 a += 16;
47 const uint8x16_t vb23 = vld1q_u8(b);
48 b += 16;
49
50 /* Subtract zero point */
51 const int16x8_t vxa0 =
52 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va01), va_zero_point));
53 const int16x8_t vxb0 =
54 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb01), vb_zero_point));
55 const int16x8_t vxa1 =
56 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va01), va_zero_point));
57 const int16x8_t vxb1 =
58 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb01), vb_zero_point));
59 const int16x8_t vxa2 =
60 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va23), va_zero_point));
61 const int16x8_t vxb2 =
62 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb23), vb_zero_point));
63 const int16x8_t vxa3 =
64 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va23), va_zero_point));
65 const int16x8_t vxb3 =
66 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb23), vb_zero_point));
67
68 /* Multiply by factors and accumulate products */
69 int32x4_t vacc0_lo =
70 vmulq_s32(vmovl_s16(vget_low_s16(vxa0)), va_multiplier);
71 int32x4_t vacc1_lo =
72 vmulq_s32(vmovl_s16(vget_low_s16(vxa1)), va_multiplier);
73 int32x4_t vacc2_lo =
74 vmulq_s32(vmovl_s16(vget_low_s16(vxa2)), va_multiplier);
75 int32x4_t vacc3_lo =
76 vmulq_s32(vmovl_s16(vget_low_s16(vxa3)), va_multiplier);
77 int32x4_t vacc0_hi = vmulq_s32(vmovl_high_s16(vxa0), va_multiplier);
78 int32x4_t vacc1_hi = vmulq_s32(vmovl_high_s16(vxa1), va_multiplier);
79 int32x4_t vacc2_hi = vmulq_s32(vmovl_high_s16(vxa2), va_multiplier);
80 int32x4_t vacc3_hi = vmulq_s32(vmovl_high_s16(vxa3), va_multiplier);
81
82 vacc0_lo =
83 vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vxb0)), vb_multiplier);
84 vacc1_lo =
85 vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vxb1)), vb_multiplier);
86 vacc2_lo =
87 vmlaq_s32(vacc2_lo, vmovl_s16(vget_low_s16(vxb2)), vb_multiplier);
88 vacc3_lo =
89 vmlaq_s32(vacc3_lo, vmovl_s16(vget_low_s16(vxb3)), vb_multiplier);
90 vacc0_hi = vmlaq_s32(vacc0_hi, vmovl_high_s16(vxb0), vb_multiplier);
91 vacc1_hi = vmlaq_s32(vacc1_hi, vmovl_high_s16(vxb1), vb_multiplier);
92 vacc2_hi = vmlaq_s32(vacc2_hi, vmovl_high_s16(vxb2), vb_multiplier);
93 vacc3_hi = vmlaq_s32(vacc3_hi, vmovl_high_s16(vxb3), vb_multiplier);
94
95 /* Shift right and round */
96 vacc0_lo =
97 vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
98 vacc1_lo =
99 vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
100 vacc2_lo =
101 vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31);
102 vacc3_lo =
103 vsraq_n_s32(vacc3_lo, vbicq_s32(vacc3_lo, vzero_shift_mask), 31);
104 vacc0_hi =
105 vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
106 vacc1_hi =
107 vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
108 vacc2_hi =
109 vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31);
110 vacc3_hi =
111 vsraq_n_s32(vacc3_hi, vbicq_s32(vacc3_hi, vzero_shift_mask), 31);
112
113 vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
114 vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
115 vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift);
116 vacc3_lo = vrshlq_s32(vacc3_lo, vright_shift);
117 vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
118 vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
119 vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift);
120 vacc3_hi = vrshlq_s32(vacc3_hi, vright_shift);
121
122 /* Pack, saturate, and add output zero point */
123 const int16x8_t vacc0 = vqaddq_s16(
124 vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), vy_zero_point);
125 const int16x8_t vacc1 = vqaddq_s16(
126 vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), vy_zero_point);
127 const int16x8_t vacc2 = vqaddq_s16(
128 vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), vy_zero_point);
129 const int16x8_t vacc3 = vqaddq_s16(
130 vqmovn_high_s32(vqmovn_s32(vacc3_lo), vacc3_hi), vy_zero_point);
131
132 uint8x16_t vy01 = vqmovun_high_s16(vqmovun_s16(vacc0), vacc1);
133 uint8x16_t vy23 = vqmovun_high_s16(vqmovun_s16(vacc2), vacc3);
134
135 vy01 = vmaxq_u8(vy01, vy_min);
136 vy23 = vmaxq_u8(vy23, vy_min);
137 vy01 = vminq_u8(vy01, vy_max);
138 vy23 = vminq_u8(vy23, vy_max);
139
140 vst1q_u8(y, vy01);
141 y += 16;
142 vst1q_u8(y, vy23);
143 y += 16;
144 }
145 #else
146 for (; n >= 16; n -= 16) {
147 const uint8x16_t va01 = vld1q_u8(a);
148 a += 16;
149 const uint8x16_t vb01 = vld1q_u8(b);
150 b += 16;
151
152 /* Subtract zero point */
153 const int16x8_t vxa0 =
154 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va01), va_zero_point));
155 const int16x8_t vxb0 =
156 vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb01), vb_zero_point));
157 const int16x8_t vxa1 =
158 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va01), va_zero_point));
159 const int16x8_t vxb1 =
160 vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb01), vb_zero_point));
161
162 /* Multiply by factors and accumulate products */
163 int32x4_t vacc0_lo =
164 vmulq_s32(vmovl_s16(vget_low_s16(vxa0)), va_multiplier);
165 int32x4_t vacc1_lo =
166 vmulq_s32(vmovl_s16(vget_low_s16(vxa1)), va_multiplier);
167 int32x4_t vacc0_hi =
168 vmulq_s32(vmovl_s16(vget_high_s16(vxa0)), va_multiplier);
169 int32x4_t vacc1_hi =
170 vmulq_s32(vmovl_s16(vget_high_s16(vxa1)), va_multiplier);
171
172 __builtin_prefetch(a + 640);
173 __builtin_prefetch(b + 640);
174
175 vacc0_lo =
176 vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vxb0)), vb_multiplier);
177 vacc1_lo =
178 vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vxb1)), vb_multiplier);
179 vacc0_hi =
180 vmlaq_s32(vacc0_hi, vmovl_s16(vget_high_s16(vxb0)), vb_multiplier);
181 vacc1_hi =
182 vmlaq_s32(vacc1_hi, vmovl_s16(vget_high_s16(vxb1)), vb_multiplier);
183
184 /* Shift right and round */
185 vacc0_lo =
186 vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31);
187 vacc1_lo =
188 vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31);
189 vacc0_hi =
190 vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31);
191 vacc1_hi =
192 vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31);
193
194 vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift);
195 vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift);
196 vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift);
197 vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift);
198
199 /* Pack, saturate, and add output zero point */
200 const int16x8_t vacc0 = vqaddq_s16(
201 vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)),
202 vy_zero_point);
203 const int16x8_t vacc1 = vqaddq_s16(
204 vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)),
205 vy_zero_point);
206
207 uint8x16_t vy01 = vcombine_u8(vqmovun_s16(vacc0), vqmovun_s16(vacc1));
208 vy01 = vmaxq_u8(vy01, vy_min);
209 vy01 = vminq_u8(vy01, vy_max);
210
211 vst1q_u8(y, vy01);
212 y += 16;
213 }
214 #endif
215 for (; n >= 8; n -= 8) {
216 const uint8x8_t va = vld1_u8(a);
217 a += 8;
218 const uint8x8_t vb = vld1_u8(b);
219 b += 8;
220
221 /* Subtract zero point */
222 const int16x8_t vxa =
223 vreinterpretq_s16_u16(vsubl_u8(va, va_zero_point));
224 const int16x8_t vxb =
225 vreinterpretq_s16_u16(vsubl_u8(vb, vb_zero_point));
226
227 /* Multiply by factors and accumulate products */
228 int32x4_t vacc_lo =
229 vmulq_s32(vmovl_s16(vget_low_s16(vxa)), va_multiplier);
230 #ifdef __aarch64__
231 int32x4_t vacc_hi = vmulq_s32(vmovl_high_s16(vxa), va_multiplier);
232 #else
233 int32x4_t vacc_hi =
234 vmulq_s32(vmovl_s16(vget_high_s16(vxa)), va_multiplier);
235 #endif
236
237 vacc_lo =
238 vmlaq_s32(vacc_lo, vmovl_s16(vget_low_s16(vxb)), vb_multiplier);
239 #ifdef __aarch64__
240 vacc_hi = vmlaq_s32(vacc_hi, vmovl_high_s16(vxb), vb_multiplier);
241 #else
242 vacc_hi =
243 vmlaq_s32(vacc_hi, vmovl_s16(vget_high_s16(vxb)), vb_multiplier);
244 #endif
245
246 /* Shift right and round */
247 vacc_lo =
248 vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
249 vacc_hi =
250 vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
251
252 vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
253 vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
254
255 /* Pack, saturate, and add output zero point */
256 #ifdef __aarch64__
257 const int16x8_t vacc = vqaddq_s16(
258 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vy_zero_point);
259 #else
260 const int16x8_t vacc = vqaddq_s16(
261 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)),
262 vy_zero_point);
263 #endif
264
265 uint8x8_t vy = vqmovun_s16(vacc);
266 vy = vmax_u8(vy, vget_low_u8(vy_min));
267 vy = vmin_u8(vy, vget_low_u8(vy_max));
268
269 vst1_u8(y, vy);
270 y += 8;
271 }
272 if (n != 0) {
273 const size_t n_increment = n - 8;
274 const int64x1_t vld_shift = vmov_n_s64(8 * n_increment);
275 const uint8x8_t va = vreinterpret_u8_u64(
276 vshl_u64(vreinterpret_u64_u8(vld1_u8(a + n_increment)), vld_shift));
277 const uint8x8_t vb = vreinterpret_u8_u64(
278 vshl_u64(vreinterpret_u64_u8(vld1_u8(b + n_increment)), vld_shift));
279
280 /* Subtract zero point */
281 const int16x8_t vxa =
282 vreinterpretq_s16_u16(vsubl_u8(va, va_zero_point));
283 const int16x8_t vxb =
284 vreinterpretq_s16_u16(vsubl_u8(vb, vb_zero_point));
285
286 /* Multiply by factors and accumulate products */
287 int32x4_t vacc_lo =
288 vmulq_s32(vmovl_s16(vget_low_s16(vxa)), va_multiplier);
289 #ifdef __aarch64__
290 int32x4_t vacc_hi = vmulq_s32(vmovl_high_s16(vxa), va_multiplier);
291 #else
292 int32x4_t vacc_hi =
293 vmulq_s32(vmovl_s16(vget_high_s16(vxa)), va_multiplier);
294 #endif
295
296 vacc_lo =
297 vmlaq_s32(vacc_lo, vmovl_s16(vget_low_s16(vxb)), vb_multiplier);
298 #ifdef __aarch64__
299 vacc_hi = vmlaq_s32(vacc_hi, vmovl_high_s16(vxb), vb_multiplier);
300 #else
301 vacc_hi =
302 vmlaq_s32(vacc_hi, vmovl_s16(vget_high_s16(vxb)), vb_multiplier);
303 #endif
304
305 /* Shift right and round */
306 vacc_lo =
307 vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31);
308 vacc_hi =
309 vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31);
310
311 vacc_lo = vrshlq_s32(vacc_lo, vright_shift);
312 vacc_hi = vrshlq_s32(vacc_hi, vright_shift);
313
314 /* Pack, saturate, and add output zero point */
315 #ifdef __aarch64__
316 const int16x8_t vacc = vqaddq_s16(
317 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vy_zero_point);
318 #else
319 const int16x8_t vacc = vqaddq_s16(
320 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)),
321 vy_zero_point);
322 #endif
323
324 uint8x8_t vy = vqmovun_s16(vacc);
325 vy = vmax_u8(vy, vget_low_u8(vy_min));
326 vy = vmin_u8(vy, vget_low_u8(vy_max));
327
328 if (n & 4) {
329 vst1_lane_u32(
330 __builtin_assume_aligned(y, 1), vreinterpret_u32_u8(vy), 0);
331 y += 4;
332 vy = vext_u8(vy, vy, 4);
333 }
334 if (n & 2) {
335 vst1_lane_u16(
336 __builtin_assume_aligned(y, 1), vreinterpret_u16_u8(vy), 0);
337 y += 2;
338 vy = vext_u8(vy, vy, 2);
339 }
340 if (n & 1) {
341 vst1_lane_u8(y, vy, 0);
342 }
343 }
344 }
345 else {
346 for (; n != 0; n--) {
347 const uint8x8_t va = vld1_dup_u8(a);
348 a += 1;
349 const uint8x8_t vb = vld1_dup_u8(b);
350 b += 1;
351
352 /* Subtract zero point */
353 const int16x4_t vxa =
354 vreinterpret_s16_u16(vget_low_u16(vsubl_u8(va, va_zero_point)));
355 const int16x4_t vxb =
356 vreinterpret_s16_u16(vget_low_u16(vsubl_u8(vb, vb_zero_point)));
357
358 /* Multiply by factors and accumulate products */
359 int32x2_t vacc =
360 vmul_s32(vget_low_s32(vmovl_s16(vxa)), vget_low_s32(va_multiplier));
361 vacc = vmla_s32(
362 vacc, vget_low_s32(vmovl_s16(vxb)), vget_low_s32(vb_multiplier));
363
364 /* Shift right and round */
365 vacc =
366 vsra_n_s32(vacc, vbic_s32(vacc, vget_low_s32(vzero_shift_mask)), 31);
367
368 vacc = vrshl_s32(vacc, vget_low_s32(vright_shift));
369
370 const int16x4_t vacc16 = vqadd_s16(
371 vqmovn_s32(vcombine_s32(vacc, vacc)), vget_low_s16(vy_zero_point));
372
373 /* Pack, saturate, and add output zero point */
374 uint8x8_t vy = vqmovun_s16(vcombine_s16(vacc16, vacc16));
375 vy = vmin_u8(vy, vget_low_u8(vy_max));
376 vy = vmax_u8(vy, vget_low_u8(vy_min));
377
378 vst1_lane_u8(y, vy, 0);
379 y += 1;
380 }
381 }
382 }
383