xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <arm_neon.h>
10 
11 #include <qnnpack/q8conv.h>
12 #include <requantization/runtime-neon.h>
13 
pytorch_q8conv_ukernel_4x8__neon(size_t mr,size_t nr,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t c_stride,size_t output_channel_index,const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static1])14 void pytorch_q8conv_ukernel_4x8__neon(
15     size_t mr,
16     size_t nr,
17     size_t kc,
18     size_t ks,
19     const uint8_t** restrict a,
20     const void* restrict w,
21     uint8_t* restrict c,
22     size_t c_stride,
23     size_t output_channel_index,
24     const union pytorch_qnnp_conv_quantization_params
25         quantization_params[restrict static 1]) {
26   const uint8x8_t va_zero_point =
27       vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
28   // Assumes that kernel_zero_points is an array padded with necessary elements
29   // in order to make it multiple of 8.
30   const uint8x8_t vb_zero_point =
31       vld1_u8((const uint8_t*)&quantization_params->neon.kernel_zero_points
32           [output_channel_index]);
33 
34   int32x4_t vacc0x0123 = vld1q_s32(w);
35   w = (void*)((uintptr_t)w + sizeof(int32x4_t));
36   int32x4_t vacc0x4567 = vld1q_s32(w);
37   w = (void*)((uintptr_t)w + sizeof(int32x4_t));
38   int32x4_t vacc1x0123 = vacc0x0123;
39   int32x4_t vacc1x4567 = vacc0x4567;
40   int32x4_t vacc2x0123 = vacc0x0123;
41   int32x4_t vacc2x4567 = vacc0x4567;
42   int32x4_t vacc3x0123 = vacc0x0123;
43   int32x4_t vacc3x4567 = vacc0x4567;
44 
45   do {
46     const uint8_t* restrict a0 = *a++;
47     const uint8_t* restrict a1 = *a++;
48     const uint8_t* restrict a2 = *a++;
49     const uint8_t* restrict a3 = *a++;
50 
51     size_t k = kc;
52     for (; k >= 8; k -= 8) {
53       const uint8x8_t va0 = vld1_u8(a0);
54       a0 += 8;
55       const uint8x8_t va1 = vld1_u8(a1);
56       a1 += 8;
57       const uint8x8_t va2 = vld1_u8(a2);
58       a2 += 8;
59       const uint8x8_t va3 = vld1_u8(a3);
60       a3 += 8;
61       const int16x8_t vxa0 =
62           vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
63       const int16x8_t vxa1 =
64           vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
65       const int16x8_t vxa2 =
66           vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
67       const int16x8_t vxa3 =
68           vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
69 
70       {
71         const uint8x8_t vb01234567 = vld1_u8(w);
72         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
73         const int16x8_t vxb01234567 =
74             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
75 
76         vacc0x0123 = vmlal_lane_s16(
77             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
78         vacc0x4567 = vmlal_lane_s16(
79             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
80         vacc1x0123 = vmlal_lane_s16(
81             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
82         vacc1x4567 = vmlal_lane_s16(
83             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
84         vacc2x0123 = vmlal_lane_s16(
85             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
86         vacc2x4567 = vmlal_lane_s16(
87             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
88         vacc3x0123 = vmlal_lane_s16(
89             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
90         vacc3x4567 = vmlal_lane_s16(
91             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
92       }
93 
94       {
95         const uint8x8_t vb01234567 = vld1_u8(w);
96         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
97         const int16x8_t vxb01234567 =
98             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
99 
100         vacc0x0123 = vmlal_lane_s16(
101             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
102         vacc0x4567 = vmlal_lane_s16(
103             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
104         vacc1x0123 = vmlal_lane_s16(
105             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
106         vacc1x4567 = vmlal_lane_s16(
107             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
108         vacc2x0123 = vmlal_lane_s16(
109             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
110         vacc2x4567 = vmlal_lane_s16(
111             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
112         vacc3x0123 = vmlal_lane_s16(
113             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
114         vacc3x4567 = vmlal_lane_s16(
115             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
116       }
117 
118       {
119         const uint8x8_t vb01234567 = vld1_u8(w);
120         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
121         const int16x8_t vxb01234567 =
122             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
123 
124         vacc0x0123 = vmlal_lane_s16(
125             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
126         vacc0x4567 = vmlal_lane_s16(
127             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
128         vacc1x0123 = vmlal_lane_s16(
129             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
130         vacc1x4567 = vmlal_lane_s16(
131             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
132         vacc2x0123 = vmlal_lane_s16(
133             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
134         vacc2x4567 = vmlal_lane_s16(
135             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
136         vacc3x0123 = vmlal_lane_s16(
137             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
138         vacc3x4567 = vmlal_lane_s16(
139             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
140       }
141 
142       {
143         const uint8x8_t vb01234567 = vld1_u8(w);
144         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
145         const int16x8_t vxb01234567 =
146             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
147 
148         vacc0x0123 = vmlal_lane_s16(
149             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
150         vacc0x4567 = vmlal_lane_s16(
151             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
152         vacc1x0123 = vmlal_lane_s16(
153             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
154         vacc1x4567 = vmlal_lane_s16(
155             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
156         vacc2x0123 = vmlal_lane_s16(
157             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
158         vacc2x4567 = vmlal_lane_s16(
159             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
160         vacc3x0123 = vmlal_lane_s16(
161             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
162         vacc3x4567 = vmlal_lane_s16(
163             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
164       }
165 
166       {
167         const uint8x8_t vb01234567 = vld1_u8(w);
168         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
169         const int16x8_t vxb01234567 =
170             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
171 
172         vacc0x0123 = vmlal_lane_s16(
173             vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0);
174         vacc0x4567 = vmlal_lane_s16(
175             vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0);
176         vacc1x0123 = vmlal_lane_s16(
177             vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0);
178         vacc1x4567 = vmlal_lane_s16(
179             vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0);
180         vacc2x0123 = vmlal_lane_s16(
181             vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0);
182         vacc2x4567 = vmlal_lane_s16(
183             vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0);
184         vacc3x0123 = vmlal_lane_s16(
185             vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0);
186         vacc3x4567 = vmlal_lane_s16(
187             vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0);
188       }
189 
190       {
191         const uint8x8_t vb01234567 = vld1_u8(w);
192         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
193         const int16x8_t vxb01234567 =
194             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
195 
196         vacc0x0123 = vmlal_lane_s16(
197             vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1);
198         vacc0x4567 = vmlal_lane_s16(
199             vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1);
200         vacc1x0123 = vmlal_lane_s16(
201             vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1);
202         vacc1x4567 = vmlal_lane_s16(
203             vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1);
204         vacc2x0123 = vmlal_lane_s16(
205             vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1);
206         vacc2x4567 = vmlal_lane_s16(
207             vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1);
208         vacc3x0123 = vmlal_lane_s16(
209             vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1);
210         vacc3x4567 = vmlal_lane_s16(
211             vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1);
212       }
213 
214       {
215         const uint8x8_t vb01234567 = vld1_u8(w);
216         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
217         const int16x8_t vxb01234567 =
218             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
219 
220         vacc0x0123 = vmlal_lane_s16(
221             vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2);
222         vacc0x4567 = vmlal_lane_s16(
223             vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2);
224         vacc1x0123 = vmlal_lane_s16(
225             vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2);
226         vacc1x4567 = vmlal_lane_s16(
227             vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2);
228         vacc2x0123 = vmlal_lane_s16(
229             vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2);
230         vacc2x4567 = vmlal_lane_s16(
231             vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2);
232         vacc3x0123 = vmlal_lane_s16(
233             vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2);
234         vacc3x4567 = vmlal_lane_s16(
235             vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2);
236       }
237 
238       {
239         const uint8x8_t vb01234567 = vld1_u8(w);
240         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
241         const int16x8_t vxb01234567 =
242             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
243 
244         vacc0x0123 = vmlal_lane_s16(
245             vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 3);
246         vacc0x4567 = vmlal_lane_s16(
247             vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 3);
248         vacc1x0123 = vmlal_lane_s16(
249             vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 3);
250         vacc1x4567 = vmlal_lane_s16(
251             vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 3);
252         vacc2x0123 = vmlal_lane_s16(
253             vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 3);
254         vacc2x4567 = vmlal_lane_s16(
255             vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 3);
256         vacc3x0123 = vmlal_lane_s16(
257             vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 3);
258         vacc3x4567 = vmlal_lane_s16(
259             vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 3);
260       }
261     }
262     if (k != 0) {
263       const size_t a_predecrement = 8 - k;
264       const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement);
265       const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64(
266           vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift));
267       const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64(
268           vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift));
269       const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64(
270           vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift));
271       const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64(
272           vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift));
273       const int16x8_t vxa0 =
274           vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
275       const int16x8_t vxa1 =
276           vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
277       const int16x8_t vxa2 =
278           vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
279       const int16x8_t vxa3 =
280           vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
281 
282       {
283         const uint8x8_t vb01234567 = vld1_u8(w);
284         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
285         const int16x8_t vxb01234567 =
286             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
287 
288         vacc0x0123 = vmlal_lane_s16(
289             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
290         vacc0x4567 = vmlal_lane_s16(
291             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
292         vacc1x0123 = vmlal_lane_s16(
293             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
294         vacc1x4567 = vmlal_lane_s16(
295             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
296         vacc2x0123 = vmlal_lane_s16(
297             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
298         vacc2x4567 = vmlal_lane_s16(
299             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
300         vacc3x0123 = vmlal_lane_s16(
301             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
302         vacc3x4567 = vmlal_lane_s16(
303             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
304       }
305 
306       if (k >= 2) {
307         const uint8x8_t vb01234567 = vld1_u8(w);
308         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
309         const int16x8_t vxb01234567 =
310             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
311 
312         vacc0x0123 = vmlal_lane_s16(
313             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
314         vacc0x4567 = vmlal_lane_s16(
315             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
316         vacc1x0123 = vmlal_lane_s16(
317             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
318         vacc1x4567 = vmlal_lane_s16(
319             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
320         vacc2x0123 = vmlal_lane_s16(
321             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
322         vacc2x4567 = vmlal_lane_s16(
323             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
324         vacc3x0123 = vmlal_lane_s16(
325             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
326         vacc3x4567 = vmlal_lane_s16(
327             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
328 
329         if (k > 2) {
330           const uint8x8_t vb01234567 = vld1_u8(w);
331           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
332           const int16x8_t vxb01234567 =
333               vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
334 
335           vacc0x0123 = vmlal_lane_s16(
336               vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
337           vacc0x4567 = vmlal_lane_s16(
338               vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
339           vacc1x0123 = vmlal_lane_s16(
340               vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
341           vacc1x4567 = vmlal_lane_s16(
342               vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
343           vacc2x0123 = vmlal_lane_s16(
344               vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
345           vacc2x4567 = vmlal_lane_s16(
346               vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
347           vacc3x0123 = vmlal_lane_s16(
348               vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
349           vacc3x4567 = vmlal_lane_s16(
350               vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
351 
352           if (k >= 4) {
353             const uint8x8_t vb01234567 = vld1_u8(w);
354             w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
355             const int16x8_t vxb01234567 =
356                 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
357 
358             vacc0x0123 = vmlal_lane_s16(
359                 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
360             vacc0x4567 = vmlal_lane_s16(
361                 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
362             vacc1x0123 = vmlal_lane_s16(
363                 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
364             vacc1x4567 = vmlal_lane_s16(
365                 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
366             vacc2x0123 = vmlal_lane_s16(
367                 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
368             vacc2x4567 = vmlal_lane_s16(
369                 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
370             vacc3x0123 = vmlal_lane_s16(
371                 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
372             vacc3x4567 = vmlal_lane_s16(
373                 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
374 
375             if (k > 4) {
376               const uint8x8_t vb01234567 = vld1_u8(w);
377               w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
378               const int16x8_t vxb01234567 =
379                   vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
380 
381               vacc0x0123 = vmlal_lane_s16(
382                   vacc0x0123,
383                   vget_low_s16(vxb01234567),
384                   vget_high_s16(vxa0),
385                   0);
386               vacc0x4567 = vmlal_lane_s16(
387                   vacc0x4567,
388                   vget_high_s16(vxb01234567),
389                   vget_high_s16(vxa0),
390                   0);
391               vacc1x0123 = vmlal_lane_s16(
392                   vacc1x0123,
393                   vget_low_s16(vxb01234567),
394                   vget_high_s16(vxa1),
395                   0);
396               vacc1x4567 = vmlal_lane_s16(
397                   vacc1x4567,
398                   vget_high_s16(vxb01234567),
399                   vget_high_s16(vxa1),
400                   0);
401               vacc2x0123 = vmlal_lane_s16(
402                   vacc2x0123,
403                   vget_low_s16(vxb01234567),
404                   vget_high_s16(vxa2),
405                   0);
406               vacc2x4567 = vmlal_lane_s16(
407                   vacc2x4567,
408                   vget_high_s16(vxb01234567),
409                   vget_high_s16(vxa2),
410                   0);
411               vacc3x0123 = vmlal_lane_s16(
412                   vacc3x0123,
413                   vget_low_s16(vxb01234567),
414                   vget_high_s16(vxa3),
415                   0);
416               vacc3x4567 = vmlal_lane_s16(
417                   vacc3x4567,
418                   vget_high_s16(vxb01234567),
419                   vget_high_s16(vxa3),
420                   0);
421 
422               if (k >= 6) {
423                 const uint8x8_t vb01234567 = vld1_u8(w);
424                 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
425                 const int16x8_t vxb01234567 =
426                     vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
427 
428                 vacc0x0123 = vmlal_lane_s16(
429                     vacc0x0123,
430                     vget_low_s16(vxb01234567),
431                     vget_high_s16(vxa0),
432                     1);
433                 vacc0x4567 = vmlal_lane_s16(
434                     vacc0x4567,
435                     vget_high_s16(vxb01234567),
436                     vget_high_s16(vxa0),
437                     1);
438                 vacc1x0123 = vmlal_lane_s16(
439                     vacc1x0123,
440                     vget_low_s16(vxb01234567),
441                     vget_high_s16(vxa1),
442                     1);
443                 vacc1x4567 = vmlal_lane_s16(
444                     vacc1x4567,
445                     vget_high_s16(vxb01234567),
446                     vget_high_s16(vxa1),
447                     1);
448                 vacc2x0123 = vmlal_lane_s16(
449                     vacc2x0123,
450                     vget_low_s16(vxb01234567),
451                     vget_high_s16(vxa2),
452                     1);
453                 vacc2x4567 = vmlal_lane_s16(
454                     vacc2x4567,
455                     vget_high_s16(vxb01234567),
456                     vget_high_s16(vxa2),
457                     1);
458                 vacc3x0123 = vmlal_lane_s16(
459                     vacc3x0123,
460                     vget_low_s16(vxb01234567),
461                     vget_high_s16(vxa3),
462                     1);
463                 vacc3x4567 = vmlal_lane_s16(
464                     vacc3x4567,
465                     vget_high_s16(vxb01234567),
466                     vget_high_s16(vxa3),
467                     1);
468 
469                 if (k > 6) {
470                   const uint8x8_t vb01234567 = vld1_u8(w);
471                   w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
472                   const int16x8_t vxb01234567 = vreinterpretq_s16_u16(
473                       vsubl_u8(vb01234567, vb_zero_point));
474 
475                   vacc0x0123 = vmlal_lane_s16(
476                       vacc0x0123,
477                       vget_low_s16(vxb01234567),
478                       vget_high_s16(vxa0),
479                       2);
480                   vacc0x4567 = vmlal_lane_s16(
481                       vacc0x4567,
482                       vget_high_s16(vxb01234567),
483                       vget_high_s16(vxa0),
484                       2);
485                   vacc1x0123 = vmlal_lane_s16(
486                       vacc1x0123,
487                       vget_low_s16(vxb01234567),
488                       vget_high_s16(vxa1),
489                       2);
490                   vacc1x4567 = vmlal_lane_s16(
491                       vacc1x4567,
492                       vget_high_s16(vxb01234567),
493                       vget_high_s16(vxa1),
494                       2);
495                   vacc2x0123 = vmlal_lane_s16(
496                       vacc2x0123,
497                       vget_low_s16(vxb01234567),
498                       vget_high_s16(vxa2),
499                       2);
500                   vacc2x4567 = vmlal_lane_s16(
501                       vacc2x4567,
502                       vget_high_s16(vxb01234567),
503                       vget_high_s16(vxa2),
504                       2);
505                   vacc3x0123 = vmlal_lane_s16(
506                       vacc3x0123,
507                       vget_low_s16(vxb01234567),
508                       vget_high_s16(vxa3),
509                       2);
510                   vacc3x4567 = vmlal_lane_s16(
511                       vacc3x4567,
512                       vget_high_s16(vxb01234567),
513                       vget_high_s16(vxa3),
514                       2);
515                 }
516               }
517             }
518           }
519         }
520       }
521     }
522   } while (--ks != 0);
523 
524   // Doing 2 VLD1 instead of 1 VLD2 because A75 has higher latency
525   // 8 vs. 5 for VLD2 with both VLD1 and VLD2 having throughput of
526   // 2 per cycle. So probably this is better.
527   const float32x4_t requantization_scale_c0123 =
528       vld1q_f32(
529           &quantization_params->neon.requantization_scales[output_channel_index]
530           );
531   const float32x4_t requantization_scale_c4567 =
532       vld1q_f32(
533           &quantization_params->neon.requantization_scales[
534               output_channel_index + 4]);
535 
536   /*
537    * Convert int32_t input to FP32 and multiply by FP32 scale.
538    * Both operations involve statistically unbiased roundings:
539    * - Large int32_t values can't be exactly represented as FP32. The
540    * conversion instruction in ARM NEON would round it to nearest FP32 value
541    * with ties to even.
542    * - Product of two FP32 values is generally not exactly representation as
543    * an FP32 value, and will be rounded to nearest FP32 value with ties to
544    * even.
545    */
546   const float32x4_t vacc0x0123_f =
547     vmulq_f32(vcvtq_f32_s32(vacc0x0123), requantization_scale_c0123);
548   const float32x4_t vacc1x0123_f =
549     vmulq_f32(vcvtq_f32_s32(vacc1x0123), requantization_scale_c0123);
550   const float32x4_t vacc2x0123_f =
551     vmulq_f32(vcvtq_f32_s32(vacc2x0123), requantization_scale_c0123);
552   const float32x4_t vacc3x0123_f =
553     vmulq_f32(vcvtq_f32_s32(vacc3x0123), requantization_scale_c0123);
554   const float32x4_t vacc0x4567_f =
555     vmulq_f32(vcvtq_f32_s32(vacc0x4567), requantization_scale_c4567);
556   const float32x4_t vacc1x4567_f =
557     vmulq_f32(vcvtq_f32_s32(vacc1x4567), requantization_scale_c4567);
558   const float32x4_t vacc2x4567_f =
559     vmulq_f32(vcvtq_f32_s32(vacc2x4567), requantization_scale_c4567);
560   const float32x4_t vacc3x4567_f =
561     vmulq_f32(vcvtq_f32_s32(vacc3x4567), requantization_scale_c4567);
562 
563 #ifdef __aarch64__
564   const int16x8_t voutput_zero_point =
565       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
566   /*
567    * Leverage "Floating-point Convert to Signed integer, rounding to nearest
568    * with ties to even" instruction. This is an ARMv8 instruction (always
569    * available in AArch64), which saturates result on overflow. We don't need
570    * to specifically consider saturated results, they will be clamped at the
571    * last stage.
572    */
573   vacc0x0123 = vcvtnq_s32_f32(vacc0x0123_f);
574   vacc1x0123 = vcvtnq_s32_f32(vacc1x0123_f);
575   vacc2x0123 = vcvtnq_s32_f32(vacc2x0123_f);
576   vacc3x0123 = vcvtnq_s32_f32(vacc3x0123_f);
577   vacc0x4567 = vcvtnq_s32_f32(vacc0x4567_f);
578   vacc1x4567 = vcvtnq_s32_f32(vacc1x4567_f);
579   vacc2x4567 = vcvtnq_s32_f32(vacc2x4567_f);
580   vacc3x4567 = vcvtnq_s32_f32(vacc3x4567_f);
581 
582   const int16x8_t vacc0x01234567 = vqaddq_s16(
583       vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
584   const int16x8_t vacc1x01234567 = vqaddq_s16(
585       vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
586   const int16x8_t vacc2x01234567 = vqaddq_s16(
587       vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
588   const int16x8_t vacc3x01234567 = vqaddq_s16(
589       vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
590 
591   uint8x16_t vout0x01234567_1x01234567 =
592       vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
593   uint8x16_t vout2x01234567_3x01234567 =
594       vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
595 
596   const uint8x16_t voutput_min =
597       vld1q_dup_u8(&quantization_params->neon.output_min);
598   const uint8x16_t voutput_max =
599       vld1q_dup_u8(&quantization_params->neon.output_max);
600 
601   vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min);
602   vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min);
603   vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max);
604   vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max);
605 #else
606   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
607   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
608   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
609   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
610   /*
611    * ARMv7 NEON offers only a floating-point to integer conversion instruction
612    * with rounding towards zero. In lieu of conversion instruction with
613    * rounding-to-nearest-even, we use a magic trick of adding a large number
614    * (1.5 * 2**23) to scaled value to cause rounding to integer, and then
615    * substracing this magic number as integer. This trick works only in a
616    * limited range (absolute value of input must be less than 2**22), so
617    * generally we have to clamp input to this range before using the magic.
618    * However, clamping to any smaller range works just as well, and thus we
619    * clamp to [qmin - zero point, qmax - zero point] range so that after we
620    * add zero point to the result, it gets into target [qmin, qmax] range.
621    */
622   const float32x4_t vacc0x0123_f_clamped =
623       vminq_f32(vmaxq_f32(vacc0x0123_f, vfmin), vfmax);
624   const float32x4_t vacc1x0123_f_clamped =
625       vminq_f32(vmaxq_f32(vacc1x0123_f, vfmin), vfmax);
626   const float32x4_t vacc2x0123_f_clamped =
627       vminq_f32(vmaxq_f32(vacc2x0123_f, vfmin), vfmax);
628   const float32x4_t vacc3x0123_f_clamped =
629       vminq_f32(vmaxq_f32(vacc3x0123_f, vfmin), vfmax);
630   const float32x4_t vacc0x4567_f_clamped =
631       vminq_f32(vmaxq_f32(vacc0x4567_f, vfmin), vfmax);
632   const float32x4_t vacc1x4567_f_clamped =
633       vminq_f32(vmaxq_f32(vacc1x4567_f, vfmin), vfmax);
634   const float32x4_t vacc2x4567_f_clamped =
635       vminq_f32(vmaxq_f32(vacc2x4567_f, vfmin), vfmax);
636   const float32x4_t vacc3x4567_f_clamped =
637       vminq_f32(vmaxq_f32(vacc3x4567_f, vfmin), vfmax);
638 
639   /*
640    * Conversion to integer using the "magic trick". Rounding is performed in
641    * the output of addition operation, and result is rounded to nearest even
642    * integer with ties to even.
643    */
644   vacc0x0123 = vsubq_s32(
645       vreinterpretq_s32_f32(vaddq_f32(vacc0x0123_f_clamped, vfmagic)), vimagic);
646   vacc1x0123 = vsubq_s32(
647       vreinterpretq_s32_f32(vaddq_f32(vacc1x0123_f_clamped, vfmagic)), vimagic);
648   vacc2x0123 = vsubq_s32(
649       vreinterpretq_s32_f32(vaddq_f32(vacc2x0123_f_clamped, vfmagic)), vimagic);
650   vacc3x0123 = vsubq_s32(
651       vreinterpretq_s32_f32(vaddq_f32(vacc3x0123_f_clamped, vfmagic)), vimagic);
652   vacc0x4567 = vsubq_s32(
653       vreinterpretq_s32_f32(vaddq_f32(vacc0x4567_f_clamped, vfmagic)), vimagic);
654   vacc1x4567 = vsubq_s32(
655       vreinterpretq_s32_f32(vaddq_f32(vacc1x4567_f_clamped, vfmagic)), vimagic);
656   vacc2x4567 = vsubq_s32(
657       vreinterpretq_s32_f32(vaddq_f32(vacc2x4567_f_clamped, vfmagic)), vimagic);
658   vacc3x4567 = vsubq_s32(
659       vreinterpretq_s32_f32(vaddq_f32(vacc3x4567_f_clamped, vfmagic)), vimagic);
660 
661   const int16x8_t vacc0x01234567 =
662       vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567));
663   const int16x8_t vacc1x01234567 =
664       vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567));
665   const int16x8_t vacc2x01234567 =
666       vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567));
667   const int16x8_t vacc3x01234567 =
668       vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567));
669 
670   uint8x16_t vout0x01234567_1x01234567 =
671       vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
672   uint8x16_t vout2x01234567_3x01234567 =
673       vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
674 #endif
675 
676   uint8_t* c0 = c;
677   uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
678   if (mr < 2) {
679     c1 = c0;
680   }
681   uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
682   if (mr <= 2) {
683     c2 = c1;
684   }
685   uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
686   if (mr != 4) {
687     c3 = c2;
688   }
689   if (nr == 8) {
690     vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567));
691     vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567));
692     vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567));
693     vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567));
694   } else {
695     if (nr >= 4) {
696       vst1q_lane_u32(
697           __builtin_assume_aligned(c0, 1),
698           vreinterpretq_u32_u8(vout0x01234567_1x01234567),
699           0);
700       c0 += 4;
701       vst1q_lane_u32(
702           __builtin_assume_aligned(c1, 1),
703           vreinterpretq_u32_u8(vout0x01234567_1x01234567),
704           2);
705       c1 += 4;
706       vst1q_lane_u32(
707           __builtin_assume_aligned(c2, 1),
708           vreinterpretq_u32_u8(vout2x01234567_3x01234567),
709           0);
710       c2 += 4;
711       vst1q_lane_u32(
712           __builtin_assume_aligned(c3, 1),
713           vreinterpretq_u32_u8(vout2x01234567_3x01234567),
714           2);
715       c3 += 4;
716       vout0x01234567_1x01234567 =
717           vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
718       vout2x01234567_3x01234567 =
719           vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
720       nr -= 4;
721     }
722     if (nr >= 2) {
723       vst1q_lane_u16(
724           __builtin_assume_aligned(c0, 1),
725           vreinterpretq_u16_u8(vout0x01234567_1x01234567),
726           0);
727       c0 += 2;
728       vst1q_lane_u16(
729           __builtin_assume_aligned(c1, 1),
730           vreinterpretq_u16_u8(vout0x01234567_1x01234567),
731           4);
732       c1 += 2;
733       vst1q_lane_u16(
734           __builtin_assume_aligned(c2, 1),
735           vreinterpretq_u16_u8(vout2x01234567_3x01234567),
736           0);
737       c2 += 2;
738       vst1q_lane_u16(
739           __builtin_assume_aligned(c3, 1),
740           vreinterpretq_u16_u8(vout2x01234567_3x01234567),
741           4);
742       c3 += 2;
743       vout0x01234567_1x01234567 =
744           vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
745       vout2x01234567_3x01234567 =
746           vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
747       nr -= 2;
748     }
749     if (nr != 0) {
750       vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
751       vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
752       vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
753       vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
754     }
755   }
756 }
757