xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <qnnpack/q8avgpool.h>
14 
pytorch_q8avgpool_ukernel_up8x9__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_avgpool_quantization_params quantization_params[restrict static1])15 void pytorch_q8avgpool_ukernel_up8x9__neon(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     const uint8_t* zero,
21     uint8_t* output,
22     size_t input_increment,
23     size_t output_increment,
24     const union pytorch_qnnp_avgpool_quantization_params
25         quantization_params[restrict static 1]) {
26   assert(n != 0);
27   assert(ks <= 9);
28   assert(kc >= 8);
29 
30   const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias);
31   const float32x4_t vscale =
32       vdupq_n_f32(quantization_params->neon.scale);
33 #if defined(__aarch64__)
34   const int16x8_t voutput_zero_point =
35       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
36   const uint8x8_t voutput_min =
37       vld1_dup_u8(&quantization_params->neon.output_min);
38   const uint8x8_t voutput_max =
39       vld1_dup_u8(&quantization_params->neon.output_max);
40 #else
41   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
42   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
43   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
44   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
45 #endif
46 
47   do {
48     const uint8_t* i0 = input[0];
49     const uint8_t* i1 = input[1];
50     const uint8_t* i2 = input[2];
51     const uint8_t* i3 = input[3];
52     const uint8_t* i4 = input[4];
53     const uint8_t* i5 = input[5];
54     const uint8_t* i6 = input[6];
55     const uint8_t* i7 = input[7];
56     const uint8_t* i8 = input[8];
57     input = (const uint8_t**)((uintptr_t)input + input_increment);
58     if (ks < 2) {
59       i1 = zero;
60     }
61     if (ks <= 2) {
62       i2 = zero;
63     }
64     if (ks < 4) {
65       i3 = zero;
66     }
67     if (ks <= 4) {
68       i4 = zero;
69     }
70     if (ks < 6) {
71       i5 = zero;
72     }
73     if (ks <= 6) {
74       i6 = zero;
75     }
76     if (ks < 8) {
77       i7 = zero;
78     }
79     if (ks <= 8) {
80       i8 = zero;
81     }
82 
83     size_t k = kc;
84     while (k >= 8) {
85       const uint8x8_t vi0 = vld1_u8(i0);
86       i0 += 8;
87       const uint8x8_t vi1 = vld1_u8(i1);
88       i1 += 8;
89       const uint8x8_t vi2 = vld1_u8(i2);
90       i2 += 8;
91       const uint8x8_t vi3 = vld1_u8(i3);
92       i3 += 8;
93       const uint8x8_t vi4 = vld1_u8(i4);
94       i4 += 8;
95       const uint8x8_t vi5 = vld1_u8(i5);
96       i5 += 8;
97       const uint8x8_t vi6 = vld1_u8(i6);
98       i6 += 8;
99       const uint8x8_t vi7 = vld1_u8(i7);
100       i7 += 8;
101       const uint8x8_t vi8 = vld1_u8(i8);
102       i8 += 8;
103 
104       const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
105       const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
106       const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
107       const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
108 
109       const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
110       const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
111       const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
112 
113       int32x4_t vacc_lo =
114           vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
115       int32x4_t vacc_hi =
116           vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
117 
118       float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
119       float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
120 
121       vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
122       vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
123 
124 #if defined(__aarch64__)
125       vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
126       vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
127       const int16x8_t vacc = vqaddq_s16(
128           vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
129       uint8x8_t vout = vqmovun_s16(vacc);
130       vout = vmax_u8(vout, voutput_min);
131       vout = vmin_u8(vout, voutput_max);
132 #else
133       vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
134       vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
135 
136       vacc_lo = vsubq_s32(
137           vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
138       vacc_hi = vsubq_s32(
139           vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
140       const int16x8_t vacc =
141           vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
142       uint8x8_t vout = vqmovun_s16(vacc);
143 #endif
144 
145       vst1_u8(output, vout);
146       output += 8;
147 
148       k -= 8;
149     }
150     if (k != 0) {
151       const size_t address_increment = k - 8;
152       i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
153       i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
154       i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
155       i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
156       i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
157       i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
158       i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
159       i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
160       i8 = (const uint8_t*)((uintptr_t)i8 + address_increment);
161       const int64x1_t vshift = vmov_n_s64(8 * address_increment);
162 
163       const uint8x8_t vi0 = vreinterpret_u8_u64(
164           vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
165       const uint8x8_t vi1 = vreinterpret_u8_u64(
166           vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
167       const uint8x8_t vi2 = vreinterpret_u8_u64(
168           vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
169       const uint8x8_t vi3 = vreinterpret_u8_u64(
170           vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
171       const uint8x8_t vi4 = vreinterpret_u8_u64(
172           vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
173       const uint8x8_t vi5 = vreinterpret_u8_u64(
174           vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
175       const uint8x8_t vi6 = vreinterpret_u8_u64(
176           vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
177       const uint8x8_t vi7 = vreinterpret_u8_u64(
178           vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift));
179       const uint8x8_t vi8 = vreinterpret_u8_u64(
180           vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vshift));
181 
182       const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
183       const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
184       const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
185       const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
186 
187       const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
188       const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
189       const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
190 
191       int32x4_t vacc_lo =
192           vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
193       int32x4_t vacc_hi =
194           vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
195 
196       float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
197       float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
198 
199       vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
200       vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
201 
202 #if defined(__aarch64__)
203       vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
204       vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
205       const int16x8_t vacc = vqaddq_s16(
206           vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
207       uint8x8_t vout = vqmovun_s16(vacc);
208       vout = vmax_u8(vout, voutput_min);
209       vout = vmin_u8(vout, voutput_max);
210 #else
211       vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
212       vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
213 
214       vacc_lo = vsubq_s32(
215           vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
216       vacc_hi = vsubq_s32(
217           vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
218       const int16x8_t vacc =
219           vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
220       uint8x8_t vout = vqmovun_s16(vacc);
221 #endif
222 
223       if (k & 4) {
224         vst1_lane_u32(
225             __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
226         output += 4;
227         vout = vext_u8(vout, vout, 4);
228       }
229       if (k & 2) {
230         vst1_lane_u16(
231             __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
232         output += 2;
233         vout = vext_u8(vout, vout, 2);
234       }
235       if (k & 1) {
236         vst1_lane_u8(output, vout, 0);
237         output += 1;
238       }
239     }
240     output = (uint8_t*)((uintptr_t)output + output_increment);
241   } while (--n != 0);
242 }
243