1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <qnnpack/q8avgpool.h>
14
pytorch_q8avgpool_ukernel_mp8x9p8q__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,int32_t * buffer,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_avgpool_quantization_params quantization_params[restrict static1])15 void pytorch_q8avgpool_ukernel_mp8x9p8q__neon(
16 size_t n,
17 size_t ks,
18 size_t kc,
19 const uint8_t** input,
20 const uint8_t* zero,
21 int32_t* buffer,
22 uint8_t* output,
23 size_t input_increment,
24 size_t output_increment,
25 const union pytorch_qnnp_avgpool_quantization_params
26 quantization_params[restrict static 1]) {
27 assert(n != 0);
28 assert(ks > 9);
29 assert(kc >= 8);
30
31 const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias);
32 const float32x4_t vscale = vdupq_n_f32(quantization_params->neon.scale);
33 #if defined(__aarch64__)
34 const int16x8_t voutput_zero_point =
35 vld1q_dup_s16(&quantization_params->neon.output_zero_point);
36 const uint8x8_t voutput_min =
37 vld1_dup_u8(&quantization_params->neon.output_min);
38 const uint8x8_t voutput_max =
39 vld1_dup_u8(&quantization_params->neon.output_max);
40 #else
41 const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
42 const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
43 const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
44 const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
45 #endif
46
47 do {
48 {
49 const uint8_t* i0 = *input++;
50 const uint8_t* i1 = *input++;
51 const uint8_t* i2 = *input++;
52 const uint8_t* i3 = *input++;
53 const uint8_t* i4 = *input++;
54 const uint8_t* i5 = *input++;
55 const uint8_t* i6 = *input++;
56 const uint8_t* i7 = *input++;
57 const uint8_t* i8 = *input++;
58
59 size_t k = kc;
60 int32_t* acc = buffer;
61 while (k >= 8) {
62 const uint8x8_t vi0 = vld1_u8(i0);
63 i0 += 8;
64 const uint8x8_t vi1 = vld1_u8(i1);
65 i1 += 8;
66 const uint8x8_t vi2 = vld1_u8(i2);
67 i2 += 8;
68 const uint8x8_t vi3 = vld1_u8(i3);
69 i3 += 8;
70 const uint8x8_t vi4 = vld1_u8(i4);
71 i4 += 8;
72 const uint8x8_t vi5 = vld1_u8(i5);
73 i5 += 8;
74 const uint8x8_t vi6 = vld1_u8(i6);
75 i6 += 8;
76 const uint8x8_t vi7 = vld1_u8(i7);
77 i7 += 8;
78 const uint8x8_t vi8 = vld1_u8(i8);
79 i8 += 8;
80
81 const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
82 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
83 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
84 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
85
86 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
87 const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
88 const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
89
90 const int32x4_t vacc_lo =
91 vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
92 const int32x4_t vacc_hi =
93 vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
94
95 vst1q_s32(acc, vacc_lo);
96 acc += 4;
97 vst1q_s32(acc, vacc_hi);
98 acc += 4;
99
100 k -= 8;
101 }
102 if (k != 0) {
103 const size_t address_increment = k - 8;
104 i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
105 i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
106 i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
107 i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
108 i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
109 i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
110 i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
111 i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
112 i8 = (const uint8_t*)((uintptr_t)i8 + address_increment);
113 const int64x1_t vshift = vmov_n_s64(8 * address_increment);
114
115 const uint8x8_t vi0 = vreinterpret_u8_u64(
116 vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
117 const uint8x8_t vi1 = vreinterpret_u8_u64(
118 vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
119 const uint8x8_t vi2 = vreinterpret_u8_u64(
120 vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
121 const uint8x8_t vi3 = vreinterpret_u8_u64(
122 vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
123 const uint8x8_t vi4 = vreinterpret_u8_u64(
124 vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
125 const uint8x8_t vi5 = vreinterpret_u8_u64(
126 vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
127 const uint8x8_t vi6 = vreinterpret_u8_u64(
128 vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
129 const uint8x8_t vi7 = vreinterpret_u8_u64(
130 vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift));
131 const uint8x8_t vi8 = vreinterpret_u8_u64(
132 vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vshift));
133
134 const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
135 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
136 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
137 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
138
139 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
140 const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
141 const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
142
143 const int32x4_t vacc_lo =
144 vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
145 const int32x4_t vacc_hi =
146 vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
147
148 vst1q_s32(acc, vacc_lo);
149 acc += 4;
150 vst1q_s32(acc, vacc_hi);
151 }
152 }
153
154 size_t m = ks;
155 for (m -= 9; m > 8; m -= 8) {
156 const uint8_t* i0 = *input++;
157 const uint8_t* i1 = *input++;
158 const uint8_t* i2 = *input++;
159 const uint8_t* i3 = *input++;
160 const uint8_t* i4 = *input++;
161 const uint8_t* i5 = *input++;
162 const uint8_t* i6 = *input++;
163 const uint8_t* i7 = *input++;
164
165 size_t k = kc;
166 int32_t* acc = buffer;
167 while (k >= 8) {
168 const uint8x8_t vi0 = vld1_u8(i0);
169 i0 += 8;
170 const uint8x8_t vi1 = vld1_u8(i1);
171 i1 += 8;
172 const uint8x8_t vi2 = vld1_u8(i2);
173 i2 += 8;
174 const uint8x8_t vi3 = vld1_u8(i3);
175 i3 += 8;
176 const uint8x8_t vi4 = vld1_u8(i4);
177 i4 += 8;
178 const uint8x8_t vi5 = vld1_u8(i5);
179 i5 += 8;
180 const uint8x8_t vi6 = vld1_u8(i6);
181 i6 += 8;
182 const uint8x8_t vi7 = vld1_u8(i7);
183 i7 += 8;
184 int32x4_t vacc_lo = vld1q_s32(acc);
185 int32x4_t vacc_hi = vld1q_s32(acc + 4);
186
187 const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
188 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
189 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
190 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
191
192 const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23);
193 const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67);
194 const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567);
195
196 vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum)));
197 vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum)));
198
199 vst1q_s32(acc, vacc_lo);
200 acc += 4;
201 vst1q_s32(acc, vacc_hi);
202 acc += 4;
203
204 k -= 8;
205 }
206 if (k != 0) {
207 const size_t address_increment = k - 8;
208 i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
209 i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
210 i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
211 i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
212 i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
213 i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
214 i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
215 i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
216 const int64x1_t vshift = vmov_n_s64(8 * address_increment);
217
218 const uint8x8_t vi0 = vreinterpret_u8_u64(
219 vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
220 const uint8x8_t vi1 = vreinterpret_u8_u64(
221 vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
222 const uint8x8_t vi2 = vreinterpret_u8_u64(
223 vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
224 const uint8x8_t vi3 = vreinterpret_u8_u64(
225 vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
226 const uint8x8_t vi4 = vreinterpret_u8_u64(
227 vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
228 const uint8x8_t vi5 = vreinterpret_u8_u64(
229 vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
230 const uint8x8_t vi6 = vreinterpret_u8_u64(
231 vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
232 const uint8x8_t vi7 = vreinterpret_u8_u64(
233 vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift));
234 int32x4_t vacc_lo = vld1q_s32(acc);
235 int32x4_t vacc_hi = vld1q_s32(acc + 4);
236
237 const uint16x8_t vsum01 = vaddl_u8(vi0, vi1);
238 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
239 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
240 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
241
242 const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23);
243 const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67);
244 const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567);
245
246 vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum)));
247 vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum)));
248
249 vst1q_s32(acc, vacc_lo);
250 acc += 4;
251 vst1q_s32(acc, vacc_hi);
252 }
253 }
254
255 {
256 const uint8_t* i0 = input[0];
257 const uint8_t* i1 = input[1];
258 const uint8_t* i2 = input[2];
259 const uint8_t* i3 = input[3];
260 const uint8_t* i4 = input[4];
261 const uint8_t* i5 = input[5];
262 const uint8_t* i6 = input[6];
263 const uint8_t* i7 = input[7];
264 input = (const uint8_t**)((uintptr_t)input + input_increment);
265 if (m < 2) {
266 i1 = zero;
267 }
268 if (m <= 2) {
269 i2 = zero;
270 }
271 if (m < 4) {
272 i3 = zero;
273 }
274 if (m <= 4) {
275 i4 = zero;
276 }
277 if (m < 6) {
278 i5 = zero;
279 }
280 if (m <= 6) {
281 i6 = zero;
282 }
283 if (m != 8) {
284 i7 = zero;
285 }
286
287 size_t k = kc;
288 int32_t* acc = buffer;
289 while (k >= 8) {
290 const uint8x8_t vi0 = vld1_u8(i0);
291 i0 += 8;
292 const uint8x8_t vi1 = vld1_u8(i1);
293 i1 += 8;
294 const uint8x8_t vi2 = vld1_u8(i2);
295 i2 += 8;
296 const uint8x8_t vi3 = vld1_u8(i3);
297 i3 += 8;
298 const uint8x8_t vi4 = vld1_u8(i4);
299 i4 += 8;
300 const uint8x8_t vi5 = vld1_u8(i5);
301 i5 += 8;
302 const uint8x8_t vi6 = vld1_u8(i6);
303 i6 += 8;
304 const uint8x8_t vi7 = vld1_u8(i7);
305 i7 += 8;
306 int32x4_t vacc_lo = vld1q_s32(acc);
307 acc += 4;
308 int32x4_t vacc_hi = vld1q_s32(acc);
309 acc += 4;
310
311 const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
312 const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
313 const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
314 const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
315
316 const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
317 const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
318 const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
319
320 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
321 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
322
323 float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
324 float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
325
326 vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
327 vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
328
329 #if defined(__aarch64__)
330 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
331 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
332 const int16x8_t vacc = vqaddq_s16(
333 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
334 uint8x8_t vout = vqmovun_s16(vacc);
335 vout = vmax_u8(vout, voutput_min);
336 vout = vmin_u8(vout, voutput_max);
337 #else
338 vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
339 vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
340
341 vacc_lo = vsubq_s32(
342 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
343 vacc_hi = vsubq_s32(
344 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
345 const int16x8_t vacc =
346 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
347 uint8x8_t vout = vqmovun_s16(vacc);
348 #endif
349
350 vst1_u8(output, vout);
351 output += 8;
352
353 k -= 8;
354 }
355 if (k != 0) {
356 const size_t address_increment = k - 8;
357 i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
358 i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
359 i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
360 i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
361 i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
362 i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
363 i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
364 i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
365 const int64x1_t vshift = vmov_n_s64(8 * address_increment);
366
367 const uint8x8_t vi0 = vreinterpret_u8_u64(
368 vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
369 const uint8x8_t vi1 = vreinterpret_u8_u64(
370 vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
371 const uint8x8_t vi2 = vreinterpret_u8_u64(
372 vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
373 const uint8x8_t vi3 = vreinterpret_u8_u64(
374 vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
375 const uint8x8_t vi4 = vreinterpret_u8_u64(
376 vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
377 const uint8x8_t vi5 = vreinterpret_u8_u64(
378 vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
379 const uint8x8_t vi6 = vreinterpret_u8_u64(
380 vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
381 const uint8x8_t vi7 = vreinterpret_u8_u64(
382 vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift));
383 int32x4_t vacc_lo = vld1q_s32(acc);
384 acc += 4;
385 int32x4_t vacc_hi = vld1q_s32(acc);
386
387 const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1));
388 const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3));
389 const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5));
390 const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7));
391
392 const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23);
393 const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67);
394 const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567);
395
396 vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum));
397 vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum));
398
399 float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
400 float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
401
402 vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
403 vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
404
405 #if defined(__aarch64__)
406 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
407 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
408 const int16x8_t vacc = vqaddq_s16(
409 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
410 uint8x8_t vout = vqmovun_s16(vacc);
411 vout = vmax_u8(vout, voutput_min);
412 vout = vmin_u8(vout, voutput_max);
413 #else
414 vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
415 vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
416
417 vacc_lo = vsubq_s32(
418 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
419 vacc_hi = vsubq_s32(
420 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
421 const int16x8_t vacc =
422 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
423 uint8x8_t vout = vqmovun_s16(vacc);
424 #endif
425
426 if (k & 4) {
427 vst1_lane_u32(
428 __builtin_assume_aligned(output, 1),
429 vreinterpret_u32_u8(vout),
430 0);
431 output += 4;
432 vout = vext_u8(vout, vout, 4);
433 }
434 if (k & 2) {
435 vst1_lane_u16(
436 __builtin_assume_aligned(output, 1),
437 vreinterpret_u16_u8(vout),
438 0);
439 output += 2;
440 vout = vext_u8(vout, vout, 2);
441 }
442 if (k & 1) {
443 vst1_lane_u8(output, vout, 0);
444 output += 1;
445 }
446 }
447 }
448 output = (uint8_t*)((uintptr_t)output + output_increment);
449 } while (--n != 0);
450 }
451