1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <qnnpack/q8avgpool.h>
14
pytorch_q8avgpool_ukernel_up8x9__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_avgpool_quantization_params quantization_params[restrict static1])15 void pytorch_q8avgpool_ukernel_up8x9__neon(
16 size_t n,
17 size_t ks,
18 size_t kc,
19 const uint8_t** input,
20 const uint8_t* zero,
21 uint8_t* output,
22 size_t input_increment,
23 size_t output_increment,
24 const union pytorch_qnnp_avgpool_quantization_params
25 quantization_params[restrict static 1]) {
26 assert(n != 0);
27 assert(ks <= 9);
28 assert(kc >= 8);
29
30 const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias);
31 const float32x4_t vscale =
32 vdupq_n_f32(quantization_params->neon.scale);
33 #if defined(__aarch64__)
34 const int16x8_t voutput_zero_point =
35 vld1q_dup_s16(&quantization_params->neon.output_zero_point);
36 const uint8x8_t voutput_min =
37 vld1_dup_u8(&quantization_params->neon.output_min);
38 const uint8x8_t voutput_max =
39 vld1_dup_u8(&quantization_params->neon.output_max);
40 #else
41 const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
42 const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
43 const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
44 const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
45 #endif
46
47 do {
48 const uint8_t* i0 = input[0];
49 const uint8_t* i1 = input[1];
50 const uint8_t* i2 = input[2];
51 const uint8_t* i3 = input[3];
52 const uint8_t* i4 = input[4];
53 const uint8_t* i5 = input[5];
54 const uint8_t* i6 = input[6];
55 const uint8_t* i7 = input[7];
56 const uint8_t* i8 = input[8];
57 input = (const uint8_t**)((uintptr_t)input + input_increment);
58 if (ks < 2) {
59 i1 = zero;
60 }
61 if (ks <= 2) {
62 i2 = zero;
63 }
64 if (ks < 4) {
65 i3 = zero;
66 }
67 if (ks <= 4) {
68 i4 = zero;
69 }
70 if (ks < 6) {
71 i5 = zero;
72 }
73 if (ks <= 6) {
74 i6 = zero;
75 }
76 if (ks < 8) {
77 i7 = zero;
78 }
79 if (ks <= 8) {
80 i8 = zero;
81 }
82
83 size_t k = kc;
84 while (k >= 8) {
85 const uint8x8_t vi0 = vld1_u8(i0);
86 i0 += 8;
87 const uint8x8_t vi1 = vld1_u8(i1);
88 i1 += 8;
89 const uint8x8_t vi2 = vld1_u8(i2);
90 i2 += 8;
91 const uint8x8_t vi3 = vld1_u8(i3);
92 i3 += 8;
93 const uint8x8_t vi4 = vld1_u8(i4);
94 i4 += 8;
95 const uint8x8_t vi5 = vld1_u8(i5);
96 i5 += 8;
97 const uint8x8_t vi6 = vld1_u8(i6);
98 i6 += 8;
99 const uint8x8_t vi7 = vld1_u8(i7);
100 i7 += 8;
101 const uint8x8_t vi8 = vld1_u8(i8);
102 i8 += 8;
103
104 const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
105 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
106 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
107 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
108
109 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
110 const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
111 const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
112
113 int32x4_t vacc_lo =
114 vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
115 int32x4_t vacc_hi =
116 vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
117
118 float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
119 float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
120
121 vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
122 vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
123
124 #if defined(__aarch64__)
125 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
126 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
127 const int16x8_t vacc = vqaddq_s16(
128 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
129 uint8x8_t vout = vqmovun_s16(vacc);
130 vout = vmax_u8(vout, voutput_min);
131 vout = vmin_u8(vout, voutput_max);
132 #else
133 vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
134 vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
135
136 vacc_lo = vsubq_s32(
137 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
138 vacc_hi = vsubq_s32(
139 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
140 const int16x8_t vacc =
141 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
142 uint8x8_t vout = vqmovun_s16(vacc);
143 #endif
144
145 vst1_u8(output, vout);
146 output += 8;
147
148 k -= 8;
149 }
150 if (k != 0) {
151 const size_t address_increment = k - 8;
152 i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
153 i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
154 i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
155 i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
156 i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
157 i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
158 i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
159 i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
160 i8 = (const uint8_t*)((uintptr_t)i8 + address_increment);
161 const int64x1_t vshift = vmov_n_s64(8 * address_increment);
162
163 const uint8x8_t vi0 = vreinterpret_u8_u64(
164 vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift));
165 const uint8x8_t vi1 = vreinterpret_u8_u64(
166 vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift));
167 const uint8x8_t vi2 = vreinterpret_u8_u64(
168 vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift));
169 const uint8x8_t vi3 = vreinterpret_u8_u64(
170 vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift));
171 const uint8x8_t vi4 = vreinterpret_u8_u64(
172 vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift));
173 const uint8x8_t vi5 = vreinterpret_u8_u64(
174 vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift));
175 const uint8x8_t vi6 = vreinterpret_u8_u64(
176 vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift));
177 const uint8x8_t vi7 = vreinterpret_u8_u64(
178 vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift));
179 const uint8x8_t vi8 = vreinterpret_u8_u64(
180 vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vshift));
181
182 const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8);
183 const uint16x8_t vsum23 = vaddl_u8(vi2, vi3);
184 const uint16x8_t vsum45 = vaddl_u8(vi4, vi5);
185 const uint16x8_t vsum67 = vaddl_u8(vi6, vi7);
186
187 const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45);
188 const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67);
189 const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678);
190
191 int32x4_t vacc_lo =
192 vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum)));
193 int32x4_t vacc_hi =
194 vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum)));
195
196 float32x4_t vacc_lo_f = vcvtq_f32_s32(vacc_lo);
197 float32x4_t vacc_hi_f = vcvtq_f32_s32(vacc_hi);
198
199 vacc_lo_f = vmulq_f32(vacc_lo_f, vscale);
200 vacc_hi_f = vmulq_f32(vacc_hi_f, vscale);
201
202 #if defined(__aarch64__)
203 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
204 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
205 const int16x8_t vacc = vqaddq_s16(
206 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
207 uint8x8_t vout = vqmovun_s16(vacc);
208 vout = vmax_u8(vout, voutput_min);
209 vout = vmin_u8(vout, voutput_max);
210 #else
211 vacc_lo_f = vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
212 vacc_hi_f = vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
213
214 vacc_lo = vsubq_s32(
215 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f, vfmagic)), vimagic);
216 vacc_hi = vsubq_s32(
217 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f, vfmagic)), vimagic);
218 const int16x8_t vacc =
219 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
220 uint8x8_t vout = vqmovun_s16(vacc);
221 #endif
222
223 if (k & 4) {
224 vst1_lane_u32(
225 __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
226 output += 4;
227 vout = vext_u8(vout, vout, 4);
228 }
229 if (k & 2) {
230 vst1_lane_u16(
231 __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
232 output += 2;
233 vout = vext_u8(vout, vout, 2);
234 }
235 if (k & 1) {
236 vst1_lane_u8(output, vout, 0);
237 output += 1;
238 }
239 }
240 output = (uint8_t*)((uintptr_t)output + output_increment);
241 } while (--n != 0);
242 }
243