1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <arm_neon.h>
10
11 #include <qnnpack/q8dwconv.h>
12
13 /**
14 * Kernel for performing depthwise convolution with a kernel of weights of
15 * size 27 (ex. 3x3x3). The general strategy is
16 * For each output row:
17 * For each output pixel in that row:
18 * For each of the three groupings of 9 of the 27 weights (ex. one yz slice for a 3x3x3 kernel):
19 * For each grouping of 8 channels:
20 * Load input pixel values from the indirection buffer and the weights,
21 * multiply and add them, and keep track of a running total of these
22 * products and the bias in a temporary buffer
23 * Perform requantization to obtain final output
24 * Advance indirection buffer to next pixel in the row
25 * Advance indirection buffer to next row
26 */
pytorch_q8dwconv_ukernel_mp8x27__neon(size_t channels,size_t output_height,size_t output_width,const uint8_t ** input,const void * weights,int32_t * outacc32,uint8_t * output,size_t input_row_stride,size_t input_col_stride,size_t output_increment,const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static1])27 void pytorch_q8dwconv_ukernel_mp8x27__neon(
28 size_t channels,
29 size_t output_height,
30 size_t output_width,
31 const uint8_t** input,
32 const void* weights,
33 int32_t* outacc32,
34 uint8_t* output,
35 size_t input_row_stride, // Stride between rows in the indirection buffer
36 size_t input_col_stride, // Stride between columns in the indirection buffer
37 size_t
38 output_increment, // Padding between output pixels in the output buffer
39 const union pytorch_qnnp_conv_quantization_params
40 quantization_params[restrict static 1]) {
41 const uint8x8_t vinput_zero_point =
42 vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
43 #ifdef __aarch64__
44 const int16x8_t voutput_zero_point =
45 vld1q_dup_s16(&quantization_params->neon.output_zero_point);
46 const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min);
47 const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max);
48 #else
49 const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
50 const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
51 const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
52 const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
53 #endif
54
55 for (size_t output_y = 0; output_y < output_height; output_y++) {
56 const uint8_t** input_row_start = input;
57 for (size_t output_x = 0; output_x < output_width; output_x++) {
58 uint8_t* output_start = output;
59 int32_t* outacc = outacc32;
60 const void* w = weights;
61 {
62 const uint8_t* i0 = input[0];
63 const uint8_t* i1 = input[1];
64 const uint8_t* i2 = input[2];
65 const uint8_t* i3 = input[3];
66 const uint8_t* i4 = input[4];
67 const uint8_t* i5 = input[5];
68 const uint8_t* i6 = input[6];
69 const uint8_t* i7 = input[7];
70 const uint8_t* i8 = input[8];
71
72 size_t c = channels;
73 const uint8_t* kernel_zero_points_ptr =
74 quantization_params->neon.kernel_zero_points + channels;
75 for (; c >= 8; c -= 8) {
76 const uint8x8_t vkernel_zero_point =
77 vld1_u8(kernel_zero_points_ptr - c);
78 int32x4_t vaccX1_lo = vld1q_s32(w);
79 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
80 int32x4_t vaccX1_hi = vld1q_s32(w);
81 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
82
83 const uint8x8_t vk0 = vld1_u8(w);
84 w += 8;
85 const uint8x8_t vi0 = vld1_u8(i0);
86 i0 += 8;
87 const int16x8_t vxk0 =
88 vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
89 const int16x8_t vxi0 =
90 vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
91 int32x4_t vaccX0_lo =
92 vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
93 int32x4_t vaccX0_hi =
94 vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
95
96 const uint8x8_t vk1 = vld1_u8(w);
97 w += 8;
98 const uint8x8_t vi1 = vld1_u8(i1);
99 i1 += 8;
100 const int16x8_t vxk1 =
101 vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
102 const int16x8_t vxi1 =
103 vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
104 vaccX1_lo =
105 vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
106 vaccX1_hi =
107 vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
108
109 const uint8x8_t vk2 = vld1_u8(w);
110 w += 8;
111 const uint8x8_t vi2 = vld1_u8(i2);
112 i2 += 8;
113 const int16x8_t vxk2 =
114 vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
115 const int16x8_t vxi2 =
116 vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
117 vaccX0_lo =
118 vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
119 vaccX0_hi =
120 vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
121
122 const uint8x8_t vk3 = vld1_u8(w);
123 w += 8;
124 const uint8x8_t vi3 = vld1_u8(i3);
125 i3 += 8;
126 const int16x8_t vxk3 =
127 vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
128 const int16x8_t vxi3 =
129 vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
130 vaccX1_lo =
131 vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
132 vaccX1_hi =
133 vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
134
135 const uint8x8_t vk4 = vld1_u8(w);
136 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
137 const uint8x8_t vi4 = vld1_u8(i4);
138 i4 += 8;
139 const int16x8_t vxk4 =
140 vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
141 const int16x8_t vxi4 =
142 vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
143 vaccX0_lo =
144 vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
145 vaccX0_hi =
146 vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
147
148 const uint8x8_t vk5 = vld1_u8(w);
149 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
150 const uint8x8_t vi5 = vld1_u8(i5);
151 i5 += 8;
152 const int16x8_t vxk5 =
153 vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
154 const int16x8_t vxi5 =
155 vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
156 vaccX1_lo =
157 vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
158 vaccX1_hi =
159 vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
160
161 const uint8x8_t vk6 = vld1_u8(w);
162 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
163 const uint8x8_t vi6 = vld1_u8(i6);
164 i6 += 8;
165 const int16x8_t vxk6 =
166 vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
167 const int16x8_t vxi6 =
168 vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
169 vaccX0_lo =
170 vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
171 vaccX0_hi =
172 vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
173
174 const uint8x8_t vk7 = vld1_u8(w);
175 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
176 const uint8x8_t vi7 = vld1_u8(i7);
177 i7 += 8;
178 const int16x8_t vxk7 =
179 vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
180 const int16x8_t vxi7 =
181 vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
182 vaccX1_lo =
183 vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
184 vaccX1_hi =
185 vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
186
187 const uint8x8_t vk8 = vld1_u8(w);
188 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
189 const uint8x8_t vi8 = vld1_u8(i8);
190 i8 += 8;
191 const int16x8_t vxk8 =
192 vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
193 const int16x8_t vxi8 =
194 vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
195 vaccX0_lo =
196 vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
197 vaccX0_hi =
198 vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
199
200 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
201 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
202
203 vst1q_s32(outacc, vacc_lo);
204 outacc += 4;
205 vst1q_s32(outacc, vacc_hi);
206 outacc += 4;
207 }
208 if (c != 0) {
209 const size_t c_predecrement = 8 - c;
210 const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
211 i0 -= c_predecrement;
212 i1 -= c_predecrement;
213 i2 -= c_predecrement;
214 i3 -= c_predecrement;
215 i4 -= c_predecrement;
216 i5 -= c_predecrement;
217 i6 -= c_predecrement;
218 i7 -= c_predecrement;
219 i8 -= c_predecrement;
220
221 const uint8x8_t vkernel_zero_point = vld1_u8(
222 &quantization_params->neon.kernel_zero_points[channels - c]);
223 int32x4_t vaccX1_lo = vld1q_s32(w);
224 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
225 int32x4_t vaccX1_hi = vld1q_s32(w);
226 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
227
228 const uint8x8_t vk0 = vld1_u8(w);
229 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
230 const uint8x8_t vi0 = vreinterpret_u8_u64(
231 vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift));
232 const int16x8_t vxk0 =
233 vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
234 const int16x8_t vxi0 =
235 vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
236 int32x4_t vaccX0_lo =
237 vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
238 int32x4_t vaccX0_hi =
239 vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
240
241 const uint8x8_t vk1 = vld1_u8(w);
242 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
243 const uint8x8_t vi1 = vreinterpret_u8_u64(
244 vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift));
245 const int16x8_t vxk1 =
246 vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
247 const int16x8_t vxi1 =
248 vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
249 vaccX1_lo =
250 vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1));
251 vaccX1_hi =
252 vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1));
253
254 const uint8x8_t vk2 = vld1_u8(w);
255 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
256 const uint8x8_t vi2 = vreinterpret_u8_u64(
257 vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift));
258 const int16x8_t vxk2 =
259 vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
260 const int16x8_t vxi2 =
261 vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
262 vaccX0_lo =
263 vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
264 vaccX0_hi =
265 vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
266
267 const uint8x8_t vk3 = vld1_u8(w);
268 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
269 const uint8x8_t vi3 = vreinterpret_u8_u64(
270 vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift));
271 const int16x8_t vxk3 =
272 vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
273 const int16x8_t vxi3 =
274 vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
275 vaccX1_lo =
276 vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
277 vaccX1_hi =
278 vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
279
280 const uint8x8_t vk4 = vld1_u8(w);
281 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
282 const uint8x8_t vi4 = vreinterpret_u8_u64(
283 vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift));
284 const int16x8_t vxk4 =
285 vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
286 const int16x8_t vxi4 =
287 vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
288 vaccX0_lo =
289 vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
290 vaccX0_hi =
291 vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
292
293 const uint8x8_t vk5 = vld1_u8(w);
294 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
295 const uint8x8_t vi5 = vreinterpret_u8_u64(
296 vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift));
297 const int16x8_t vxk5 =
298 vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
299 const int16x8_t vxi5 =
300 vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
301 vaccX1_lo =
302 vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
303 vaccX1_hi =
304 vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
305
306 const uint8x8_t vk6 = vld1_u8(w);
307 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
308 const uint8x8_t vi6 = vreinterpret_u8_u64(
309 vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift));
310 const int16x8_t vxk6 =
311 vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
312 const int16x8_t vxi6 =
313 vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
314 vaccX0_lo =
315 vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
316 vaccX0_hi =
317 vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
318
319 const uint8x8_t vk7 = vld1_u8(w);
320 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
321 const uint8x8_t vi7 = vreinterpret_u8_u64(
322 vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift));
323 const int16x8_t vxk7 =
324 vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
325 const int16x8_t vxi7 =
326 vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
327 vaccX1_lo =
328 vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
329 vaccX1_hi =
330 vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
331
332 const uint8x8_t vk8 = vld1_u8(w);
333 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
334 const uint8x8_t vi8 = vreinterpret_u8_u64(
335 vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift));
336 const int16x8_t vxk8 =
337 vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
338 const int16x8_t vxi8 =
339 vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
340 vaccX0_lo =
341 vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
342 vaccX0_hi =
343 vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
344
345 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
346 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
347
348 vst1q_s32(outacc, vacc_lo);
349 outacc += 4;
350 vst1q_s32(outacc, vacc_hi);
351 outacc += 4;
352 }
353 }
354 {
355 const uint8_t* i0 = input[9];
356 const uint8_t* i1 = input[10];
357 const uint8_t* i2 = input[11];
358 const uint8_t* i3 = input[12];
359 const uint8_t* i4 = input[13];
360 const uint8_t* i5 = input[14];
361 const uint8_t* i6 = input[15];
362 const uint8_t* i7 = input[16];
363 const uint8_t* i8 = input[17];
364 output = output_start;
365 outacc = outacc32;
366
367 size_t c = channels;
368 const uint8_t* kernel_zero_points_ptr =
369 quantization_params->neon.kernel_zero_points + channels;
370 for (; c >= 8; c -= 8) {
371 const uint8x8_t vkernel_zero_point =
372 vld1_u8(kernel_zero_points_ptr - c);
373 const uint8x8_t vk0 = vld1_u8(w);
374 w += 8;
375 const uint8x8_t vi0 = vld1_u8(i0);
376 i0 += 8;
377 const int16x8_t vxk0 =
378 vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
379 const int16x8_t vxi0 =
380 vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
381 int32x4_t vaccX0_lo =
382 vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
383 int32x4_t vaccX0_hi =
384 vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
385
386 const uint8x8_t vk1 = vld1_u8(w);
387 w += 8;
388 const uint8x8_t vi1 = vld1_u8(i1);
389 i1 += 8;
390 const int16x8_t vxk1 =
391 vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
392 const int16x8_t vxi1 =
393 vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
394 int32x4_t vaccX1_lo =
395 vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1));
396 int32x4_t vaccX1_hi =
397 vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1));
398
399 const uint8x8_t vk2 = vld1_u8(w);
400 w += 8;
401 const uint8x8_t vi2 = vld1_u8(i2);
402 i2 += 8;
403 const int16x8_t vxk2 =
404 vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
405 const int16x8_t vxi2 =
406 vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
407 vaccX0_lo =
408 vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
409 vaccX0_hi =
410 vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
411
412 const uint8x8_t vk3 = vld1_u8(w);
413 w += 8;
414 const uint8x8_t vi3 = vld1_u8(i3);
415 i3 += 8;
416 const int16x8_t vxk3 =
417 vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
418 const int16x8_t vxi3 =
419 vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
420 vaccX1_lo =
421 vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
422 vaccX1_hi =
423 vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
424
425 const uint8x8_t vk4 = vld1_u8(w);
426 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
427 const uint8x8_t vi4 = vld1_u8(i4);
428 i4 += 8;
429 const int16x8_t vxk4 =
430 vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
431 const int16x8_t vxi4 =
432 vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
433 vaccX0_lo =
434 vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
435 vaccX0_hi =
436 vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
437
438 const uint8x8_t vk5 = vld1_u8(w);
439 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
440 const uint8x8_t vi5 = vld1_u8(i5);
441 i5 += 8;
442 const int16x8_t vxk5 =
443 vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
444 const int16x8_t vxi5 =
445 vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
446 vaccX1_lo =
447 vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
448 vaccX1_hi =
449 vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
450
451 const uint8x8_t vk6 = vld1_u8(w);
452 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
453 const uint8x8_t vi6 = vld1_u8(i6);
454 i6 += 8;
455 const int16x8_t vxk6 =
456 vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
457 const int16x8_t vxi6 =
458 vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
459 vaccX0_lo =
460 vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
461 vaccX0_hi =
462 vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
463
464 const uint8x8_t vk7 = vld1_u8(w);
465 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
466 const uint8x8_t vi7 = vld1_u8(i7);
467 i7 += 8;
468 const int16x8_t vxk7 =
469 vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
470 const int16x8_t vxi7 =
471 vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
472 vaccX1_lo =
473 vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
474 vaccX1_hi =
475 vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
476
477 const uint8x8_t vk8 = vld1_u8(w);
478 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
479 const uint8x8_t vi8 = vld1_u8(i8);
480 i8 += 8;
481 const int16x8_t vxk8 =
482 vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
483 const int16x8_t vxi8 =
484 vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
485 vaccX0_lo =
486 vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
487 vaccX0_hi =
488 vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
489
490 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
491 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
492
493 const int32x4_t vacc_lo_old = vld1q_s32(outacc);
494 const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4);
495 vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
496 vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
497 vst1q_s32(outacc, vacc_lo);
498 outacc += 4;
499 vst1q_s32(outacc, vacc_hi);
500 outacc += 4;
501 }
502 if (c != 0) {
503 const size_t c_predecrement = 8 - c;
504 const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
505 i0 -= c_predecrement;
506 i1 -= c_predecrement;
507 i2 -= c_predecrement;
508 i3 -= c_predecrement;
509 i4 -= c_predecrement;
510 i5 -= c_predecrement;
511 i6 -= c_predecrement;
512 i7 -= c_predecrement;
513 i8 -= c_predecrement;
514
515 const uint8x8_t vkernel_zero_point = vld1_u8(
516 &quantization_params->neon.kernel_zero_points[channels - c]);
517 const uint8x8_t vk0 = vld1_u8(w);
518 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
519 const uint8x8_t vi0 = vreinterpret_u8_u64(
520 vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift));
521 const int16x8_t vxk0 =
522 vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
523 const int16x8_t vxi0 =
524 vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
525 int32x4_t vaccX0_lo =
526 vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
527 int32x4_t vaccX0_hi =
528 vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
529
530 const uint8x8_t vk1 = vld1_u8(w);
531 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
532 const uint8x8_t vi1 = vreinterpret_u8_u64(
533 vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift));
534 const int16x8_t vxk1 =
535 vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
536 const int16x8_t vxi1 =
537 vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
538 int32x4_t vaccX1_lo =
539 vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1));
540 int32x4_t vaccX1_hi =
541 vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1));
542
543 const uint8x8_t vk2 = vld1_u8(w);
544 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
545 const uint8x8_t vi2 = vreinterpret_u8_u64(
546 vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift));
547 const int16x8_t vxk2 =
548 vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
549 const int16x8_t vxi2 =
550 vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
551 vaccX0_lo =
552 vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
553 vaccX0_hi =
554 vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
555
556 const uint8x8_t vk3 = vld1_u8(w);
557 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
558 const uint8x8_t vi3 = vreinterpret_u8_u64(
559 vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift));
560 const int16x8_t vxk3 =
561 vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
562 const int16x8_t vxi3 =
563 vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
564 vaccX1_lo =
565 vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
566 vaccX1_hi =
567 vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
568
569 const uint8x8_t vk4 = vld1_u8(w);
570 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
571 const uint8x8_t vi4 = vreinterpret_u8_u64(
572 vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift));
573 const int16x8_t vxk4 =
574 vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
575 const int16x8_t vxi4 =
576 vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
577 vaccX0_lo =
578 vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
579 vaccX0_hi =
580 vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
581
582 const uint8x8_t vk5 = vld1_u8(w);
583 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
584 const uint8x8_t vi5 = vreinterpret_u8_u64(
585 vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift));
586 const int16x8_t vxk5 =
587 vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
588 const int16x8_t vxi5 =
589 vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
590 vaccX1_lo =
591 vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
592 vaccX1_hi =
593 vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
594
595 const uint8x8_t vk6 = vld1_u8(w);
596 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
597 const uint8x8_t vi6 = vreinterpret_u8_u64(
598 vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift));
599 const int16x8_t vxk6 =
600 vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
601 const int16x8_t vxi6 =
602 vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
603 vaccX0_lo =
604 vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
605 vaccX0_hi =
606 vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
607
608 const uint8x8_t vk7 = vld1_u8(w);
609 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
610 const uint8x8_t vi7 = vreinterpret_u8_u64(
611 vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift));
612 const int16x8_t vxk7 =
613 vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
614 const int16x8_t vxi7 =
615 vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
616 vaccX1_lo =
617 vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
618 vaccX1_hi =
619 vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
620
621 const uint8x8_t vk8 = vld1_u8(w);
622 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
623 const uint8x8_t vi8 = vreinterpret_u8_u64(
624 vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift));
625 const int16x8_t vxk8 =
626 vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
627 const int16x8_t vxi8 =
628 vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
629 vaccX0_lo =
630 vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
631 vaccX0_hi =
632 vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
633
634 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
635 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
636
637 const int32x4_t vacc_lo_old = vld1q_s32(outacc);
638 const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4);
639 vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
640 vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
641 vst1q_s32(outacc, vacc_lo);
642 outacc += 4;
643 vst1q_s32(outacc, vacc_hi);
644 outacc += 4;
645 }
646 }
647
648 {
649 const uint8_t* i0 = input[18];
650 const uint8_t* i1 = input[19];
651 const uint8_t* i2 = input[20];
652 const uint8_t* i3 = input[21];
653 const uint8_t* i4 = input[22];
654 const uint8_t* i5 = input[23];
655 const uint8_t* i6 = input[24];
656 const uint8_t* i7 = input[25];
657 const uint8_t* i8 = input[26];
658 input = (const uint8_t**)((uintptr_t)input + input_col_stride);
659 output = output_start;
660 outacc = outacc32;
661
662 size_t c = channels;
663 const uint8_t* kernel_zero_points_ptr =
664 quantization_params->neon.kernel_zero_points + channels;
665 const float* requantization_scales_ptr =
666 quantization_params->neon.requantization_scales + channels;
667 for (; c >= 8; c -= 8) {
668 const uint8x8_t vkernel_zero_point =
669 vld1_u8(kernel_zero_points_ptr - c);
670 const uint8x8_t vk0 = vld1_u8(w);
671 w += 8;
672 const uint8x8_t vi0 = vld1_u8(i0);
673 i0 += 8;
674 const int16x8_t vxk0 =
675 vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
676 const int16x8_t vxi0 =
677 vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
678 int32x4_t vaccX0_lo =
679 vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
680 int32x4_t vaccX0_hi =
681 vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
682
683 const uint8x8_t vk1 = vld1_u8(w);
684 w += 8;
685 const uint8x8_t vi1 = vld1_u8(i1);
686 i1 += 8;
687 const int16x8_t vxk1 =
688 vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
689 const int16x8_t vxi1 =
690 vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
691 int32x4_t vaccX1_lo =
692 vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1));
693 int32x4_t vaccX1_hi =
694 vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1));
695
696 const uint8x8_t vk2 = vld1_u8(w);
697 w += 8;
698 const uint8x8_t vi2 = vld1_u8(i2);
699 i2 += 8;
700 const int16x8_t vxk2 =
701 vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
702 const int16x8_t vxi2 =
703 vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
704 vaccX0_lo =
705 vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
706 vaccX0_hi =
707 vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
708
709 const uint8x8_t vk3 = vld1_u8(w);
710 w += 8;
711 const uint8x8_t vi3 = vld1_u8(i3);
712 i3 += 8;
713 const int16x8_t vxk3 =
714 vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
715 const int16x8_t vxi3 =
716 vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
717 vaccX1_lo =
718 vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
719 vaccX1_hi =
720 vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
721
722 const uint8x8_t vk4 = vld1_u8(w);
723 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
724 const uint8x8_t vi4 = vld1_u8(i4);
725 i4 += 8;
726 const int16x8_t vxk4 =
727 vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
728 const int16x8_t vxi4 =
729 vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
730 vaccX0_lo =
731 vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
732 vaccX0_hi =
733 vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
734
735 const uint8x8_t vk5 = vld1_u8(w);
736 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
737 const uint8x8_t vi5 = vld1_u8(i5);
738 i5 += 8;
739 const int16x8_t vxk5 =
740 vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
741 const int16x8_t vxi5 =
742 vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
743 vaccX1_lo =
744 vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
745 vaccX1_hi =
746 vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
747
748 const uint8x8_t vk6 = vld1_u8(w);
749 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
750 const uint8x8_t vi6 = vld1_u8(i6);
751 i6 += 8;
752 const int16x8_t vxk6 =
753 vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
754 const int16x8_t vxi6 =
755 vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
756 vaccX0_lo =
757 vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
758 vaccX0_hi =
759 vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
760
761 const uint8x8_t vk7 = vld1_u8(w);
762 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
763 const uint8x8_t vi7 = vld1_u8(i7);
764 i7 += 8;
765 const int16x8_t vxk7 =
766 vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
767 const int16x8_t vxi7 =
768 vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
769 vaccX1_lo =
770 vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
771 vaccX1_hi =
772 vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
773
774 const uint8x8_t vk8 = vld1_u8(w);
775 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
776 const uint8x8_t vi8 = vld1_u8(i8);
777 i8 += 8;
778 const int16x8_t vxk8 =
779 vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
780 const int16x8_t vxi8 =
781 vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
782 vaccX0_lo =
783 vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
784 vaccX0_hi =
785 vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
786
787 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
788 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
789
790 const int32x4_t vacc_lo_old = vld1q_s32(outacc);
791 outacc += 4;
792 const int32x4_t vacc_hi_old = vld1q_s32(outacc);
793 outacc += 4;
794 vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
795 vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
796
797 const float32x4_t requantization_scale_v_lo =
798 vld1q_f32(requantization_scales_ptr - c);
799 const float32x4_t requantization_scale_v_hi =
800 vld1q_f32(requantization_scales_ptr - c + 4);
801
802 const float32x4_t vacc_lo_f =
803 vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v_lo);
804 const float32x4_t vacc_hi_f =
805 vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v_hi);
806
807 #ifdef __aarch64__
808 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
809 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
810
811 const int16x8_t vacc = vqaddq_s16(
812 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi),
813 voutput_zero_point);
814
815 uint8x8_t vout = vqmovun_s16(vacc);
816 vout = vmax_u8(vout, voutput_min);
817 vout = vmin_u8(vout, voutput_max);
818 #else
819 const float32x4_t vacc_lo_f_clamped =
820 vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
821 const float32x4_t vacc_hi_f_clamped =
822 vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
823 vacc_lo = vsubq_s32(
824 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f_clamped, vfmagic)),
825 vimagic);
826 vacc_hi = vsubq_s32(
827 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f_clamped, vfmagic)),
828 vimagic);
829 const int16x8_t vacc =
830 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
831
832 uint8x8_t vout = vqmovun_s16(vacc);
833 #endif
834
835 vst1_u8(output, vout);
836 output += 8;
837 }
838 if (c != 0) {
839 const size_t c_predecrement = 8 - c;
840 const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement);
841 i0 -= c_predecrement;
842 i1 -= c_predecrement;
843 i2 -= c_predecrement;
844 i3 -= c_predecrement;
845 i4 -= c_predecrement;
846 i5 -= c_predecrement;
847 i6 -= c_predecrement;
848 i7 -= c_predecrement;
849 i8 -= c_predecrement;
850
851 const uint8x8_t vkernel_zero_point = vld1_u8(
852 &quantization_params->neon.kernel_zero_points[channels - c]);
853 const uint8x8_t vk0 = vld1_u8(w);
854 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
855 const uint8x8_t vi0 = vreinterpret_u8_u64(
856 vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift));
857 const int16x8_t vxk0 =
858 vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point));
859 const int16x8_t vxi0 =
860 vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point));
861 int32x4_t vaccX0_lo =
862 vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0));
863 int32x4_t vaccX0_hi =
864 vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0));
865
866 const uint8x8_t vk1 = vld1_u8(w);
867 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
868 const uint8x8_t vi1 = vreinterpret_u8_u64(
869 vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift));
870 const int16x8_t vxk1 =
871 vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point));
872 const int16x8_t vxi1 =
873 vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point));
874 int32x4_t vaccX1_lo =
875 vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1));
876 int32x4_t vaccX1_hi =
877 vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1));
878
879 const uint8x8_t vk2 = vld1_u8(w);
880 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
881 const uint8x8_t vi2 = vreinterpret_u8_u64(
882 vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift));
883 const int16x8_t vxk2 =
884 vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point));
885 const int16x8_t vxi2 =
886 vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point));
887 vaccX0_lo =
888 vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2));
889 vaccX0_hi =
890 vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2));
891
892 const uint8x8_t vk3 = vld1_u8(w);
893 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
894 const uint8x8_t vi3 = vreinterpret_u8_u64(
895 vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift));
896 const int16x8_t vxk3 =
897 vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point));
898 const int16x8_t vxi3 =
899 vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point));
900 vaccX1_lo =
901 vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3));
902 vaccX1_hi =
903 vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3));
904
905 const uint8x8_t vk4 = vld1_u8(w);
906 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
907 const uint8x8_t vi4 = vreinterpret_u8_u64(
908 vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift));
909 const int16x8_t vxk4 =
910 vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point));
911 const int16x8_t vxi4 =
912 vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point));
913 vaccX0_lo =
914 vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4));
915 vaccX0_hi =
916 vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4));
917
918 const uint8x8_t vk5 = vld1_u8(w);
919 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
920 const uint8x8_t vi5 = vreinterpret_u8_u64(
921 vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift));
922 const int16x8_t vxk5 =
923 vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point));
924 const int16x8_t vxi5 =
925 vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point));
926 vaccX1_lo =
927 vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5));
928 vaccX1_hi =
929 vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5));
930
931 const uint8x8_t vk6 = vld1_u8(w);
932 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
933 const uint8x8_t vi6 = vreinterpret_u8_u64(
934 vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift));
935 const int16x8_t vxk6 =
936 vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point));
937 const int16x8_t vxi6 =
938 vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point));
939 vaccX0_lo =
940 vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6));
941 vaccX0_hi =
942 vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6));
943
944 const uint8x8_t vk7 = vld1_u8(w);
945 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
946 const uint8x8_t vi7 = vreinterpret_u8_u64(
947 vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift));
948 const int16x8_t vxk7 =
949 vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point));
950 const int16x8_t vxi7 =
951 vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point));
952 vaccX1_lo =
953 vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7));
954 vaccX1_hi =
955 vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7));
956
957 const uint8x8_t vk8 = vld1_u8(w);
958 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
959 const uint8x8_t vi8 = vreinterpret_u8_u64(
960 vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift));
961 const int16x8_t vxk8 =
962 vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point));
963 const int16x8_t vxi8 =
964 vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point));
965 vaccX0_lo =
966 vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8));
967 vaccX0_hi =
968 vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8));
969
970 int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
971 int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
972
973 const int32x4_t vacc_lo_old = vld1q_s32(outacc);
974 const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4);
975 vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
976 vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
977
978 const float32x4_t requantization_scale_v_lo = vld1q_f32(
979 &quantization_params->neon.requantization_scales[channels - c]);
980 const float32x4_t requantization_scale_v_hi =
981 vld1q_f32(&quantization_params->neon
982 .requantization_scales[channels - c + 4]);
983
984 const float32x4_t vacc_lo_f =
985 vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v_lo);
986 const float32x4_t vacc_hi_f =
987 vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v_hi);
988
989 #ifdef __aarch64__
990 vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
991 vacc_hi = vcvtnq_s32_f32(vacc_hi_f);
992
993 const int16x8_t vacc = vqaddq_s16(
994 vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi),
995 voutput_zero_point);
996
997 uint8x8_t vout = vqmovun_s16(vacc);
998 vout = vmax_u8(vout, voutput_min);
999 vout = vmin_u8(vout, voutput_max);
1000 #else
1001 const float32x4_t vacc_lo_f_clamped =
1002 vminq_f32(vmaxq_f32(vacc_lo_f, vfmin), vfmax);
1003 const float32x4_t vacc_hi_f_clamped =
1004 vminq_f32(vmaxq_f32(vacc_hi_f, vfmin), vfmax);
1005 vacc_lo = vsubq_s32(
1006 vreinterpretq_s32_f32(vaddq_f32(vacc_lo_f_clamped, vfmagic)),
1007 vimagic);
1008 vacc_hi = vsubq_s32(
1009 vreinterpretq_s32_f32(vaddq_f32(vacc_hi_f_clamped, vfmagic)),
1010 vimagic);
1011 const int16x8_t vacc =
1012 vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
1013
1014 uint8x8_t vout = vqmovun_s16(vacc);
1015 #endif
1016
1017 if (c & 4) {
1018 vst1_lane_u32(
1019 __builtin_assume_aligned(output, 1),
1020 vreinterpret_u32_u8(vout),
1021 0);
1022 output += 4;
1023 vout = vext_u8(vout, vout, 4);
1024 }
1025 if (c & 2) {
1026 vst1_lane_u16(
1027 __builtin_assume_aligned(output, 1),
1028 vreinterpret_u16_u8(vout),
1029 0);
1030 output += 2;
1031 vout = vext_u8(vout, vout, 2);
1032 }
1033 if (c & 1) {
1034 vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0);
1035 output++;
1036 }
1037 }
1038 }
1039
1040 output = (uint8_t*)((uintptr_t)output + output_increment);
1041 }
1042 input = (const uint8_t**)((uintptr_t)input_row_start + input_row_stride);
1043 }
1044 }
1045