1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <arm_neon.h>
10
11 #include <qnnpack/q8dwconv.h>
12 #include <requantization/runtime-neon.h>
13
pytorch_q8dwconv_ukernel_up8x9__neon(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static1])14 void pytorch_q8dwconv_ukernel_up8x9__neon(
15 size_t channels,
16 size_t output_width,
17 const uint8_t** input,
18 const void* weights,
19 uint8_t* output,
20 size_t input_stride,
21 size_t output_increment,
22 const union pytorch_qnnp_conv_quantization_params
23 quantization_params[restrict static 1]) {
24 const uint8x8_t va_zero_point =
25 vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
26 const uint8x8_t vkernel_zero_point =
27 vdup_n_u8(quantization_params->neon.kernel_zero_points[0]);
28 const float32x4_t requantization_scale_v =
29 vdupq_n_f32(quantization_params->neon.requantization_scales[0]);
30 #ifdef __aarch64__
31 const int16x8_t voutput_zero_point =
32 vld1q_dup_s16(&quantization_params->neon.output_zero_point);
33 const uint8x8_t voutput_min =
34 vld1_dup_u8(&quantization_params->neon.output_min);
35 const uint8x8_t voutput_max =
36 vld1_dup_u8(&quantization_params->neon.output_max);
37 #else
38 const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
39 const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
40 const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
41 const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
42 #endif
43
44 #ifdef __aarch64__
45 /* Larger number of registers on AArch64 make it possible to process few
46 * pixels at a time */
47 if (input_stride == 3 * sizeof(void*)) {
48 for (; output_width >= 3; output_width -= 3) {
49 const uint8_t* i00 = input[0];
50 const uint8_t* i10 = input[1];
51 const uint8_t* i20 = input[2];
52 const uint8_t* i01 = input[3];
53 const uint8_t* i11 = input[4];
54 const uint8_t* i21 = input[5];
55 const uint8_t* i02 = input[6];
56 const uint8_t* i12 = input[7];
57 const uint8_t* i22 = input[8];
58 const uint8_t* i03 = input[9];
59 const uint8_t* i13 = input[10];
60 const uint8_t* i23 = input[11];
61 const uint8_t* i04 = input[12];
62 const uint8_t* i14 = input[13];
63 const uint8_t* i24 = input[14];
64
65 uint8_t* output0 = output;
66 uint8_t* output1 = output0 + channels + output_increment;
67 uint8_t* output2 = output1 + channels + output_increment;
68
69 input += 9;
70
71 size_t c = channels;
72 const void* w = weights;
73 for (; c >= 8; c -= 8) {
74 int32x4_t vacc0_lo = vld1q_s32(w);
75 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
76 int32x4_t vacc0_hi = vld1q_s32(w);
77 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
78 int32x4_t vacc1_lo = vacc0_lo;
79 int32x4_t vacc2_lo = vacc0_lo;
80 int32x4_t vacc1_hi = vacc0_hi;
81 int32x4_t vacc2_hi = vacc0_hi;
82
83 const uint8x8_t vk00 = vld1_u8(w);
84 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
85 const uint8x8_t vi00 = vld1_u8(i00);
86 i00 += 8;
87 const uint8x8_t vi01 = vld1_u8(i01);
88 i01 += 8;
89 const uint8x8_t vi02 = vld1_u8(i02);
90 i02 += 8;
91 const int16x8_t vxk00 =
92 vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
93 const int16x8_t vxi00 =
94 vreinterpretq_s16_u16(sub_zero_point(vi00, va_zero_point));
95 const int16x8_t vxi01 =
96 vreinterpretq_s16_u16(sub_zero_point(vi01, va_zero_point));
97 const int16x8_t vxi02 =
98 vreinterpretq_s16_u16(sub_zero_point(vi02, va_zero_point));
99 vacc0_lo =
100 vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
101 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
102 vacc1_lo =
103 vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
104 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
105 vacc2_lo =
106 vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
107 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
108
109 const uint8x8_t vk10 = vld1_u8(w);
110 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
111 const uint8x8_t vi10 = vld1_u8(i10);
112 i10 += 8;
113 const uint8x8_t vi11 = vld1_u8(i11);
114 i11 += 8;
115 const uint8x8_t vi12 = vld1_u8(i12);
116 i12 += 8;
117 const int16x8_t vxk10 =
118 vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
119 const int16x8_t vxi10 =
120 vreinterpretq_s16_u16(sub_zero_point(vi10, va_zero_point));
121 const int16x8_t vxi11 =
122 vreinterpretq_s16_u16(sub_zero_point(vi11, va_zero_point));
123 const int16x8_t vxi12 =
124 vreinterpretq_s16_u16(sub_zero_point(vi12, va_zero_point));
125 vacc0_lo =
126 vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
127 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
128 vacc1_lo =
129 vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
130 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
131 vacc2_lo =
132 vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
133 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
134
135 const uint8x8_t vk20 = vld1_u8(w);
136 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
137 const uint8x8_t vi20 = vld1_u8(i20);
138 i20 += 8;
139 const uint8x8_t vi21 = vld1_u8(i21);
140 i21 += 8;
141 const uint8x8_t vi22 = vld1_u8(i22);
142 i22 += 8;
143 const int16x8_t vxk20 =
144 vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
145 const int16x8_t vxi20 =
146 vreinterpretq_s16_u16(sub_zero_point(vi20, va_zero_point));
147 const int16x8_t vxi21 =
148 vreinterpretq_s16_u16(sub_zero_point(vi21, va_zero_point));
149 const int16x8_t vxi22 =
150 vreinterpretq_s16_u16(sub_zero_point(vi22, va_zero_point));
151 vacc0_lo =
152 vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
153 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
154 vacc1_lo =
155 vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
156 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
157 vacc2_lo =
158 vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
159 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
160
161 const uint8x8_t vk01 = vld1_u8(w);
162 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
163 const uint8x8_t vi03 = vld1_u8(i03);
164 i03 += 8;
165 const int16x8_t vxk01 =
166 vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
167 const int16x8_t vxi03 =
168 vreinterpretq_s16_u16(sub_zero_point(vi03, va_zero_point));
169 vacc0_lo =
170 vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
171 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
172 vacc1_lo =
173 vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
174 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
175 vacc2_lo =
176 vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
177 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
178
179 const uint8x8_t vk11 = vld1_u8(w);
180 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
181 const uint8x8_t vi13 = vld1_u8(i13);
182 i13 += 8;
183 const int16x8_t vxk11 =
184 vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
185 const int16x8_t vxi13 =
186 vreinterpretq_s16_u16(sub_zero_point(vi13, va_zero_point));
187 vacc0_lo =
188 vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
189 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
190 vacc1_lo =
191 vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
192 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
193 vacc2_lo =
194 vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
195 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
196
197 const uint8x8_t vk21 = vld1_u8(w);
198 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
199 const uint8x8_t vi23 = vld1_u8(i23);
200 i23 += 8;
201 const int16x8_t vxk21 =
202 vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
203 const int16x8_t vxi23 =
204 vreinterpretq_s16_u16(sub_zero_point(vi23, va_zero_point));
205 vacc0_lo =
206 vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
207 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
208 vacc1_lo =
209 vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
210 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
211 vacc2_lo =
212 vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
213 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
214
215 const uint8x8_t vk02 = vld1_u8(w);
216 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
217 const uint8x8_t vi04 = vld1_u8(i04);
218 i04 += 8;
219 const int16x8_t vxk02 =
220 vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
221 const int16x8_t vxi04 =
222 vreinterpretq_s16_u16(sub_zero_point(vi04, va_zero_point));
223 vacc0_lo =
224 vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
225 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
226 vacc1_lo =
227 vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
228 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
229 vacc2_lo =
230 vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
231 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
232
233 const uint8x8_t vk12 = vld1_u8(w);
234 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
235 const uint8x8_t vi14 = vld1_u8(i14);
236 i14 += 8;
237 const int16x8_t vxk12 =
238 vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
239 const int16x8_t vxi14 =
240 vreinterpretq_s16_u16(sub_zero_point(vi14, va_zero_point));
241 vacc0_lo =
242 vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
243 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
244 vacc1_lo =
245 vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
246 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
247 vacc2_lo =
248 vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
249 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
250
251 const uint8x8_t vk22 = vld1_u8(w);
252 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
253 const uint8x8_t vi24 = vld1_u8(i24);
254 i24 += 8;
255 const int16x8_t vxk22 =
256 vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
257 const int16x8_t vxi24 =
258 vreinterpretq_s16_u16(sub_zero_point(vi24, va_zero_point));
259 vacc0_lo =
260 vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
261 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
262 vacc1_lo =
263 vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
264 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
265 vacc2_lo =
266 vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
267 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
268
269 vacc0_lo = vcvtnq_s32_f32(
270 vmulq_f32(vcvtq_f32_s32(vacc0_lo), requantization_scale_v));
271 vacc0_hi = vcvtnq_s32_f32(
272 vmulq_f32(vcvtq_f32_s32(vacc0_hi), requantization_scale_v));
273 vacc1_lo = vcvtnq_s32_f32(
274 vmulq_f32(vcvtq_f32_s32(vacc1_lo), requantization_scale_v));
275 vacc1_hi = vcvtnq_s32_f32(
276 vmulq_f32(vcvtq_f32_s32(vacc1_hi), requantization_scale_v));
277 vacc2_lo = vcvtnq_s32_f32(
278 vmulq_f32(vcvtq_f32_s32(vacc2_lo), requantization_scale_v));
279 vacc2_hi = vcvtnq_s32_f32(
280 vmulq_f32(vcvtq_f32_s32(vacc2_hi), requantization_scale_v));
281
282 const int16x8_t vacc0 = vqaddq_s16(
283 vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi),
284 voutput_zero_point);
285 const int16x8_t vacc1 = vqaddq_s16(
286 vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi),
287 voutput_zero_point);
288 const int16x8_t vacc2 = vqaddq_s16(
289 vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi),
290 voutput_zero_point);
291 uint8x8_t vout0 = vqmovun_s16(vacc0);
292 uint8x8_t vout1 = vqmovun_s16(vacc1);
293 uint8x8_t vout2 = vqmovun_s16(vacc2);
294 vout0 = vmax_u8(vout0, voutput_min);
295 vout1 = vmax_u8(vout1, voutput_min);
296 vout2 = vmax_u8(vout2, voutput_min);
297 vout0 = vmin_u8(vout0, voutput_max);
298 vout1 = vmin_u8(vout1, voutput_max);
299 vout2 = vmin_u8(vout2, voutput_max);
300
301 vst1_u8(output0, vout0);
302 output0 += 8;
303 vst1_u8(output1, vout1);
304 output1 += 8;
305 vst1_u8(output2, vout2);
306 output2 += 8;
307 }
308 if (c != 0) {
309 const size_t c_predecrement = 8 - c;
310 const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
311 i00 -= c_predecrement;
312 i10 -= c_predecrement;
313 i20 -= c_predecrement;
314 i01 -= c_predecrement;
315 i11 -= c_predecrement;
316 i21 -= c_predecrement;
317 i02 -= c_predecrement;
318 i12 -= c_predecrement;
319 i22 -= c_predecrement;
320 i03 -= c_predecrement;
321 i13 -= c_predecrement;
322 i23 -= c_predecrement;
323 i04 -= c_predecrement;
324 i14 -= c_predecrement;
325 i24 -= c_predecrement;
326
327 int32x4_t vacc0_lo = vld1q_s32(w);
328 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
329 int32x4_t vacc0_hi = vld1q_s32(w);
330 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
331 int32x4_t vacc1_lo = vacc0_lo;
332 int32x4_t vacc2_lo = vacc0_lo;
333 int32x4_t vacc1_hi = vacc0_hi;
334 int32x4_t vacc2_hi = vacc0_hi;
335
336 const uint8x8_t vk00 = vld1_u8(w);
337 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
338 const uint8x8_t vi00 = vreinterpret_u8_u64(
339 vshl_u64(vreinterpret_u64_u8(vld1_u8(i00)), vi_shift));
340 const uint8x8_t vi01 = vreinterpret_u8_u64(
341 vshl_u64(vreinterpret_u64_u8(vld1_u8(i01)), vi_shift));
342 const uint8x8_t vi02 = vreinterpret_u8_u64(
343 vshl_u64(vreinterpret_u64_u8(vld1_u8(i02)), vi_shift));
344 const int16x8_t vxk00 =
345 vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
346 const int16x8_t vxi00 =
347 vreinterpretq_s16_u16(sub_zero_point(vi00, va_zero_point));
348 const int16x8_t vxi01 =
349 vreinterpretq_s16_u16(sub_zero_point(vi01, va_zero_point));
350 const int16x8_t vxi02 =
351 vreinterpretq_s16_u16(sub_zero_point(vi02, va_zero_point));
352 vacc0_lo =
353 vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
354 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
355 vacc1_lo =
356 vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
357 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
358 vacc2_lo =
359 vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
360 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
361
362 const uint8x8_t vk10 = vld1_u8(w);
363 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
364 const uint8x8_t vi10 = vreinterpret_u8_u64(
365 vshl_u64(vreinterpret_u64_u8(vld1_u8(i10)), vi_shift));
366 const uint8x8_t vi11 = vreinterpret_u8_u64(
367 vshl_u64(vreinterpret_u64_u8(vld1_u8(i11)), vi_shift));
368 const uint8x8_t vi12 = vreinterpret_u8_u64(
369 vshl_u64(vreinterpret_u64_u8(vld1_u8(i12)), vi_shift));
370 const int16x8_t vxk10 =
371 vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
372 const int16x8_t vxi10 =
373 vreinterpretq_s16_u16(sub_zero_point(vi10, va_zero_point));
374 const int16x8_t vxi11 =
375 vreinterpretq_s16_u16(sub_zero_point(vi11, va_zero_point));
376 const int16x8_t vxi12 =
377 vreinterpretq_s16_u16(sub_zero_point(vi12, va_zero_point));
378 vacc0_lo =
379 vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
380 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
381 vacc1_lo =
382 vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
383 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
384 vacc2_lo =
385 vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
386 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
387
388 const uint8x8_t vk20 = vld1_u8(w);
389 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
390 const uint8x8_t vi20 = vreinterpret_u8_u64(
391 vshl_u64(vreinterpret_u64_u8(vld1_u8(i20)), vi_shift));
392 const uint8x8_t vi21 = vreinterpret_u8_u64(
393 vshl_u64(vreinterpret_u64_u8(vld1_u8(i21)), vi_shift));
394 const uint8x8_t vi22 = vreinterpret_u8_u64(
395 vshl_u64(vreinterpret_u64_u8(vld1_u8(i22)), vi_shift));
396 const int16x8_t vxk20 =
397 vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
398 const int16x8_t vxi20 =
399 vreinterpretq_s16_u16(sub_zero_point(vi20, va_zero_point));
400 const int16x8_t vxi21 =
401 vreinterpretq_s16_u16(sub_zero_point(vi21, va_zero_point));
402 const int16x8_t vxi22 =
403 vreinterpretq_s16_u16(sub_zero_point(vi22, va_zero_point));
404 vacc0_lo =
405 vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
406 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
407 vacc1_lo =
408 vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
409 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
410 vacc2_lo =
411 vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
412 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
413
414 const uint8x8_t vk01 = vld1_u8(w);
415 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
416 const uint8x8_t vi03 = vreinterpret_u8_u64(
417 vshl_u64(vreinterpret_u64_u8(vld1_u8(i03)), vi_shift));
418 const int16x8_t vxk01 =
419 vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
420 const int16x8_t vxi03 =
421 vreinterpretq_s16_u16(sub_zero_point(vi03, va_zero_point));
422 vacc0_lo =
423 vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
424 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
425 vacc1_lo =
426 vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
427 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
428 vacc2_lo =
429 vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
430 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
431
432 const uint8x8_t vk11 = vld1_u8(w);
433 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
434 const uint8x8_t vi13 = vreinterpret_u8_u64(
435 vshl_u64(vreinterpret_u64_u8(vld1_u8(i13)), vi_shift));
436 const int16x8_t vxk11 =
437 vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
438 const int16x8_t vxi13 =
439 vreinterpretq_s16_u16(sub_zero_point(vi13, va_zero_point));
440 vacc0_lo =
441 vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
442 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
443 vacc1_lo =
444 vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
445 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
446 vacc2_lo =
447 vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
448 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
449
450 const uint8x8_t vk21 = vld1_u8(w);
451 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
452 const uint8x8_t vi23 = vreinterpret_u8_u64(
453 vshl_u64(vreinterpret_u64_u8(vld1_u8(i23)), vi_shift));
454 const int16x8_t vxk21 =
455 vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
456 const int16x8_t vxi23 =
457 vreinterpretq_s16_u16(sub_zero_point(vi23, va_zero_point));
458 vacc0_lo =
459 vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
460 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
461 vacc1_lo =
462 vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
463 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
464 vacc2_lo =
465 vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
466 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
467
468 const uint8x8_t vk02 = vld1_u8(w);
469 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
470 const uint8x8_t vi04 = vreinterpret_u8_u64(
471 vshl_u64(vreinterpret_u64_u8(vld1_u8(i04)), vi_shift));
472 const int16x8_t vxk02 =
473 vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
474 const int16x8_t vxi04 =
475 vreinterpretq_s16_u16(sub_zero_point(vi04, va_zero_point));
476 vacc0_lo =
477 vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
478 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
479 vacc1_lo =
480 vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
481 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
482 vacc2_lo =
483 vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
484 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
485
486 const uint8x8_t vk12 = vld1_u8(w);
487 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
488 const uint8x8_t vi14 = vreinterpret_u8_u64(
489 vshl_u64(vreinterpret_u64_u8(vld1_u8(i14)), vi_shift));
490 const int16x8_t vxk12 =
491 vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
492 const int16x8_t vxi14 =
493 vreinterpretq_s16_u16(sub_zero_point(vi14, va_zero_point));
494 vacc0_lo =
495 vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
496 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
497 vacc1_lo =
498 vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
499 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
500 vacc2_lo =
501 vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
502 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
503
504 const uint8x8_t vk22 = vld1_u8(w);
505 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
506 const uint8x8_t vi24 = vreinterpret_u8_u64(
507 vshl_u64(vreinterpret_u64_u8(vld1_u8(i24)), vi_shift));
508 const int16x8_t vxk22 =
509 vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
510 const int16x8_t vxi24 =
511 vreinterpretq_s16_u16(sub_zero_point(vi24, va_zero_point));
512 vacc0_lo =
513 vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
514 vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
515 vacc1_lo =
516 vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
517 vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
518 vacc2_lo =
519 vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
520 vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
521
522 vacc0_lo = vcvtnq_s32_f32(
523 vmulq_f32(vcvtq_f32_s32(vacc0_lo), requantization_scale_v));
524 vacc0_hi = vcvtnq_s32_f32(
525 vmulq_f32(vcvtq_f32_s32(vacc0_hi), requantization_scale_v));
526 vacc1_lo = vcvtnq_s32_f32(
527 vmulq_f32(vcvtq_f32_s32(vacc1_lo), requantization_scale_v));
528 vacc1_hi = vcvtnq_s32_f32(
529 vmulq_f32(vcvtq_f32_s32(vacc1_hi), requantization_scale_v));
530 vacc2_lo = vcvtnq_s32_f32(
531 vmulq_f32(vcvtq_f32_s32(vacc2_lo), requantization_scale_v));
532 vacc2_hi = vcvtnq_s32_f32(
533 vmulq_f32(vcvtq_f32_s32(vacc2_hi), requantization_scale_v));
534
535 const int16x8_t vacc0 = vqaddq_s16(
536 vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi),
537 voutput_zero_point);
538 const int16x8_t vacc1 = vqaddq_s16(
539 vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi),
540 voutput_zero_point);
541 const int16x8_t vacc2 = vqaddq_s16(
542 vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi),
543 voutput_zero_point);
544 uint8x8_t vout0 = vqmovun_s16(vacc0);
545 uint8x8_t vout1 = vqmovun_s16(vacc1);
546 uint8x8_t vout2 = vqmovun_s16(vacc2);
547 vout0 = vmax_u8(vout0, voutput_min);
548 vout1 = vmax_u8(vout1, voutput_min);
549 vout2 = vmax_u8(vout2, voutput_min);
550 vout0 = vmin_u8(vout0, voutput_max);
551 vout1 = vmin_u8(vout1, voutput_max);
552 vout2 = vmin_u8(vout2, voutput_max);
553
554 if (c & 4) {
555 vst1_lane_u32(
556 __builtin_assume_aligned(output0, 1),
557 vreinterpret_u32_u8(vout0),
558 0);
559 output0 += 4;
560 vst1_lane_u32(
561 __builtin_assume_aligned(output1, 1),
562 vreinterpret_u32_u8(vout1),
563 0);
564 output1 += 4;
565 vst1_lane_u32(
566 __builtin_assume_aligned(output2, 1),
567 vreinterpret_u32_u8(vout2),
568 0);
569 output2 += 4;
570 vout0 = vext_u8(vout0, vout0, 4);
571 vout1 = vext_u8(vout1, vout1, 4);
572 vout2 = vext_u8(vout2, vout2, 4);
573 }
574 if (c & 2) {
575 vst1_lane_u16(
576 __builtin_assume_aligned(output0, 1),
577 vreinterpret_u16_u8(vout0),
578 0);
579 output0 += 2;
580 vst1_lane_u16(
581 __builtin_assume_aligned(output1, 1),
582 vreinterpret_u16_u8(vout1),
583 0);
584 output1 += 2;
585 vst1_lane_u16(
586 __builtin_assume_aligned(output2, 1),
587 vreinterpret_u16_u8(vout2),
588 0);
589 output2 += 2;
590 vout0 = vext_u8(vout0, vout0, 2);
591 vout1 = vext_u8(vout1, vout1, 2);
592 vout2 = vext_u8(vout2, vout2, 2);
593 }
594 if (c & 1) {
595 vst1_lane_u8(__builtin_assume_aligned(output0, 1), vout0, 0);
596 output0++;
597 vst1_lane_u8(__builtin_assume_aligned(output1, 1), vout1, 0);
598 output1++;
599 vst1_lane_u8(__builtin_assume_aligned(output2, 1), vout2, 0);
600 output2++;
601 }
602 }
603
604 output = (uint8_t*)((uintptr_t)output2 + output_increment);
605 }
606 if (output_width == 0) {
607 return;
608 }
609 }
610 #endif
611
612 do {
613 const uint8_t* i0 = input[0];
614 const uint8_t* i1 = input[1];
615 const uint8_t* i2 = input[2];
616 const uint8_t* i3 = input[3];
617 const uint8_t* i4 = input[4];
618 const uint8_t* i5 = input[5];
619 const uint8_t* i6 = input[6];
620 const uint8_t* i7 = input[7];
621 const uint8_t* i8 = input[8];
622
623 input = (const uint8_t**)((uintptr_t)input + input_stride);
624
625 size_t c = channels;
626 const void* w = weights;
627 for (; c >= 8; c -= 8) {
628 int32x4_t vaccX1_lo = vld1q_s32(w);
629 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
630 int32x4_t vaccX1_hi = vld1q_s32(w);
631 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
632
633 const uint8x8_t vk0 = vld1_u8(w);
634 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
635 const uint8x8_t vi0 = vld1_u8(i0);
636 i0 += 8;
637 const int16x8_t vxk0 =
638 vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
639 const int16x8_t vxi0 =
640 vreinterpretq_s16_u16(sub_zero_point(vi0, va_zero_point));
641 int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
642 int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
643
644 const uint8x8_t vk1 = vld1_u8(w);
645 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
646 const uint8x8_t vi1 = vld1_u8(i1);
647 i1 += 8;
648 const int16x8_t vxk1 =
649 vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
650 const int16x8_t vxi1 =
651 vreinterpretq_s16_u16(sub_zero_point(vi1, va_zero_point));
652 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
653 vaccX1_hi =
654 vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
655
656 const uint8x8_t vk2 = vld1_u8(w);
657 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
658 const uint8x8_t vi2 = vld1_u8(i2);
659 i2 += 8;
660 const int16x8_t vxk2 =
661 vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
662 const int16x8_t vxi2 =
663 vreinterpretq_s16_u16(sub_zero_point(vi2, va_zero_point));
664 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
665 vaccX0_hi =
666 vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
667
668 const uint8x8_t vk3 = vld1_u8(w);
669 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
670 const uint8x8_t vi3 = vld1_u8(i3);
671 i3 += 8;
672 const int16x8_t vxk3 =
673 vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
674 const int16x8_t vxi3 =
675 vreinterpretq_s16_u16(sub_zero_point(vi3, va_zero_point));
676 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
677 vaccX1_hi =
678 vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
679
680 const uint8x8_t vk4 = vld1_u8(w);
681 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
682 const uint8x8_t vi4 = vld1_u8(i4);
683 i4 += 8;
684 const int16x8_t vxk4 =
685 vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
686 const int16x8_t vxi4 =
687 vreinterpretq_s16_u16(sub_zero_point(vi4, va_zero_point));
688 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
689 vaccX0_hi =
690 vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
691
692 const uint8x8_t vk5 = vld1_u8(w);
693 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
694 const uint8x8_t vi5 = vld1_u8(i5);
695 i5 += 8;
696 const int16x8_t vxk5 =
697 vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
698 const int16x8_t vxi5 =
699 vreinterpretq_s16_u16(sub_zero_point(vi5, va_zero_point));
700 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
701 vaccX1_hi =
702 vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
703
704 const uint8x8_t vk6 = vld1_u8(w);
705 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
706 const uint8x8_t vi6 = vld1_u8(i6);
707 i6 += 8;
708 const int16x8_t vxk6 =
709 vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
710 const int16x8_t vxi6 =
711 vreinterpretq_s16_u16(sub_zero_point(vi6, va_zero_point));
712 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
713 vaccX0_hi =
714 vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
715
716 const uint8x8_t vk7 = vld1_u8(w);
717 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
718 const uint8x8_t vi7 = vld1_u8(i7);
719 i7 += 8;
720 const int16x8_t vxk7 =
721 vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
722 const int16x8_t vxi7 =
723 vreinterpretq_s16_u16(sub_zero_point(vi7, va_zero_point));
724 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
725 vaccX1_hi =
726 vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
727
728 const uint8x8_t vk8 = vld1_u8(w);
729 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
730 const uint8x8_t vi8 = vld1_u8(i8);
731 i8 += 8;
732 const int16x8_t vxk8 =
733 vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
734 const int16x8_t vxi8 =
735 vreinterpretq_s16_u16(sub_zero_point(vi8, va_zero_point));
736 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
737 vaccX0_hi =
738 vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
739
740 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
741 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
742
743 const float32x4_t vacc_lo_f =
744 vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v);
745 const float32x4_t vacc_hi_f =
746 vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v);
747
748 #ifdef __aarch64__
749 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
750 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
751
752 const int16x8_t vacc = vqaddq_s16(
753 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
754
755 uint8x8_t vout = vqmovun_s16(vacc);
756 vout = vmax_u8(vout, voutput_min);
757 vout = vmin_u8(vout, voutput_max);
758 #else
759 const float32x4_t vacc_lo_f_clamped =
760 vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
761 const float32x4_t vacc_hi_f_clamped =
762 vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
763 vacc_lo = vsubq_s32(
764 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f_clamped, vfmagic)), vimagic);
765 vacc_hi = vsubq_s32(
766 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f_clamped, vfmagic)), vimagic);
767 const int16x8_t vacc =
768 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
769
770 uint8x8_t vout = vqmovun_s16(vacc);
771 #endif
772
773 vst1_u8(output, vout);
774 output += 8;
775 }
776 if (c != 0) {
777 const size_t c_predecrement = 8 - c;
778 const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
779 i0 -= c_predecrement;
780 i1 -= c_predecrement;
781 i2 -= c_predecrement;
782 i3 -= c_predecrement;
783 i4 -= c_predecrement;
784 i5 -= c_predecrement;
785 i6 -= c_predecrement;
786 i7 -= c_predecrement;
787 i8 -= c_predecrement;
788
789 int32x4_t vaccX1_lo = vld1q_s32(w);
790 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
791 int32x4_t vaccX1_hi = vld1q_s32(w);
792 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
793
794 const uint8x8_t vk0 = vld1_u8(w);
795 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
796 const uint8x8_t vi0 = vreinterpret_u8_u64(
797 vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift));
798 const int16x8_t vxk0 =
799 vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
800 const int16x8_t vxi0 =
801 vreinterpretq_s16_u16(sub_zero_point(vi0, va_zero_point));
802 int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
803 int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
804
805 const uint8x8_t vk1 = vld1_u8(w);
806 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
807 const uint8x8_t vi1 = vreinterpret_u8_u64(
808 vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift));
809 const int16x8_t vxk1 =
810 vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
811 const int16x8_t vxi1 =
812 vreinterpretq_s16_u16(sub_zero_point(vi1, va_zero_point));
813 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
814 vaccX1_hi =
815 vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
816
817 const uint8x8_t vk2 = vld1_u8(w);
818 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
819 const uint8x8_t vi2 = vreinterpret_u8_u64(
820 vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift));
821 const int16x8_t vxk2 =
822 vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
823 const int16x8_t vxi2 =
824 vreinterpretq_s16_u16(sub_zero_point(vi2, va_zero_point));
825 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
826 vaccX0_hi =
827 vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
828
829 const uint8x8_t vk3 = vld1_u8(w);
830 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
831 const uint8x8_t vi3 = vreinterpret_u8_u64(
832 vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift));
833 const int16x8_t vxk3 =
834 vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
835 const int16x8_t vxi3 =
836 vreinterpretq_s16_u16(sub_zero_point(vi3, va_zero_point));
837 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
838 vaccX1_hi =
839 vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
840
841 const uint8x8_t vk4 = vld1_u8(w);
842 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
843 const uint8x8_t vi4 = vreinterpret_u8_u64(
844 vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift));
845 const int16x8_t vxk4 =
846 vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
847 const int16x8_t vxi4 =
848 vreinterpretq_s16_u16(sub_zero_point(vi4, va_zero_point));
849 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
850 vaccX0_hi =
851 vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
852
853 const uint8x8_t vk5 = vld1_u8(w);
854 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
855 const uint8x8_t vi5 = vreinterpret_u8_u64(
856 vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift));
857 const int16x8_t vxk5 =
858 vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
859 const int16x8_t vxi5 =
860 vreinterpretq_s16_u16(sub_zero_point(vi5, va_zero_point));
861 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
862 vaccX1_hi =
863 vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
864
865 const uint8x8_t vk6 = vld1_u8(w);
866 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
867 const uint8x8_t vi6 = vreinterpret_u8_u64(
868 vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift));
869 const int16x8_t vxk6 =
870 vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
871 const int16x8_t vxi6 =
872 vreinterpretq_s16_u16(sub_zero_point(vi6, va_zero_point));
873 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
874 vaccX0_hi =
875 vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
876
877 const uint8x8_t vk7 = vld1_u8(w);
878 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
879 const uint8x8_t vi7 = vreinterpret_u8_u64(
880 vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift));
881 const int16x8_t vxk7 =
882 vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
883 const int16x8_t vxi7 =
884 vreinterpretq_s16_u16(sub_zero_point(vi7, va_zero_point));
885 vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
886 vaccX1_hi =
887 vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
888
889 const uint8x8_t vk8 = vld1_u8(w);
890 const uint8x8_t vi8 = vreinterpret_u8_u64(
891 vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift));
892 const int16x8_t vxk8 =
893 vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
894 const int16x8_t vxi8 =
895 vreinterpretq_s16_u16(sub_zero_point(vi8, va_zero_point));
896 vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
897 vaccX0_hi =
898 vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
899
900 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
901 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
902
903 const float32x4_t vacc_lo_f =
904 vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v);
905 const float32x4_t vacc_hi_f =
906 vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v);
907
908 #ifdef __aarch64__
909 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
910 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
911
912 const int16x8_t vacc = vqaddq_s16(
913 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
914
915 uint8x8_t vout = vqmovun_s16(vacc);
916 vout = vmax_u8(vout, voutput_min);
917 vout = vmin_u8(vout, voutput_max);
918 #else
919 const float32x4_t vacc_lo_f_clamped =
920 vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
921 const float32x4_t vacc_hi_f_clamped =
922 vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
923 vacc_lo = vsubq_s32(
924 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f_clamped, vfmagic)), vimagic);
925 vacc_hi = vsubq_s32(
926 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f_clamped, vfmagic)), vimagic);
927 const int16x8_t vacc =
928 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
929
930 uint8x8_t vout = vqmovun_s16(vacc);
931 #endif
932
933 if (c & 4) {
934 vst1_lane_u32(
935 __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
936 output += 4;
937 vout = vext_u8(vout, vout, 4);
938 }
939 if (c & 2) {
940 vst1_lane_u16(
941 __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
942 output += 2;
943 vout = vext_u8(vout, vout, 2);
944 }
945 if (c & 1) {
946 vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0);
947 output++;
948 }
949 }
950
951 output = (uint8_t*)((uintptr_t)output + output_increment);
952 } while (--output_width != 0);
953 }
954