xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <qnnpack/q8avgpool.h>
14 
pytorch_q8avgpool_ukernel_mp8x9p8q__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,int32_t * buffer,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_avgpool_quantization_params quantization_params[restrict static1])15 void pytorch_q8avgpool_ukernel_mp8x9p8q__neon(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     const uint8_t* zero,
21     int32_t* buffer,
22     uint8_t* output,
23     size_t input_increment,
24     size_t output_increment,
25     const union pytorch_qnnp_avgpool_quantization_params
26         quantization_params[restrict static 1]) {
27   assert(n != 0);
28   assert(ks > 9);
29   assert(kc >= 8);
30 
31   const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias);
32   const float32x4_t vscale = vdupq_n_f32(quantization_params->neon.scale);
33 #if defined(__aarch64__)
34   const int16x8_t voutput_zero_point =
35       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
36   const uint8x8_t voutput_min =
37       vld1_dup_u8(&quantization_params->neon.output_min);
38   const uint8x8_t voutput_max =
39       vld1_dup_u8(&quantization_params->neon.output_max);
40 #else
41   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
42   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
43   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
44   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
45 #endif
46 
47   do {
48     {
49       const uint8_t* i0 = *input++;
50       const uint8_t* i1 = *input++;
51       const uint8_t* i2 = *input++;
52       const uint8_t* i3 = *input++;
53       const uint8_t* i4 = *input++;
54       const uint8_t* i5 = *input++;
55       const uint8_t* i6 = *input++;
56       const uint8_t* i7 = *input++;
57       const uint8_t* i8 = *input++;
58 
59       size_t k = kc;
60       int32_t* acc = buffer;
61       while (k >= 8) {
62         const uint8x8_t vi0 = vld1_u8(i0);
63         i0 += 8;
64         const uint8x8_t vi1 = vld1_u8(i1);
65         i1 += 8;
66         const uint8x8_t vi2 = vld1_u8(i2);
67         i2 += 8;
68         const uint8x8_t vi3 = vld1_u8(i3);
69         i3 += 8;
70         const uint8x8_t vi4 = vld1_u8(i4);
71         i4 += 8;
72         const uint8x8_t vi5 = vld1_u8(i5);
73         i5 += 8;
74         const uint8x8_t vi6 = vld1_u8(i6);
75         i6 += 8;
76         const uint8x8_t vi7 = vld1_u8(i7);
77         i7 += 8;
78         const uint8x8_t vi8 = vld1_u8(i8);
79         i8 += 8;
80 
81         const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
82         const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
83         const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
84         const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
85 
86         const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
87         const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
88         const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
89 
90         const int32x4_t vacc_lo =
91             vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
92         const int32x4_t vacc_hi =
93             vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
94 
95         vst1q_s32(acc, vacc_lo);
96         acc += 4;
97         vst1q_s32(acc, vacc_hi);
98         acc += 4;
99 
100         k -= 8;
101       }
102       if (k != 0) {
103         const size_t address_increment = k - 8;
104         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
105         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
106         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
107         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
108         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
109         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
110         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
111         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
112         i8 = (const uint8_t*)((uintptr_t)i8 + address_increment);
113         const int64x1_t vshift = vmov_n_s64(8 * address_increment);
114 
115         const uint8x8_t vi0 = vreinterpret_u8_u64(
116             vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
117         const uint8x8_t vi1 = vreinterpret_u8_u64(
118             vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
119         const uint8x8_t vi2 = vreinterpret_u8_u64(
120             vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
121         const uint8x8_t vi3 = vreinterpret_u8_u64(
122             vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
123         const uint8x8_t vi4 = vreinterpret_u8_u64(
124             vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
125         const uint8x8_t vi5 = vreinterpret_u8_u64(
126             vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
127         const uint8x8_t vi6 = vreinterpret_u8_u64(
128             vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
129         const uint8x8_t vi7 = vreinterpret_u8_u64(
130             vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift));
131         const uint8x8_t vi8 = vreinterpret_u8_u64(
132             vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vshift));
133 
134         const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
135         const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
136         const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
137         const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
138 
139         const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
140         const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
141         const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
142 
143         const int32x4_t vacc_lo =
144             vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
145         const int32x4_t vacc_hi =
146             vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
147 
148         vst1q_s32(acc, vacc_lo);
149         acc += 4;
150         vst1q_s32(acc, vacc_hi);
151       }
152     }
153 
154     size_t m = ks;
155     for (m -= 9; m > 8; m -= 8) {
156       const uint8_t* i0 = *input++;
157       const uint8_t* i1 = *input++;
158       const uint8_t* i2 = *input++;
159       const uint8_t* i3 = *input++;
160       const uint8_t* i4 = *input++;
161       const uint8_t* i5 = *input++;
162       const uint8_t* i6 = *input++;
163       const uint8_t* i7 = *input++;
164 
165       size_t k = kc;
166       int32_t* acc = buffer;
167       while (k >= 8) {
168         const uint8x8_t vi0 = vld1_u8(i0);
169         i0 += 8;
170         const uint8x8_t vi1 = vld1_u8(i1);
171         i1 += 8;
172         const uint8x8_t vi2 = vld1_u8(i2);
173         i2 += 8;
174         const uint8x8_t vi3 = vld1_u8(i3);
175         i3 += 8;
176         const uint8x8_t vi4 = vld1_u8(i4);
177         i4 += 8;
178         const uint8x8_t vi5 = vld1_u8(i5);
179         i5 += 8;
180         const uint8x8_t vi6 = vld1_u8(i6);
181         i6 += 8;
182         const uint8x8_t vi7 = vld1_u8(i7);
183         i7 += 8;
184         int32x4_t vacc_lo = vld1q_s32(acc);
185         int32x4_t vacc_hi = vld1q_s32(acc + 4);
186 
187         const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
188         const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
189         const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
190         const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
191 
192         const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23);
193         const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67);
194         const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567);
195 
196         vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum)));
197         vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum)));
198 
199         vst1q_s32(acc, vacc_lo);
200         acc += 4;
201         vst1q_s32(acc, vacc_hi);
202         acc += 4;
203 
204         k -= 8;
205       }
206       if (k != 0) {
207         const size_t address_increment = k - 8;
208         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
209         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
210         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
211         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
212         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
213         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
214         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
215         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
216         const int64x1_t vshift = vmov_n_s64(8 * address_increment);
217 
218         const uint8x8_t vi0 = vreinterpret_u8_u64(
219             vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
220         const uint8x8_t vi1 = vreinterpret_u8_u64(
221             vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
222         const uint8x8_t vi2 = vreinterpret_u8_u64(
223             vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
224         const uint8x8_t vi3 = vreinterpret_u8_u64(
225             vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
226         const uint8x8_t vi4 = vreinterpret_u8_u64(
227             vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
228         const uint8x8_t vi5 = vreinterpret_u8_u64(
229             vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
230         const uint8x8_t vi6 = vreinterpret_u8_u64(
231             vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
232         const uint8x8_t vi7 = vreinterpret_u8_u64(
233             vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift));
234         int32x4_t vacc_lo = vld1q_s32(acc);
235         int32x4_t vacc_hi = vld1q_s32(acc + 4);
236 
237         const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
238         const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
239         const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
240         const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
241 
242         const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23);
243         const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67);
244         const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567);
245 
246         vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum)));
247         vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum)));
248 
249         vst1q_s32(acc, vacc_lo);
250         acc += 4;
251         vst1q_s32(acc, vacc_hi);
252       }
253     }
254 
255     {
256       const uint8_t* i0 = input[0];
257       const uint8_t* i1 = input[1];
258       const uint8_t* i2 = input[2];
259       const uint8_t* i3 = input[3];
260       const uint8_t* i4 = input[4];
261       const uint8_t* i5 = input[5];
262       const uint8_t* i6 = input[6];
263       const uint8_t* i7 = input[7];
264       input = (const uint8_t**)((uintptr_t)input + input_increment);
265       if (m < 2) {
266         i1 = zero;
267       }
268       if (m <= 2) {
269         i2 = zero;
270       }
271       if (m < 4) {
272         i3 = zero;
273       }
274       if (m <= 4) {
275         i4 = zero;
276       }
277       if (m < 6) {
278         i5 = zero;
279       }
280       if (m <= 6) {
281         i6 = zero;
282       }
283       if (m != 8) {
284         i7 = zero;
285       }
286 
287       size_t k = kc;
288       int32_t* acc = buffer;
289       while (k >= 8) {
290         const uint8x8_t vi0 = vld1_u8(i0);
291         i0 += 8;
292         const uint8x8_t vi1 = vld1_u8(i1);
293         i1 += 8;
294         const uint8x8_t vi2 = vld1_u8(i2);
295         i2 += 8;
296         const uint8x8_t vi3 = vld1_u8(i3);
297         i3 += 8;
298         const uint8x8_t vi4 = vld1_u8(i4);
299         i4 += 8;
300         const uint8x8_t vi5 = vld1_u8(i5);
301         i5 += 8;
302         const uint8x8_t vi6 = vld1_u8(i6);
303         i6 += 8;
304         const uint8x8_t vi7 = vld1_u8(i7);
305         i7 += 8;
306         int32x4_t vacc_lo = vld1q_s32(acc);
307         acc += 4;
308         int32x4_t vacc_hi = vld1q_s32(acc);
309         acc += 4;
310 
311         const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
312         const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
313         const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
314         const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
315 
316         const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
317         const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
318         const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
319 
320         vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
321         vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
322 
323         float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
324         float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
325 
326         vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
327         vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
328 
329 #if defined(__aarch64__)
330         vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
331         vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
332         const int16x8_t vacc = vqaddq_s16(
333             vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
334         uint8x8_t vout = vqmovun_s16(vacc);
335         vout = vmax_u8(vout, voutput_min);
336         vout = vmin_u8(vout, voutput_max);
337 #else
338         vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
339         vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
340 
341         vacc_lo = vsubq_s32(
342             vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
343         vacc_hi = vsubq_s32(
344             vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
345         const int16x8_t vacc =
346             vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
347         uint8x8_t vout = vqmovun_s16(vacc);
348 #endif
349 
350         vst1_u8(output, vout);
351         output += 8;
352 
353         k -= 8;
354       }
355       if (k != 0) {
356         const size_t address_increment = k - 8;
357         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
358         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
359         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
360         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
361         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
362         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
363         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
364         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
365         const int64x1_t vshift = vmov_n_s64(8 * address_increment);
366 
367         const uint8x8_t vi0 = vreinterpret_u8_u64(
368             vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
369         const uint8x8_t vi1 = vreinterpret_u8_u64(
370             vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
371         const uint8x8_t vi2 = vreinterpret_u8_u64(
372             vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
373         const uint8x8_t vi3 = vreinterpret_u8_u64(
374             vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
375         const uint8x8_t vi4 = vreinterpret_u8_u64(
376             vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
377         const uint8x8_t vi5 = vreinterpret_u8_u64(
378             vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
379         const uint8x8_t vi6 = vreinterpret_u8_u64(
380             vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
381         const uint8x8_t vi7 = vreinterpret_u8_u64(
382             vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift));
383         int32x4_t vacc_lo = vld1q_s32(acc);
384         acc += 4;
385         int32x4_t vacc_hi = vld1q_s32(acc);
386 
387         const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
388         const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
389         const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
390         const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
391 
392         const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
393         const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
394         const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
395 
396         vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
397         vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
398 
399         float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
400         float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
401 
402         vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
403         vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
404 
405 #if defined(__aarch64__)
406         vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
407         vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
408         const int16x8_t vacc = vqaddq_s16(
409             vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
410         uint8x8_t vout = vqmovun_s16(vacc);
411         vout = vmax_u8(vout, voutput_min);
412         vout = vmin_u8(vout, voutput_max);
413 #else
414         vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
415         vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
416 
417         vacc_lo = vsubq_s32(
418             vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
419         vacc_hi = vsubq_s32(
420             vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
421         const int16x8_t vacc =
422             vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
423         uint8x8_t vout = vqmovun_s16(vacc);
424 #endif
425 
426         if (k & 4) {
427           vst1_lane_u32(
428               __builtin_assume_aligned(output, 1),
429               vreinterpret_u32_u8(vout),
430               0);
431           output += 4;
432           vout = vext_u8(vout, vout, 4);
433         }
434         if (k & 2) {
435           vst1_lane_u16(
436               __builtin_assume_aligned(output, 1),
437               vreinterpret_u16_u8(vout),
438               0);
439           output += 2;
440           vout = vext_u8(vout, vout, 2);
441         }
442         if (k & 1) {
443           vst1_lane_u8(output, vout, 0);
444           output += 1;
445         }
446       }
447     }
448     output = (uint8_t*)((uintptr_t)output + output_increment);
449   } while (--n != 0);
450 }
451