xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x27-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <arm_neon.h>
10 
11 #include <qnnpack/q8dwconv.h>
12 
13 /**
14  * Kernel for performing depthwise convolution with a kernel of weights of
15  * size 27 (ex. 3x3x3). The general strategy is
16  * For each output row:
17  *   For each output pixel in that row:
18  *     For each of the three groupings of 9 of the 27 weights (ex. one yz slice for a 3x3x3 kernel):
19  *       For each grouping of 8 channels:
20  *          Load input pixel values from the indirection buffer and the weights,
21  *          multiply and add them, and keep track of a running total of these
22  *          products and the bias in a temporary buffer
23  *       Perform requantization to obtain final output
24  *     Advance indirection buffer to next pixel in the row
25  *   Advance indirection buffer to next row
26  */
pytorch_q8dwconv_ukernel_mp8x27__neon(size_t channels,size_t output_height,size_t output_width,const uint8_t ** input,const void * weights,int32_t * outacc32,uint8_t * output,size_t input_row_stride,size_t input_col_stride,size_t output_increment,const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static1])27 void pytorch_q8dwconv_ukernel_mp8x27__neon(
28     size_t channels,
29     size_t output_height,
30     size_t output_width,
31     const uint8_t** input,
32     const void* weights,
33     int32_t* outacc32,
34     uint8_t* output,
35     size_t input_row_stride, // Stride between rows in the indirection buffer
36     size_t input_col_stride, // Stride between columns in the indirection buffer
37     size_t
38         output_increment, // Padding between output pixels in the output buffer
39     const union pytorch_qnnp_conv_quantization_params
40         quantization_params[restrict static 1]) {
41   const uint8x8_t vinput_zero_point =
42       vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
43 #ifdef __aarch64__
44   const int16x8_t voutput_zero_point =
45       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
46   const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min);
47   const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max);
48 #else
49   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
50   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
51   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
52   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
53 #endif
54 
55   for (size_t output_y = 0; output_y < output_height; output_y++) {
56     const uint8_t** input_row_start = input;
57     for (size_t output_x = 0; output_x < output_width; output_x++) {
58       uint8_t* output_start = output;
59       int32_t* outacc = outacc32;
60       const void* w = weights;
61       {
62         const uint8_t* i0 = input[0];
63         const uint8_t* i1 = input[1];
64         const uint8_t* i2 = input[2];
65         const uint8_t* i3 = input[3];
66         const uint8_t* i4 = input[4];
67         const uint8_t* i5 = input[5];
68         const uint8_t* i6 = input[6];
69         const uint8_t* i7 = input[7];
70         const uint8_t* i8 = input[8];
71 
72         size_t c = channels;
73         const uint8_t* kernel_zero_points_ptr =
74             quantization_params->neon.kernel_zero_points + channels;
75         for (; c >= 8; c -= 8) {
76           const uint8x8_t vkernel_zero_point =
77               vld1_u8(kernel_zero_points_ptr - c);
78           int32x4_t vaccX1_lo = vld1q_s32(w);
79           w = (void*)((uintptr_t)w + sizeof(int32x4_t));
80           int32x4_t vaccX1_hi = vld1q_s32(w);
81           w = (void*)((uintptr_t)w + sizeof(int32x4_t));
82 
83           const uint8x8_t vk0 = vld1_u8(w);
84           w += 8;
85           const uint8x8_t vi0 = vld1_u8(i0);
86           i0 += 8;
87           const int16x8_t vxk0 =
88               vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
89           const int16x8_t vxi0 =
90               vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
91           int32x4_t vaccX0_lo =
92               vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
93           int32x4_t vaccX0_hi =
94               vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
95 
96           const uint8x8_t vk1 = vld1_u8(w);
97           w += 8;
98           const uint8x8_t vi1 = vld1_u8(i1);
99           i1 += 8;
100           const int16x8_t vxk1 =
101               vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
102           const int16x8_t vxi1 =
103               vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
104           vaccX1_lo =
105               vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
106           vaccX1_hi =
107               vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
108 
109           const uint8x8_t vk2 = vld1_u8(w);
110           w += 8;
111           const uint8x8_t vi2 = vld1_u8(i2);
112           i2 += 8;
113           const int16x8_t vxk2 =
114               vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
115           const int16x8_t vxi2 =
116               vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
117           vaccX0_lo =
118               vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
119           vaccX0_hi =
120               vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
121 
122           const uint8x8_t vk3 = vld1_u8(w);
123           w += 8;
124           const uint8x8_t vi3 = vld1_u8(i3);
125           i3 += 8;
126           const int16x8_t vxk3 =
127               vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
128           const int16x8_t vxi3 =
129               vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
130           vaccX1_lo =
131               vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
132           vaccX1_hi =
133               vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
134 
135           const uint8x8_t vk4 = vld1_u8(w);
136           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
137           const uint8x8_t vi4 = vld1_u8(i4);
138           i4 += 8;
139           const int16x8_t vxk4 =
140               vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
141           const int16x8_t vxi4 =
142               vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
143           vaccX0_lo =
144               vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
145           vaccX0_hi =
146               vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
147 
148           const uint8x8_t vk5 = vld1_u8(w);
149           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
150           const uint8x8_t vi5 = vld1_u8(i5);
151           i5 += 8;
152           const int16x8_t vxk5 =
153               vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
154           const int16x8_t vxi5 =
155               vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
156           vaccX1_lo =
157               vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
158           vaccX1_hi =
159               vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
160 
161           const uint8x8_t vk6 = vld1_u8(w);
162           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
163           const uint8x8_t vi6 = vld1_u8(i6);
164           i6 += 8;
165           const int16x8_t vxk6 =
166               vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
167           const int16x8_t vxi6 =
168               vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
169           vaccX0_lo =
170               vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
171           vaccX0_hi =
172               vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
173 
174           const uint8x8_t vk7 = vld1_u8(w);
175           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
176           const uint8x8_t vi7 = vld1_u8(i7);
177           i7 += 8;
178           const int16x8_t vxk7 =
179               vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
180           const int16x8_t vxi7 =
181               vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
182           vaccX1_lo =
183               vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
184           vaccX1_hi =
185               vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
186 
187           const uint8x8_t vk8 = vld1_u8(w);
188           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
189           const uint8x8_t vi8 = vld1_u8(i8);
190           i8 += 8;
191           const int16x8_t vxk8 =
192               vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
193           const int16x8_t vxi8 =
194               vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
195           vaccX0_lo =
196               vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
197           vaccX0_hi =
198               vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
199 
200           int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
201           int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
202 
203           vst1q_s32(outacc, vacc_lo);
204           outacc += 4;
205           vst1q_s32(outacc, vacc_hi);
206           outacc += 4;
207         }
208         if (c != 0) {
209           const size_t c_predecrement = 8 - c;
210           const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
211           i0 -= c_predecrement;
212           i1 -= c_predecrement;
213           i2 -= c_predecrement;
214           i3 -= c_predecrement;
215           i4 -= c_predecrement;
216           i5 -= c_predecrement;
217           i6 -= c_predecrement;
218           i7 -= c_predecrement;
219           i8 -= c_predecrement;
220 
221           const uint8x8_t vkernel_zero_point = vld1_u8(
222               &quantization_params->neon.kernel_zero_points[channels - c]);
223           int32x4_t vaccX1_lo = vld1q_s32(w);
224           w = (void*)((uintptr_t)w + sizeof(int32x4_t));
225           int32x4_t vaccX1_hi = vld1q_s32(w);
226           w = (void*)((uintptr_t)w + sizeof(int32x4_t));
227 
228           const uint8x8_t vk0 = vld1_u8(w);
229           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
230           const uint8x8_t vi0 = vreinterpret_u8_u64(
231               vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift));
232           const int16x8_t vxk0 =
233               vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
234           const int16x8_t vxi0 =
235               vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
236           int32x4_t vaccX0_lo =
237               vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
238           int32x4_t vaccX0_hi =
239               vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
240 
241           const uint8x8_t vk1 = vld1_u8(w);
242           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
243           const uint8x8_t vi1 = vreinterpret_u8_u64(
244               vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift));
245           const int16x8_t vxk1 =
246               vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
247           const int16x8_t vxi1 =
248               vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
249           vaccX1_lo =
250               vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
251           vaccX1_hi =
252               vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
253 
254           const uint8x8_t vk2 = vld1_u8(w);
255           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
256           const uint8x8_t vi2 = vreinterpret_u8_u64(
257               vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift));
258           const int16x8_t vxk2 =
259               vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
260           const int16x8_t vxi2 =
261               vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
262           vaccX0_lo =
263               vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
264           vaccX0_hi =
265               vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
266 
267           const uint8x8_t vk3 = vld1_u8(w);
268           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
269           const uint8x8_t vi3 = vreinterpret_u8_u64(
270               vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift));
271           const int16x8_t vxk3 =
272               vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
273           const int16x8_t vxi3 =
274               vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
275           vaccX1_lo =
276               vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
277           vaccX1_hi =
278               vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
279 
280           const uint8x8_t vk4 = vld1_u8(w);
281           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
282           const uint8x8_t vi4 = vreinterpret_u8_u64(
283               vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift));
284           const int16x8_t vxk4 =
285               vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
286           const int16x8_t vxi4 =
287               vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
288           vaccX0_lo =
289               vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
290           vaccX0_hi =
291               vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
292 
293           const uint8x8_t vk5 = vld1_u8(w);
294           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
295           const uint8x8_t vi5 = vreinterpret_u8_u64(
296               vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift));
297           const int16x8_t vxk5 =
298               vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
299           const int16x8_t vxi5 =
300               vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
301           vaccX1_lo =
302               vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
303           vaccX1_hi =
304               vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
305 
306           const uint8x8_t vk6 = vld1_u8(w);
307           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
308           const uint8x8_t vi6 = vreinterpret_u8_u64(
309               vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift));
310           const int16x8_t vxk6 =
311               vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
312           const int16x8_t vxi6 =
313               vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
314           vaccX0_lo =
315               vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
316           vaccX0_hi =
317               vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
318 
319           const uint8x8_t vk7 = vld1_u8(w);
320           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
321           const uint8x8_t vi7 = vreinterpret_u8_u64(
322               vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift));
323           const int16x8_t vxk7 =
324               vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
325           const int16x8_t vxi7 =
326               vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
327           vaccX1_lo =
328               vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
329           vaccX1_hi =
330               vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
331 
332           const uint8x8_t vk8 = vld1_u8(w);
333           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
334           const uint8x8_t vi8 = vreinterpret_u8_u64(
335               vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift));
336           const int16x8_t vxk8 =
337               vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
338           const int16x8_t vxi8 =
339               vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
340           vaccX0_lo =
341               vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
342           vaccX0_hi =
343               vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
344 
345           int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
346           int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
347 
348           vst1q_s32(outacc, vacc_lo);
349           outacc += 4;
350           vst1q_s32(outacc, vacc_hi);
351           outacc += 4;
352         }
353       }
354       {
355         const uint8_t* i0 = input[9];
356         const uint8_t* i1 = input[10];
357         const uint8_t* i2 = input[11];
358         const uint8_t* i3 = input[12];
359         const uint8_t* i4 = input[13];
360         const uint8_t* i5 = input[14];
361         const uint8_t* i6 = input[15];
362         const uint8_t* i7 = input[16];
363         const uint8_t* i8 = input[17];
364         output = output_start;
365         outacc = outacc32;
366 
367         size_t c = channels;
368         const uint8_t* kernel_zero_points_ptr =
369             quantization_params->neon.kernel_zero_points + channels;
370         for (; c >= 8; c -= 8) {
371           const uint8x8_t vkernel_zero_point =
372               vld1_u8(kernel_zero_points_ptr - c);
373           const uint8x8_t vk0 = vld1_u8(w);
374           w += 8;
375           const uint8x8_t vi0 = vld1_u8(i0);
376           i0 += 8;
377           const int16x8_t vxk0 =
378               vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
379           const int16x8_t vxi0 =
380               vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
381           int32x4_t vaccX0_lo =
382               vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
383           int32x4_t vaccX0_hi =
384               vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
385 
386           const uint8x8_t vk1 = vld1_u8(w);
387           w += 8;
388           const uint8x8_t vi1 = vld1_u8(i1);
389           i1 += 8;
390           const int16x8_t vxk1 =
391               vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
392           const int16x8_t vxi1 =
393               vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
394           int32x4_t vaccX1_lo =
395               vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1));
396           int32x4_t vaccX1_hi =
397               vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1));
398 
399           const uint8x8_t vk2 = vld1_u8(w);
400           w += 8;
401           const uint8x8_t vi2 = vld1_u8(i2);
402           i2 += 8;
403           const int16x8_t vxk2 =
404               vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
405           const int16x8_t vxi2 =
406               vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
407           vaccX0_lo =
408               vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
409           vaccX0_hi =
410               vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
411 
412           const uint8x8_t vk3 = vld1_u8(w);
413           w += 8;
414           const uint8x8_t vi3 = vld1_u8(i3);
415           i3 += 8;
416           const int16x8_t vxk3 =
417               vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
418           const int16x8_t vxi3 =
419               vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
420           vaccX1_lo =
421               vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
422           vaccX1_hi =
423               vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
424 
425           const uint8x8_t vk4 = vld1_u8(w);
426           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
427           const uint8x8_t vi4 = vld1_u8(i4);
428           i4 += 8;
429           const int16x8_t vxk4 =
430               vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
431           const int16x8_t vxi4 =
432               vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
433           vaccX0_lo =
434               vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
435           vaccX0_hi =
436               vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
437 
438           const uint8x8_t vk5 = vld1_u8(w);
439           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
440           const uint8x8_t vi5 = vld1_u8(i5);
441           i5 += 8;
442           const int16x8_t vxk5 =
443               vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
444           const int16x8_t vxi5 =
445               vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
446           vaccX1_lo =
447               vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
448           vaccX1_hi =
449               vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
450 
451           const uint8x8_t vk6 = vld1_u8(w);
452           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
453           const uint8x8_t vi6 = vld1_u8(i6);
454           i6 += 8;
455           const int16x8_t vxk6 =
456               vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
457           const int16x8_t vxi6 =
458               vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
459           vaccX0_lo =
460               vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
461           vaccX0_hi =
462               vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
463 
464           const uint8x8_t vk7 = vld1_u8(w);
465           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
466           const uint8x8_t vi7 = vld1_u8(i7);
467           i7 += 8;
468           const int16x8_t vxk7 =
469               vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
470           const int16x8_t vxi7 =
471               vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
472           vaccX1_lo =
473               vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
474           vaccX1_hi =
475               vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
476 
477           const uint8x8_t vk8 = vld1_u8(w);
478           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
479           const uint8x8_t vi8 = vld1_u8(i8);
480           i8 += 8;
481           const int16x8_t vxk8 =
482               vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
483           const int16x8_t vxi8 =
484               vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
485           vaccX0_lo =
486               vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
487           vaccX0_hi =
488               vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
489 
490           int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
491           int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
492 
493           const int32x4_t vacc_lo_old = vld1q_s32(outacc);
494           const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4);
495           vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
496           vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
497           vst1q_s32(outacc, vacc_lo);
498           outacc += 4;
499           vst1q_s32(outacc, vacc_hi);
500           outacc += 4;
501         }
502         if (c != 0) {
503           const size_t c_predecrement = 8 - c;
504           const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
505           i0 -= c_predecrement;
506           i1 -= c_predecrement;
507           i2 -= c_predecrement;
508           i3 -= c_predecrement;
509           i4 -= c_predecrement;
510           i5 -= c_predecrement;
511           i6 -= c_predecrement;
512           i7 -= c_predecrement;
513           i8 -= c_predecrement;
514 
515           const uint8x8_t vkernel_zero_point = vld1_u8(
516               &quantization_params->neon.kernel_zero_points[channels - c]);
517           const uint8x8_t vk0 = vld1_u8(w);
518           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
519           const uint8x8_t vi0 = vreinterpret_u8_u64(
520               vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift));
521           const int16x8_t vxk0 =
522               vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
523           const int16x8_t vxi0 =
524               vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
525           int32x4_t vaccX0_lo =
526               vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
527           int32x4_t vaccX0_hi =
528               vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
529 
530           const uint8x8_t vk1 = vld1_u8(w);
531           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
532           const uint8x8_t vi1 = vreinterpret_u8_u64(
533               vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift));
534           const int16x8_t vxk1 =
535               vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
536           const int16x8_t vxi1 =
537               vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
538           int32x4_t vaccX1_lo =
539               vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1));
540           int32x4_t vaccX1_hi =
541               vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1));
542 
543           const uint8x8_t vk2 = vld1_u8(w);
544           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
545           const uint8x8_t vi2 = vreinterpret_u8_u64(
546               vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift));
547           const int16x8_t vxk2 =
548               vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
549           const int16x8_t vxi2 =
550               vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
551           vaccX0_lo =
552               vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
553           vaccX0_hi =
554               vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
555 
556           const uint8x8_t vk3 = vld1_u8(w);
557           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
558           const uint8x8_t vi3 = vreinterpret_u8_u64(
559               vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift));
560           const int16x8_t vxk3 =
561               vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
562           const int16x8_t vxi3 =
563               vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
564           vaccX1_lo =
565               vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
566           vaccX1_hi =
567               vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
568 
569           const uint8x8_t vk4 = vld1_u8(w);
570           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
571           const uint8x8_t vi4 = vreinterpret_u8_u64(
572               vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift));
573           const int16x8_t vxk4 =
574               vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
575           const int16x8_t vxi4 =
576               vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
577           vaccX0_lo =
578               vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
579           vaccX0_hi =
580               vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
581 
582           const uint8x8_t vk5 = vld1_u8(w);
583           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
584           const uint8x8_t vi5 = vreinterpret_u8_u64(
585               vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift));
586           const int16x8_t vxk5 =
587               vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
588           const int16x8_t vxi5 =
589               vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
590           vaccX1_lo =
591               vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
592           vaccX1_hi =
593               vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
594 
595           const uint8x8_t vk6 = vld1_u8(w);
596           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
597           const uint8x8_t vi6 = vreinterpret_u8_u64(
598               vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift));
599           const int16x8_t vxk6 =
600               vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
601           const int16x8_t vxi6 =
602               vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
603           vaccX0_lo =
604               vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
605           vaccX0_hi =
606               vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
607 
608           const uint8x8_t vk7 = vld1_u8(w);
609           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
610           const uint8x8_t vi7 = vreinterpret_u8_u64(
611               vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift));
612           const int16x8_t vxk7 =
613               vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
614           const int16x8_t vxi7 =
615               vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
616           vaccX1_lo =
617               vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
618           vaccX1_hi =
619               vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
620 
621           const uint8x8_t vk8 = vld1_u8(w);
622           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
623           const uint8x8_t vi8 = vreinterpret_u8_u64(
624               vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift));
625           const int16x8_t vxk8 =
626               vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
627           const int16x8_t vxi8 =
628               vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
629           vaccX0_lo =
630               vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
631           vaccX0_hi =
632               vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
633 
634           int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
635           int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
636 
637           const int32x4_t vacc_lo_old = vld1q_s32(outacc);
638           const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4);
639           vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
640           vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
641           vst1q_s32(outacc, vacc_lo);
642           outacc += 4;
643           vst1q_s32(outacc, vacc_hi);
644           outacc += 4;
645         }
646       }
647 
648       {
649         const uint8_t* i0 = input[18];
650         const uint8_t* i1 = input[19];
651         const uint8_t* i2 = input[20];
652         const uint8_t* i3 = input[21];
653         const uint8_t* i4 = input[22];
654         const uint8_t* i5 = input[23];
655         const uint8_t* i6 = input[24];
656         const uint8_t* i7 = input[25];
657         const uint8_t* i8 = input[26];
658         input = (const uint8_t**)((uintptr_t)input + input_col_stride);
659         output = output_start;
660         outacc = outacc32;
661 
662         size_t c = channels;
663         const uint8_t* kernel_zero_points_ptr =
664             quantization_params->neon.kernel_zero_points + channels;
665         const float* requantization_scales_ptr =
666             quantization_params->neon.requantization_scales + channels;
667         for (; c >= 8; c -= 8) {
668           const uint8x8_t vkernel_zero_point =
669               vld1_u8(kernel_zero_points_ptr - c);
670           const uint8x8_t vk0 = vld1_u8(w);
671           w += 8;
672           const uint8x8_t vi0 = vld1_u8(i0);
673           i0 += 8;
674           const int16x8_t vxk0 =
675               vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
676           const int16x8_t vxi0 =
677               vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
678           int32x4_t vaccX0_lo =
679               vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
680           int32x4_t vaccX0_hi =
681               vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
682 
683           const uint8x8_t vk1 = vld1_u8(w);
684           w += 8;
685           const uint8x8_t vi1 = vld1_u8(i1);
686           i1 += 8;
687           const int16x8_t vxk1 =
688               vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
689           const int16x8_t vxi1 =
690               vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
691           int32x4_t vaccX1_lo =
692               vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1));
693           int32x4_t vaccX1_hi =
694               vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1));
695 
696           const uint8x8_t vk2 = vld1_u8(w);
697           w += 8;
698           const uint8x8_t vi2 = vld1_u8(i2);
699           i2 += 8;
700           const int16x8_t vxk2 =
701               vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
702           const int16x8_t vxi2 =
703               vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
704           vaccX0_lo =
705               vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
706           vaccX0_hi =
707               vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
708 
709           const uint8x8_t vk3 = vld1_u8(w);
710           w += 8;
711           const uint8x8_t vi3 = vld1_u8(i3);
712           i3 += 8;
713           const int16x8_t vxk3 =
714               vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
715           const int16x8_t vxi3 =
716               vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
717           vaccX1_lo =
718               vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
719           vaccX1_hi =
720               vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
721 
722           const uint8x8_t vk4 = vld1_u8(w);
723           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
724           const uint8x8_t vi4 = vld1_u8(i4);
725           i4 += 8;
726           const int16x8_t vxk4 =
727               vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
728           const int16x8_t vxi4 =
729               vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
730           vaccX0_lo =
731               vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
732           vaccX0_hi =
733               vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
734 
735           const uint8x8_t vk5 = vld1_u8(w);
736           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
737           const uint8x8_t vi5 = vld1_u8(i5);
738           i5 += 8;
739           const int16x8_t vxk5 =
740               vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
741           const int16x8_t vxi5 =
742               vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
743           vaccX1_lo =
744               vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
745           vaccX1_hi =
746               vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
747 
748           const uint8x8_t vk6 = vld1_u8(w);
749           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
750           const uint8x8_t vi6 = vld1_u8(i6);
751           i6 += 8;
752           const int16x8_t vxk6 =
753               vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
754           const int16x8_t vxi6 =
755               vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
756           vaccX0_lo =
757               vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
758           vaccX0_hi =
759               vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
760 
761           const uint8x8_t vk7 = vld1_u8(w);
762           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
763           const uint8x8_t vi7 = vld1_u8(i7);
764           i7 += 8;
765           const int16x8_t vxk7 =
766               vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
767           const int16x8_t vxi7 =
768               vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
769           vaccX1_lo =
770               vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
771           vaccX1_hi =
772               vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
773 
774           const uint8x8_t vk8 = vld1_u8(w);
775           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
776           const uint8x8_t vi8 = vld1_u8(i8);
777           i8 += 8;
778           const int16x8_t vxk8 =
779               vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
780           const int16x8_t vxi8 =
781               vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
782           vaccX0_lo =
783               vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
784           vaccX0_hi =
785               vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
786 
787           int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
788           int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
789 
790           const int32x4_t vacc_lo_old = vld1q_s32(outacc);
791           outacc += 4;
792           const int32x4_t vacc_hi_old = vld1q_s32(outacc);
793           outacc += 4;
794           vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
795           vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
796 
797           const float32x4_t requantization_scale_v_lo =
798               vld1q_f32(requantization_scales_ptr - c);
799           const float32x4_t requantization_scale_v_hi =
800               vld1q_f32(requantization_scales_ptr - c + 4);
801 
802           const float32x4_t vacc_lo_f =
803               vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v_lo);
804           const float32x4_t vacc_hi_f =
805               vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v_hi);
806 
807 #ifdef __aarch64__
808           vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
809           vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
810 
811           const int16x8_t vacc = vqaddq_s16(
812               vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi),
813               voutput_zero_point);
814 
815           uint8x8_t vout = vqmovun_s16(vacc);
816           vout = vmax_u8(vout, voutput_min);
817           vout = vmin_u8(vout, voutput_max);
818 #else
819           const float32x4_t vacc_lo_f_clamped =
820               vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
821           const float32x4_t vacc_hi_f_clamped =
822               vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
823           vacc_lo = vsubq_s32(
824               vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f_clamped, vfmagic)),
825               vimagic);
826           vacc_hi = vsubq_s32(
827               vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f_clamped, vfmagic)),
828               vimagic);
829           const int16x8_t vacc =
830               vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
831 
832           uint8x8_t vout = vqmovun_s16(vacc);
833 #endif
834 
835           vst1_u8(output, vout);
836           output += 8;
837         }
838         if (c != 0) {
839           const size_t c_predecrement = 8 - c;
840           const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
841           i0 -= c_predecrement;
842           i1 -= c_predecrement;
843           i2 -= c_predecrement;
844           i3 -= c_predecrement;
845           i4 -= c_predecrement;
846           i5 -= c_predecrement;
847           i6 -= c_predecrement;
848           i7 -= c_predecrement;
849           i8 -= c_predecrement;
850 
851           const uint8x8_t vkernel_zero_point = vld1_u8(
852               &quantization_params->neon.kernel_zero_points[channels - c]);
853           const uint8x8_t vk0 = vld1_u8(w);
854           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
855           const uint8x8_t vi0 = vreinterpret_u8_u64(
856               vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift));
857           const int16x8_t vxk0 =
858               vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
859           const int16x8_t vxi0 =
860               vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
861           int32x4_t vaccX0_lo =
862               vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
863           int32x4_t vaccX0_hi =
864               vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
865 
866           const uint8x8_t vk1 = vld1_u8(w);
867           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
868           const uint8x8_t vi1 = vreinterpret_u8_u64(
869               vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift));
870           const int16x8_t vxk1 =
871               vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
872           const int16x8_t vxi1 =
873               vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
874           int32x4_t vaccX1_lo =
875               vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1));
876           int32x4_t vaccX1_hi =
877               vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1));
878 
879           const uint8x8_t vk2 = vld1_u8(w);
880           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
881           const uint8x8_t vi2 = vreinterpret_u8_u64(
882               vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift));
883           const int16x8_t vxk2 =
884               vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
885           const int16x8_t vxi2 =
886               vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
887           vaccX0_lo =
888               vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
889           vaccX0_hi =
890               vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
891 
892           const uint8x8_t vk3 = vld1_u8(w);
893           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
894           const uint8x8_t vi3 = vreinterpret_u8_u64(
895               vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift));
896           const int16x8_t vxk3 =
897               vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
898           const int16x8_t vxi3 =
899               vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
900           vaccX1_lo =
901               vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
902           vaccX1_hi =
903               vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
904 
905           const uint8x8_t vk4 = vld1_u8(w);
906           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
907           const uint8x8_t vi4 = vreinterpret_u8_u64(
908               vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift));
909           const int16x8_t vxk4 =
910               vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
911           const int16x8_t vxi4 =
912               vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
913           vaccX0_lo =
914               vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
915           vaccX0_hi =
916               vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
917 
918           const uint8x8_t vk5 = vld1_u8(w);
919           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
920           const uint8x8_t vi5 = vreinterpret_u8_u64(
921               vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift));
922           const int16x8_t vxk5 =
923               vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
924           const int16x8_t vxi5 =
925               vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
926           vaccX1_lo =
927               vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
928           vaccX1_hi =
929               vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
930 
931           const uint8x8_t vk6 = vld1_u8(w);
932           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
933           const uint8x8_t vi6 = vreinterpret_u8_u64(
934               vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift));
935           const int16x8_t vxk6 =
936               vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
937           const int16x8_t vxi6 =
938               vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
939           vaccX0_lo =
940               vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
941           vaccX0_hi =
942               vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
943 
944           const uint8x8_t vk7 = vld1_u8(w);
945           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
946           const uint8x8_t vi7 = vreinterpret_u8_u64(
947               vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift));
948           const int16x8_t vxk7 =
949               vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
950           const int16x8_t vxi7 =
951               vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
952           vaccX1_lo =
953               vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
954           vaccX1_hi =
955               vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
956 
957           const uint8x8_t vk8 = vld1_u8(w);
958           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
959           const uint8x8_t vi8 = vreinterpret_u8_u64(
960               vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift));
961           const int16x8_t vxk8 =
962               vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
963           const int16x8_t vxi8 =
964               vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
965           vaccX0_lo =
966               vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
967           vaccX0_hi =
968               vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
969 
970           int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
971           int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
972 
973           const int32x4_t vacc_lo_old = vld1q_s32(outacc);
974           const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4);
975           vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
976           vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
977 
978           const float32x4_t requantization_scale_v_lo = vld1q_f32(
979               &quantization_params->neon.requantization_scales[channels - c]);
980           const float32x4_t requantization_scale_v_hi =
981               vld1q_f32(&quantization_params->neon
982                              .requantization_scales[channels - c + 4]);
983 
984           const float32x4_t vacc_lo_f =
985               vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v_lo);
986           const float32x4_t vacc_hi_f =
987               vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v_hi);
988 
989 #ifdef __aarch64__
990           vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
991           vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
992 
993           const int16x8_t vacc = vqaddq_s16(
994               vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi),
995               voutput_zero_point);
996 
997           uint8x8_t vout = vqmovun_s16(vacc);
998           vout = vmax_u8(vout, voutput_min);
999           vout = vmin_u8(vout, voutput_max);
1000 #else
1001           const float32x4_t vacc_lo_f_clamped =
1002               vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
1003           const float32x4_t vacc_hi_f_clamped =
1004               vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
1005           vacc_lo = vsubq_s32(
1006               vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f_clamped, vfmagic)),
1007               vimagic);
1008           vacc_hi = vsubq_s32(
1009               vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f_clamped, vfmagic)),
1010               vimagic);
1011           const int16x8_t vacc =
1012               vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
1013 
1014           uint8x8_t vout = vqmovun_s16(vacc);
1015 #endif
1016 
1017           if (c & 4) {
1018             vst1_lane_u32(
1019                 __builtin_assume_aligned(output, 1),
1020                 vreinterpret_u32_u8(vout),
1021                 0);
1022             output += 4;
1023             vout = vext_u8(vout, vout, 4);
1024           }
1025           if (c & 2) {
1026             vst1_lane_u16(
1027                 __builtin_assume_aligned(output, 1),
1028                 vreinterpret_u16_u8(vout),
1029                 0);
1030             output += 2;
1031             vout = vext_u8(vout, vout, 2);
1032           }
1033           if (c & 1) {
1034             vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0);
1035             output++;
1036           }
1037         }
1038       }
1039 
1040       output = (uint8_t*)((uintptr_t)output + output_increment);
1041     }
1042     input = (const uint8_t**)((uintptr_t)input_row_start + input_row_stride);
1043   }
1044 }
1045