xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <arm_neon.h>
10 
11 #include <qnnpack/q8dwconv.h>
12 #include <requantization/runtime-neon.h>
13 
pytorch_q8dwconv_ukernel_up8x9__neon(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static1])14 void pytorch_q8dwconv_ukernel_up8x9__neon(
15     size_t channels,
16     size_t output_width,
17     const uint8_t** input,
18     const void* weights,
19     uint8_t* output,
20     size_t input_stride,
21     size_t output_increment,
22     const union pytorch_qnnp_conv_quantization_params
23         quantization_params[restrict static 1]) {
24   const uint8x8_t va_zero_point =
25       vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
26   const uint8x8_t vkernel_zero_point =
27       vdup_n_u8(quantization_params->neon.kernel_zero_points[0]);
28   const float32x4_t requantization_scale_v =
29       vdupq_n_f32(quantization_params->neon.requantization_scales[0]);
30 #ifdef __aarch64__
31   const int16x8_t voutput_zero_point =
32       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
33   const uint8x8_t voutput_min =
34       vld1_dup_u8(&quantization_params->neon.output_min);
35   const uint8x8_t voutput_max =
36       vld1_dup_u8(&quantization_params->neon.output_max);
37 #else
38   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
39   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
40   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
41   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
42 #endif
43 
44 #ifdef __aarch64__
45   /* Larger number of registers on AArch64 make it possible to process few
46    * pixels at a time */
47   if (input_stride == 3 * sizeof(void*)) {
48     for (; output_width >= 3; output_width -= 3) {
49       const uint8_t* i00 = input[0];
50       const uint8_t* i10 = input[1];
51       const uint8_t* i20 = input[2];
52       const uint8_t* i01 = input[3];
53       const uint8_t* i11 = input[4];
54       const uint8_t* i21 = input[5];
55       const uint8_t* i02 = input[6];
56       const uint8_t* i12 = input[7];
57       const uint8_t* i22 = input[8];
58       const uint8_t* i03 = input[9];
59       const uint8_t* i13 = input[10];
60       const uint8_t* i23 = input[11];
61       const uint8_t* i04 = input[12];
62       const uint8_t* i14 = input[13];
63       const uint8_t* i24 = input[14];
64 
65       uint8_t* output0 = output;
66       uint8_t* output1 = output0 + channels + output_increment;
67       uint8_t* output2 = output1 + channels + output_increment;
68 
69       input += 9;
70 
71       size_t c = channels;
72       const void* w = weights;
73       for (; c >= 8; c -= 8) {
74         int32x4_t vacc0_lo = vld1q_s32(w);
75         w = (void*)((uintptr_t)w + sizeof(int32x4_t));
76         int32x4_t vacc0_hi = vld1q_s32(w);
77         w = (void*)((uintptr_t)w + sizeof(int32x4_t));
78         int32x4_t vacc1_lo = vacc0_lo;
79         int32x4_t vacc2_lo = vacc0_lo;
80         int32x4_t vacc1_hi = vacc0_hi;
81         int32x4_t vacc2_hi = vacc0_hi;
82 
83         const uint8x8_t vk00 = vld1_u8(w);
84         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
85         const uint8x8_t vi00 = vld1_u8(i00);
86         i00 += 8;
87         const uint8x8_t vi01 = vld1_u8(i01);
88         i01 += 8;
89         const uint8x8_t vi02 = vld1_u8(i02);
90         i02 += 8;
91         const int16x8_t vxk00 =
92             vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
93         const int16x8_t vxi00 =
94             vreinterpretq_s16_u16(sub_zero_point(vi00, va_zero_point));
95         const int16x8_t vxi01 =
96             vreinterpretq_s16_u16(sub_zero_point(vi01, va_zero_point));
97         const int16x8_t vxi02 =
98             vreinterpretq_s16_u16(sub_zero_point(vi02, va_zero_point));
99         vacc0_lo =
100             vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
101         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
102         vacc1_lo =
103             vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
104         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
105         vacc2_lo =
106             vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
107         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
108 
109         const uint8x8_t vk10 = vld1_u8(w);
110         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
111         const uint8x8_t vi10 = vld1_u8(i10);
112         i10 += 8;
113         const uint8x8_t vi11 = vld1_u8(i11);
114         i11 += 8;
115         const uint8x8_t vi12 = vld1_u8(i12);
116         i12 += 8;
117         const int16x8_t vxk10 =
118             vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
119         const int16x8_t vxi10 =
120             vreinterpretq_s16_u16(sub_zero_point(vi10, va_zero_point));
121         const int16x8_t vxi11 =
122             vreinterpretq_s16_u16(sub_zero_point(vi11, va_zero_point));
123         const int16x8_t vxi12 =
124             vreinterpretq_s16_u16(sub_zero_point(vi12, va_zero_point));
125         vacc0_lo =
126             vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
127         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
128         vacc1_lo =
129             vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
130         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
131         vacc2_lo =
132             vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
133         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
134 
135         const uint8x8_t vk20 = vld1_u8(w);
136         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
137         const uint8x8_t vi20 = vld1_u8(i20);
138         i20 += 8;
139         const uint8x8_t vi21 = vld1_u8(i21);
140         i21 += 8;
141         const uint8x8_t vi22 = vld1_u8(i22);
142         i22 += 8;
143         const int16x8_t vxk20 =
144             vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
145         const int16x8_t vxi20 =
146             vreinterpretq_s16_u16(sub_zero_point(vi20, va_zero_point));
147         const int16x8_t vxi21 =
148             vreinterpretq_s16_u16(sub_zero_point(vi21, va_zero_point));
149         const int16x8_t vxi22 =
150             vreinterpretq_s16_u16(sub_zero_point(vi22, va_zero_point));
151         vacc0_lo =
152             vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
153         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
154         vacc1_lo =
155             vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
156         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
157         vacc2_lo =
158             vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
159         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
160 
161         const uint8x8_t vk01 = vld1_u8(w);
162         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
163         const uint8x8_t vi03 = vld1_u8(i03);
164         i03 += 8;
165         const int16x8_t vxk01 =
166             vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
167         const int16x8_t vxi03 =
168             vreinterpretq_s16_u16(sub_zero_point(vi03, va_zero_point));
169         vacc0_lo =
170             vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
171         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
172         vacc1_lo =
173             vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
174         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
175         vacc2_lo =
176             vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
177         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
178 
179         const uint8x8_t vk11 = vld1_u8(w);
180         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
181         const uint8x8_t vi13 = vld1_u8(i13);
182         i13 += 8;
183         const int16x8_t vxk11 =
184             vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
185         const int16x8_t vxi13 =
186             vreinterpretq_s16_u16(sub_zero_point(vi13, va_zero_point));
187         vacc0_lo =
188             vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
189         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
190         vacc1_lo =
191             vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
192         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
193         vacc2_lo =
194             vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
195         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
196 
197         const uint8x8_t vk21 = vld1_u8(w);
198         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
199         const uint8x8_t vi23 = vld1_u8(i23);
200         i23 += 8;
201         const int16x8_t vxk21 =
202             vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
203         const int16x8_t vxi23 =
204             vreinterpretq_s16_u16(sub_zero_point(vi23, va_zero_point));
205         vacc0_lo =
206             vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
207         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
208         vacc1_lo =
209             vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
210         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
211         vacc2_lo =
212             vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
213         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
214 
215         const uint8x8_t vk02 = vld1_u8(w);
216         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
217         const uint8x8_t vi04 = vld1_u8(i04);
218         i04 += 8;
219         const int16x8_t vxk02 =
220             vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
221         const int16x8_t vxi04 =
222             vreinterpretq_s16_u16(sub_zero_point(vi04, va_zero_point));
223         vacc0_lo =
224             vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
225         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
226         vacc1_lo =
227             vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
228         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
229         vacc2_lo =
230             vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
231         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
232 
233         const uint8x8_t vk12 = vld1_u8(w);
234         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
235         const uint8x8_t vi14 = vld1_u8(i14);
236         i14 += 8;
237         const int16x8_t vxk12 =
238             vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
239         const int16x8_t vxi14 =
240             vreinterpretq_s16_u16(sub_zero_point(vi14, va_zero_point));
241         vacc0_lo =
242             vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
243         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
244         vacc1_lo =
245             vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
246         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
247         vacc2_lo =
248             vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
249         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
250 
251         const uint8x8_t vk22 = vld1_u8(w);
252         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
253         const uint8x8_t vi24 = vld1_u8(i24);
254         i24 += 8;
255         const int16x8_t vxk22 =
256             vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
257         const int16x8_t vxi24 =
258             vreinterpretq_s16_u16(sub_zero_point(vi24, va_zero_point));
259         vacc0_lo =
260             vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
261         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
262         vacc1_lo =
263             vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
264         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
265         vacc2_lo =
266             vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
267         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
268 
269         vacc0_lo = vcvtnq_s32_f32(
270             vmulq_f32(vcvtq_f32_s32(vacc0_lo), requantization_scale_v));
271         vacc0_hi = vcvtnq_s32_f32(
272             vmulq_f32(vcvtq_f32_s32(vacc0_hi), requantization_scale_v));
273         vacc1_lo = vcvtnq_s32_f32(
274             vmulq_f32(vcvtq_f32_s32(vacc1_lo), requantization_scale_v));
275         vacc1_hi = vcvtnq_s32_f32(
276             vmulq_f32(vcvtq_f32_s32(vacc1_hi), requantization_scale_v));
277         vacc2_lo = vcvtnq_s32_f32(
278             vmulq_f32(vcvtq_f32_s32(vacc2_lo), requantization_scale_v));
279         vacc2_hi = vcvtnq_s32_f32(
280             vmulq_f32(vcvtq_f32_s32(vacc2_hi), requantization_scale_v));
281 
282         const int16x8_t vacc0 = vqaddq_s16(
283             vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi),
284             voutput_zero_point);
285         const int16x8_t vacc1 = vqaddq_s16(
286             vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi),
287             voutput_zero_point);
288         const int16x8_t vacc2 = vqaddq_s16(
289             vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi),
290             voutput_zero_point);
291         uint8x8_t vout0 = vqmovun_s16(vacc0);
292         uint8x8_t vout1 = vqmovun_s16(vacc1);
293         uint8x8_t vout2 = vqmovun_s16(vacc2);
294         vout0 = vmax_u8(vout0, voutput_min);
295         vout1 = vmax_u8(vout1, voutput_min);
296         vout2 = vmax_u8(vout2, voutput_min);
297         vout0 = vmin_u8(vout0, voutput_max);
298         vout1 = vmin_u8(vout1, voutput_max);
299         vout2 = vmin_u8(vout2, voutput_max);
300 
301         vst1_u8(output0, vout0);
302         output0 += 8;
303         vst1_u8(output1, vout1);
304         output1 += 8;
305         vst1_u8(output2, vout2);
306         output2 += 8;
307       }
308       if (c != 0) {
309         const size_t c_predecrement = 8 - c;
310         const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
311         i00 -= c_predecrement;
312         i10 -= c_predecrement;
313         i20 -= c_predecrement;
314         i01 -= c_predecrement;
315         i11 -= c_predecrement;
316         i21 -= c_predecrement;
317         i02 -= c_predecrement;
318         i12 -= c_predecrement;
319         i22 -= c_predecrement;
320         i03 -= c_predecrement;
321         i13 -= c_predecrement;
322         i23 -= c_predecrement;
323         i04 -= c_predecrement;
324         i14 -= c_predecrement;
325         i24 -= c_predecrement;
326 
327         int32x4_t vacc0_lo = vld1q_s32(w);
328         w = (void*)((uintptr_t)w + sizeof(int32x4_t));
329         int32x4_t vacc0_hi = vld1q_s32(w);
330         w = (void*)((uintptr_t)w + sizeof(int32x4_t));
331         int32x4_t vacc1_lo = vacc0_lo;
332         int32x4_t vacc2_lo = vacc0_lo;
333         int32x4_t vacc1_hi = vacc0_hi;
334         int32x4_t vacc2_hi = vacc0_hi;
335 
336         const uint8x8_t vk00 = vld1_u8(w);
337         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
338         const uint8x8_t vi00 = vreinterpret_u8_u64(
339             vshl_u64(vreinterpret_u64_u8(vld1_u8(i00)), vi_shift));
340         const uint8x8_t vi01 = vreinterpret_u8_u64(
341             vshl_u64(vreinterpret_u64_u8(vld1_u8(i01)), vi_shift));
342         const uint8x8_t vi02 = vreinterpret_u8_u64(
343             vshl_u64(vreinterpret_u64_u8(vld1_u8(i02)), vi_shift));
344         const int16x8_t vxk00 =
345             vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point));
346         const int16x8_t vxi00 =
347             vreinterpretq_s16_u16(sub_zero_point(vi00, va_zero_point));
348         const int16x8_t vxi01 =
349             vreinterpretq_s16_u16(sub_zero_point(vi01, va_zero_point));
350         const int16x8_t vxi02 =
351             vreinterpretq_s16_u16(sub_zero_point(vi02, va_zero_point));
352         vacc0_lo =
353             vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00));
354         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00);
355         vacc1_lo =
356             vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01));
357         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01);
358         vacc2_lo =
359             vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02));
360         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02);
361 
362         const uint8x8_t vk10 = vld1_u8(w);
363         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
364         const uint8x8_t vi10 = vreinterpret_u8_u64(
365             vshl_u64(vreinterpret_u64_u8(vld1_u8(i10)), vi_shift));
366         const uint8x8_t vi11 = vreinterpret_u8_u64(
367             vshl_u64(vreinterpret_u64_u8(vld1_u8(i11)), vi_shift));
368         const uint8x8_t vi12 = vreinterpret_u8_u64(
369             vshl_u64(vreinterpret_u64_u8(vld1_u8(i12)), vi_shift));
370         const int16x8_t vxk10 =
371             vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point));
372         const int16x8_t vxi10 =
373             vreinterpretq_s16_u16(sub_zero_point(vi10, va_zero_point));
374         const int16x8_t vxi11 =
375             vreinterpretq_s16_u16(sub_zero_point(vi11, va_zero_point));
376         const int16x8_t vxi12 =
377             vreinterpretq_s16_u16(sub_zero_point(vi12, va_zero_point));
378         vacc0_lo =
379             vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10));
380         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10);
381         vacc1_lo =
382             vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11));
383         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11);
384         vacc2_lo =
385             vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12));
386         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12);
387 
388         const uint8x8_t vk20 = vld1_u8(w);
389         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
390         const uint8x8_t vi20 = vreinterpret_u8_u64(
391             vshl_u64(vreinterpret_u64_u8(vld1_u8(i20)), vi_shift));
392         const uint8x8_t vi21 = vreinterpret_u8_u64(
393             vshl_u64(vreinterpret_u64_u8(vld1_u8(i21)), vi_shift));
394         const uint8x8_t vi22 = vreinterpret_u8_u64(
395             vshl_u64(vreinterpret_u64_u8(vld1_u8(i22)), vi_shift));
396         const int16x8_t vxk20 =
397             vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point));
398         const int16x8_t vxi20 =
399             vreinterpretq_s16_u16(sub_zero_point(vi20, va_zero_point));
400         const int16x8_t vxi21 =
401             vreinterpretq_s16_u16(sub_zero_point(vi21, va_zero_point));
402         const int16x8_t vxi22 =
403             vreinterpretq_s16_u16(sub_zero_point(vi22, va_zero_point));
404         vacc0_lo =
405             vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20));
406         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20);
407         vacc1_lo =
408             vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21));
409         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21);
410         vacc2_lo =
411             vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22));
412         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22);
413 
414         const uint8x8_t vk01 = vld1_u8(w);
415         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
416         const uint8x8_t vi03 = vreinterpret_u8_u64(
417             vshl_u64(vreinterpret_u64_u8(vld1_u8(i03)), vi_shift));
418         const int16x8_t vxk01 =
419             vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point));
420         const int16x8_t vxi03 =
421             vreinterpretq_s16_u16(sub_zero_point(vi03, va_zero_point));
422         vacc0_lo =
423             vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01));
424         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01);
425         vacc1_lo =
426             vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02));
427         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02);
428         vacc2_lo =
429             vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03));
430         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03);
431 
432         const uint8x8_t vk11 = vld1_u8(w);
433         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
434         const uint8x8_t vi13 = vreinterpret_u8_u64(
435             vshl_u64(vreinterpret_u64_u8(vld1_u8(i13)), vi_shift));
436         const int16x8_t vxk11 =
437             vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point));
438         const int16x8_t vxi13 =
439             vreinterpretq_s16_u16(sub_zero_point(vi13, va_zero_point));
440         vacc0_lo =
441             vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11));
442         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11);
443         vacc1_lo =
444             vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12));
445         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12);
446         vacc2_lo =
447             vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13));
448         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13);
449 
450         const uint8x8_t vk21 = vld1_u8(w);
451         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
452         const uint8x8_t vi23 = vreinterpret_u8_u64(
453             vshl_u64(vreinterpret_u64_u8(vld1_u8(i23)), vi_shift));
454         const int16x8_t vxk21 =
455             vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point));
456         const int16x8_t vxi23 =
457             vreinterpretq_s16_u16(sub_zero_point(vi23, va_zero_point));
458         vacc0_lo =
459             vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21));
460         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21);
461         vacc1_lo =
462             vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22));
463         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22);
464         vacc2_lo =
465             vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23));
466         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23);
467 
468         const uint8x8_t vk02 = vld1_u8(w);
469         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
470         const uint8x8_t vi04 = vreinterpret_u8_u64(
471             vshl_u64(vreinterpret_u64_u8(vld1_u8(i04)), vi_shift));
472         const int16x8_t vxk02 =
473             vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point));
474         const int16x8_t vxi04 =
475             vreinterpretq_s16_u16(sub_zero_point(vi04, va_zero_point));
476         vacc0_lo =
477             vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02));
478         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02);
479         vacc1_lo =
480             vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03));
481         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03);
482         vacc2_lo =
483             vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04));
484         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04);
485 
486         const uint8x8_t vk12 = vld1_u8(w);
487         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
488         const uint8x8_t vi14 = vreinterpret_u8_u64(
489             vshl_u64(vreinterpret_u64_u8(vld1_u8(i14)), vi_shift));
490         const int16x8_t vxk12 =
491             vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point));
492         const int16x8_t vxi14 =
493             vreinterpretq_s16_u16(sub_zero_point(vi14, va_zero_point));
494         vacc0_lo =
495             vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12));
496         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12);
497         vacc1_lo =
498             vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13));
499         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13);
500         vacc2_lo =
501             vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14));
502         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14);
503 
504         const uint8x8_t vk22 = vld1_u8(w);
505         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
506         const uint8x8_t vi24 = vreinterpret_u8_u64(
507             vshl_u64(vreinterpret_u64_u8(vld1_u8(i24)), vi_shift));
508         const int16x8_t vxk22 =
509             vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point));
510         const int16x8_t vxi24 =
511             vreinterpretq_s16_u16(sub_zero_point(vi24, va_zero_point));
512         vacc0_lo =
513             vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22));
514         vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22);
515         vacc1_lo =
516             vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23));
517         vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23);
518         vacc2_lo =
519             vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
520         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
521 
522         vacc0_lo = vcvtnq_s32_f32(
523             vmulq_f32(vcvtq_f32_s32(vacc0_lo), requantization_scale_v));
524         vacc0_hi = vcvtnq_s32_f32(
525             vmulq_f32(vcvtq_f32_s32(vacc0_hi), requantization_scale_v));
526         vacc1_lo = vcvtnq_s32_f32(
527             vmulq_f32(vcvtq_f32_s32(vacc1_lo), requantization_scale_v));
528         vacc1_hi = vcvtnq_s32_f32(
529             vmulq_f32(vcvtq_f32_s32(vacc1_hi), requantization_scale_v));
530         vacc2_lo = vcvtnq_s32_f32(
531             vmulq_f32(vcvtq_f32_s32(vacc2_lo), requantization_scale_v));
532         vacc2_hi = vcvtnq_s32_f32(
533             vmulq_f32(vcvtq_f32_s32(vacc2_hi), requantization_scale_v));
534 
535         const int16x8_t vacc0 = vqaddq_s16(
536             vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi),
537             voutput_zero_point);
538         const int16x8_t vacc1 = vqaddq_s16(
539             vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi),
540             voutput_zero_point);
541         const int16x8_t vacc2 = vqaddq_s16(
542             vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi),
543             voutput_zero_point);
544         uint8x8_t vout0 = vqmovun_s16(vacc0);
545         uint8x8_t vout1 = vqmovun_s16(vacc1);
546         uint8x8_t vout2 = vqmovun_s16(vacc2);
547         vout0 = vmax_u8(vout0, voutput_min);
548         vout1 = vmax_u8(vout1, voutput_min);
549         vout2 = vmax_u8(vout2, voutput_min);
550         vout0 = vmin_u8(vout0, voutput_max);
551         vout1 = vmin_u8(vout1, voutput_max);
552         vout2 = vmin_u8(vout2, voutput_max);
553 
554         if (c & 4) {
555           vst1_lane_u32(
556               __builtin_assume_aligned(output0, 1),
557               vreinterpret_u32_u8(vout0),
558               0);
559           output0 += 4;
560           vst1_lane_u32(
561               __builtin_assume_aligned(output1, 1),
562               vreinterpret_u32_u8(vout1),
563               0);
564           output1 += 4;
565           vst1_lane_u32(
566               __builtin_assume_aligned(output2, 1),
567               vreinterpret_u32_u8(vout2),
568               0);
569           output2 += 4;
570           vout0 = vext_u8(vout0, vout0, 4);
571           vout1 = vext_u8(vout1, vout1, 4);
572           vout2 = vext_u8(vout2, vout2, 4);
573         }
574         if (c & 2) {
575           vst1_lane_u16(
576               __builtin_assume_aligned(output0, 1),
577               vreinterpret_u16_u8(vout0),
578               0);
579           output0 += 2;
580           vst1_lane_u16(
581               __builtin_assume_aligned(output1, 1),
582               vreinterpret_u16_u8(vout1),
583               0);
584           output1 += 2;
585           vst1_lane_u16(
586               __builtin_assume_aligned(output2, 1),
587               vreinterpret_u16_u8(vout2),
588               0);
589           output2 += 2;
590           vout0 = vext_u8(vout0, vout0, 2);
591           vout1 = vext_u8(vout1, vout1, 2);
592           vout2 = vext_u8(vout2, vout2, 2);
593         }
594         if (c & 1) {
595           vst1_lane_u8(__builtin_assume_aligned(output0, 1), vout0, 0);
596           output0++;
597           vst1_lane_u8(__builtin_assume_aligned(output1, 1), vout1, 0);
598           output1++;
599           vst1_lane_u8(__builtin_assume_aligned(output2, 1), vout2, 0);
600           output2++;
601         }
602       }
603 
604       output = (uint8_t*)((uintptr_t)output2 + output_increment);
605     }
606     if (output_width == 0) {
607       return;
608     }
609   }
610 #endif
611 
612   do {
613     const uint8_t* i0 = input[0];
614     const uint8_t* i1 = input[1];
615     const uint8_t* i2 = input[2];
616     const uint8_t* i3 = input[3];
617     const uint8_t* i4 = input[4];
618     const uint8_t* i5 = input[5];
619     const uint8_t* i6 = input[6];
620     const uint8_t* i7 = input[7];
621     const uint8_t* i8 = input[8];
622 
623     input = (const uint8_t**)((uintptr_t)input + input_stride);
624 
625     size_t c = channels;
626     const void* w = weights;
627     for (; c >= 8; c -= 8) {
628       int32x4_t vaccX1_lo = vld1q_s32(w);
629       w = (void*)((uintptr_t)w + sizeof(int32x4_t));
630       int32x4_t vaccX1_hi = vld1q_s32(w);
631       w = (void*)((uintptr_t)w + sizeof(int32x4_t));
632 
633       const uint8x8_t vk0 = vld1_u8(w);
634       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
635       const uint8x8_t vi0 = vld1_u8(i0);
636       i0 += 8;
637       const int16x8_t vxk0 =
638           vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
639       const int16x8_t vxi0 =
640           vreinterpretq_s16_u16(sub_zero_point(vi0, va_zero_point));
641       int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
642       int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
643 
644       const uint8x8_t vk1 = vld1_u8(w);
645       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
646       const uint8x8_t vi1 = vld1_u8(i1);
647       i1 += 8;
648       const int16x8_t vxk1 =
649           vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
650       const int16x8_t vxi1 =
651           vreinterpretq_s16_u16(sub_zero_point(vi1, va_zero_point));
652       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
653       vaccX1_hi =
654           vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
655 
656       const uint8x8_t vk2 = vld1_u8(w);
657       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
658       const uint8x8_t vi2 = vld1_u8(i2);
659       i2 += 8;
660       const int16x8_t vxk2 =
661           vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
662       const int16x8_t vxi2 =
663           vreinterpretq_s16_u16(sub_zero_point(vi2, va_zero_point));
664       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
665       vaccX0_hi =
666           vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
667 
668       const uint8x8_t vk3 = vld1_u8(w);
669       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
670       const uint8x8_t vi3 = vld1_u8(i3);
671       i3 += 8;
672       const int16x8_t vxk3 =
673           vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
674       const int16x8_t vxi3 =
675           vreinterpretq_s16_u16(sub_zero_point(vi3, va_zero_point));
676       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
677       vaccX1_hi =
678           vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
679 
680       const uint8x8_t vk4 = vld1_u8(w);
681       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
682       const uint8x8_t vi4 = vld1_u8(i4);
683       i4 += 8;
684       const int16x8_t vxk4 =
685           vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
686       const int16x8_t vxi4 =
687           vreinterpretq_s16_u16(sub_zero_point(vi4, va_zero_point));
688       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
689       vaccX0_hi =
690           vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
691 
692       const uint8x8_t vk5 = vld1_u8(w);
693       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
694       const uint8x8_t vi5 = vld1_u8(i5);
695       i5 += 8;
696       const int16x8_t vxk5 =
697           vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
698       const int16x8_t vxi5 =
699           vreinterpretq_s16_u16(sub_zero_point(vi5, va_zero_point));
700       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
701       vaccX1_hi =
702           vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
703 
704       const uint8x8_t vk6 = vld1_u8(w);
705       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
706       const uint8x8_t vi6 = vld1_u8(i6);
707       i6 += 8;
708       const int16x8_t vxk6 =
709           vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
710       const int16x8_t vxi6 =
711           vreinterpretq_s16_u16(sub_zero_point(vi6, va_zero_point));
712       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
713       vaccX0_hi =
714           vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
715 
716       const uint8x8_t vk7 = vld1_u8(w);
717       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
718       const uint8x8_t vi7 = vld1_u8(i7);
719       i7 += 8;
720       const int16x8_t vxk7 =
721           vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
722       const int16x8_t vxi7 =
723           vreinterpretq_s16_u16(sub_zero_point(vi7, va_zero_point));
724       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
725       vaccX1_hi =
726           vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
727 
728       const uint8x8_t vk8 = vld1_u8(w);
729       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
730       const uint8x8_t vi8 = vld1_u8(i8);
731       i8 += 8;
732       const int16x8_t vxk8 =
733           vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
734       const int16x8_t vxi8 =
735           vreinterpretq_s16_u16(sub_zero_point(vi8, va_zero_point));
736       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
737       vaccX0_hi =
738           vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
739 
740       int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
741       int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
742 
743       const float32x4_t vacc_lo_f =
744         vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v);
745       const float32x4_t vacc_hi_f =
746         vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v);
747 
748 #ifdef __aarch64__
749       vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
750       vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
751 
752       const int16x8_t vacc = vqaddq_s16(
753           vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
754 
755       uint8x8_t vout = vqmovun_s16(vacc);
756       vout = vmax_u8(vout, voutput_min);
757       vout = vmin_u8(vout, voutput_max);
758 #else
759       const float32x4_t vacc_lo_f_clamped =
760           vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
761       const float32x4_t vacc_hi_f_clamped =
762           vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
763       vacc_lo = vsubq_s32(
764           vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f_clamped, vfmagic)), vimagic);
765       vacc_hi = vsubq_s32(
766           vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f_clamped, vfmagic)), vimagic);
767       const int16x8_t vacc =
768           vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
769 
770       uint8x8_t vout = vqmovun_s16(vacc);
771 #endif
772 
773       vst1_u8(output, vout);
774       output += 8;
775     }
776     if (c != 0) {
777       const size_t c_predecrement = 8 - c;
778       const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
779       i0 -= c_predecrement;
780       i1 -= c_predecrement;
781       i2 -= c_predecrement;
782       i3 -= c_predecrement;
783       i4 -= c_predecrement;
784       i5 -= c_predecrement;
785       i6 -= c_predecrement;
786       i7 -= c_predecrement;
787       i8 -= c_predecrement;
788 
789       int32x4_t vaccX1_lo = vld1q_s32(w);
790       w = (void*)((uintptr_t)w + sizeof(int32x4_t));
791       int32x4_t vaccX1_hi = vld1q_s32(w);
792       w = (void*)((uintptr_t)w + sizeof(int32x4_t));
793 
794       const uint8x8_t vk0 = vld1_u8(w);
795       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
796       const uint8x8_t vi0 = vreinterpret_u8_u64(
797           vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift));
798       const int16x8_t vxk0 =
799           vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
800       const int16x8_t vxi0 =
801           vreinterpretq_s16_u16(sub_zero_point(vi0, va_zero_point));
802       int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
803       int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
804 
805       const uint8x8_t vk1 = vld1_u8(w);
806       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
807       const uint8x8_t vi1 = vreinterpret_u8_u64(
808           vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift));
809       const int16x8_t vxk1 =
810           vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
811       const int16x8_t vxi1 =
812           vreinterpretq_s16_u16(sub_zero_point(vi1, va_zero_point));
813       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
814       vaccX1_hi =
815           vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
816 
817       const uint8x8_t vk2 = vld1_u8(w);
818       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
819       const uint8x8_t vi2 = vreinterpret_u8_u64(
820           vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift));
821       const int16x8_t vxk2 =
822           vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
823       const int16x8_t vxi2 =
824           vreinterpretq_s16_u16(sub_zero_point(vi2, va_zero_point));
825       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
826       vaccX0_hi =
827           vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
828 
829       const uint8x8_t vk3 = vld1_u8(w);
830       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
831       const uint8x8_t vi3 = vreinterpret_u8_u64(
832           vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift));
833       const int16x8_t vxk3 =
834           vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
835       const int16x8_t vxi3 =
836           vreinterpretq_s16_u16(sub_zero_point(vi3, va_zero_point));
837       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
838       vaccX1_hi =
839           vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
840 
841       const uint8x8_t vk4 = vld1_u8(w);
842       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
843       const uint8x8_t vi4 = vreinterpret_u8_u64(
844           vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift));
845       const int16x8_t vxk4 =
846           vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
847       const int16x8_t vxi4 =
848           vreinterpretq_s16_u16(sub_zero_point(vi4, va_zero_point));
849       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
850       vaccX0_hi =
851           vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
852 
853       const uint8x8_t vk5 = vld1_u8(w);
854       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
855       const uint8x8_t vi5 = vreinterpret_u8_u64(
856           vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift));
857       const int16x8_t vxk5 =
858           vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
859       const int16x8_t vxi5 =
860           vreinterpretq_s16_u16(sub_zero_point(vi5, va_zero_point));
861       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
862       vaccX1_hi =
863           vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
864 
865       const uint8x8_t vk6 = vld1_u8(w);
866       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
867       const uint8x8_t vi6 = vreinterpret_u8_u64(
868           vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift));
869       const int16x8_t vxk6 =
870           vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
871       const int16x8_t vxi6 =
872           vreinterpretq_s16_u16(sub_zero_point(vi6, va_zero_point));
873       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
874       vaccX0_hi =
875           vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
876 
877       const uint8x8_t vk7 = vld1_u8(w);
878       w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
879       const uint8x8_t vi7 = vreinterpret_u8_u64(
880           vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift));
881       const int16x8_t vxk7 =
882           vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
883       const int16x8_t vxi7 =
884           vreinterpretq_s16_u16(sub_zero_point(vi7, va_zero_point));
885       vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
886       vaccX1_hi =
887           vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
888 
889       const uint8x8_t vk8 = vld1_u8(w);
890       const uint8x8_t vi8 = vreinterpret_u8_u64(
891           vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift));
892       const int16x8_t vxk8 =
893           vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
894       const int16x8_t vxi8 =
895           vreinterpretq_s16_u16(sub_zero_point(vi8, va_zero_point));
896       vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
897       vaccX0_hi =
898           vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
899 
900       int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
901       int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
902 
903       const float32x4_t vacc_lo_f =
904         vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v);
905       const float32x4_t vacc_hi_f =
906         vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v);
907 
908 #ifdef __aarch64__
909       vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
910       vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
911 
912       const int16x8_t vacc = vqaddq_s16(
913           vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point);
914 
915       uint8x8_t vout = vqmovun_s16(vacc);
916       vout = vmax_u8(vout, voutput_min);
917       vout = vmin_u8(vout, voutput_max);
918 #else
919       const float32x4_t vacc_lo_f_clamped =
920           vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
921       const float32x4_t vacc_hi_f_clamped =
922           vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
923       vacc_lo = vsubq_s32(
924           vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f_clamped, vfmagic)), vimagic);
925       vacc_hi = vsubq_s32(
926           vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f_clamped, vfmagic)), vimagic);
927       const int16x8_t vacc =
928           vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
929 
930       uint8x8_t vout = vqmovun_s16(vacc);
931 #endif
932 
933       if (c & 4) {
934         vst1_lane_u32(
935             __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0);
936         output += 4;
937         vout = vext_u8(vout, vout, 4);
938       }
939       if (c & 2) {
940         vst1_lane_u16(
941             __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0);
942         output += 2;
943         vout = vext_u8(vout, vout, 2);
944       }
945       if (c & 1) {
946         vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0);
947         output++;
948       }
949     }
950 
951     output = (uint8_t*)((uintptr_t)output + output_increment);
952   } while (--output_width != 0);
953 }
954