xref: /aosp_15_r20/external/XNNPACK/src/qs8-igemm/gen/3x16c2s4-minmax-rndnu-neon-mlal.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Auto-generated file. Do not edit!
2 //   Template: src/qs8-igemm/c2-neon-mull-shuffle.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2021 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/gemm.h>
15 #include <xnnpack/math.h>
16 
17 
xnn_qs8_igemm_minmax_rndnu_ukernel_3x16c2s4__neon_mlal(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_qs8_igemm_minmax_rndnu_ukernel_3x16c2s4__neon_mlal(
19     size_t mr,
20     size_t nc,
21     size_t kc,
22     size_t ks,
23     const int8_t** restrict a,
24     const void* restrict w,
25     int8_t* restrict c,
26     size_t cm_stride,
27     size_t cn_stride,
28     size_t a_offset,
29     const int8_t* zero,
30     const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
31 {
32   assert(mr != 0);
33   assert(mr <= 3);
34   assert(nc != 0);
35   assert(kc != 0);
36   assert(ks != 0);
37   assert(ks % (3 * sizeof(void*)) == 0);
38   assert(a_offset % sizeof(int8_t) == 0);
39   assert(a != NULL);
40   assert(w != NULL);
41   assert(c != NULL);
42 
43   int8_t* c0 = c;
44   int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
45   if XNN_UNPREDICTABLE(mr < 2) {
46     c1 = c0;
47   }
48   int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
49   if XNN_UNPREDICTABLE(mr <= 2) {
50     c2 = c1;
51   }
52 
53   kc = round_up_po2(kc, 8 * sizeof(int8_t));
54   do {
55     int32x4_t vacc0x0123 = vld1q_s32(w); w = (const int32_t*) w + 4;
56     int32x4_t vacc0x4567 = vld1q_s32(w); w = (const int32_t*) w + 4;
57     int32x4_t vacc0x89AB = vld1q_s32(w); w = (const int32_t*) w + 4;
58     int32x4_t vacc0xCDEF = vld1q_s32(w); w = (const int32_t*) w + 4;
59     int32x4_t vacc1x0123 = vacc0x0123;
60     int32x4_t vacc1x4567 = vacc0x4567;
61     int32x4_t vacc1x89AB = vacc0x89AB;
62     int32x4_t vacc1xCDEF = vacc0xCDEF;
63     int32x4_t vacc2x0123 = vacc0x0123;
64     int32x4_t vacc2x4567 = vacc0x4567;
65     int32x4_t vacc2x89AB = vacc0x89AB;
66     int32x4_t vacc2xCDEF = vacc0xCDEF;
67 
68     size_t p = ks;
69     do {
70       const int8_t* restrict a0 = a[0];
71       if XNN_UNPREDICTABLE(a0 != zero) {
72         a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
73       }
74       const int8_t* restrict a1 = a[1];
75       if XNN_UNPREDICTABLE(a1 != zero) {
76         a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
77       }
78       const int8_t* restrict a2 = a[2];
79       if XNN_UNPREDICTABLE(a2 != zero) {
80         a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
81       }
82       a += 3;
83 
84       size_t k = kc;
85       while (k >= 16 * sizeof(int8_t)) {
86         int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
87         int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
88         int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
89         int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
90         int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
91         int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
92 
93         const int8x8_t vb0123c0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
94         const int8x8_t vb4567c0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
95         const int8x8_t vb89ABc0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
96         const int8x8_t vbCDEFc0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
97         const int8x8_t vb0123c1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
98         const int8x8_t vb4567c1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
99         const int8x8_t vb89ABc1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
100         const int8x8_t vbCDEFc1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
101         const int8x8_t vb0123c2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
102         const int8x8_t vb4567c2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
103         const int8x8_t vb89ABc2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
104         const int8x8_t vbCDEFc2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
105         const int8x8_t vb0123c3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
106         const int8x8_t vb4567c3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
107         const int8x8_t vb89ABc3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
108         const int8x8_t vbCDEFc3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
109 
110         int16x8_t vprod0x0123c0 = vmull_s8(vb0123c0x0, va0x0);
111         int16x8_t vprod1x0123c0 = vmull_s8(vb0123c0x0, va1x0);
112         int16x8_t vprod2x0123c0 = vmull_s8(vb0123c0x0, va2x0);
113         const int8x8_t vb0123c0x1 = vld1_s8(w); w = (const int8_t*) w + 8;
114         vprod0x0123c0 = vmlal_s8(vprod0x0123c0, vb0123c0x1, va0x1);
115         vprod1x0123c0 = vmlal_s8(vprod1x0123c0, vb0123c0x1, va1x1);
116         vprod2x0123c0 = vmlal_s8(vprod2x0123c0, vb0123c0x1, va2x1);
117         vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c0);
118         vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c0);
119         vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c0);
120         int16x8_t vprod0x4567c0 = vmull_s8(vb4567c0x0, va0x0);
121         int16x8_t vprod1x4567c0 = vmull_s8(vb4567c0x0, va1x0);
122         int16x8_t vprod2x4567c0 = vmull_s8(vb4567c0x0, va2x0);
123         const int8x8_t vb4567c0x1 = vld1_s8(w); w = (const int8_t*) w + 8;
124         vprod0x4567c0 = vmlal_s8(vprod0x4567c0, vb4567c0x1, va0x1);
125         vprod1x4567c0 = vmlal_s8(vprod1x4567c0, vb4567c0x1, va1x1);
126         vprod2x4567c0 = vmlal_s8(vprod2x4567c0, vb4567c0x1, va2x1);
127         vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c0);
128         vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c0);
129         vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c0);
130         int16x8_t vprod0x89ABc0 = vmull_s8(vb89ABc0x0, va0x0);
131         int16x8_t vprod1x89ABc0 = vmull_s8(vb89ABc0x0, va1x0);
132         int16x8_t vprod2x89ABc0 = vmull_s8(vb89ABc0x0, va2x0);
133         const int8x8_t vb89ABc0x1 = vld1_s8(w); w = (const int8_t*) w + 8;
134         vprod0x89ABc0 = vmlal_s8(vprod0x89ABc0, vb89ABc0x1, va0x1);
135         vprod1x89ABc0 = vmlal_s8(vprod1x89ABc0, vb89ABc0x1, va1x1);
136         vprod2x89ABc0 = vmlal_s8(vprod2x89ABc0, vb89ABc0x1, va2x1);
137         vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc0);
138         vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc0);
139         vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc0);
140         int16x8_t vprod0xCDEFc0 = vmull_s8(vbCDEFc0x0, va0x0);
141         int16x8_t vprod1xCDEFc0 = vmull_s8(vbCDEFc0x0, va1x0);
142         int16x8_t vprod2xCDEFc0 = vmull_s8(vbCDEFc0x0, va2x0);
143         const int8x8_t vbCDEFc0x1 = vld1_s8(w); w = (const int8_t*) w + 8;
144         vprod0xCDEFc0 = vmlal_s8(vprod0xCDEFc0, vbCDEFc0x1, va0x1);
145         vprod1xCDEFc0 = vmlal_s8(vprod1xCDEFc0, vbCDEFc0x1, va1x1);
146         vprod2xCDEFc0 = vmlal_s8(vprod2xCDEFc0, vbCDEFc0x1, va2x1);
147         vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc0);
148         vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc0);
149         vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc0);
150         va0x0 = vext_s8(va0x0, va0x0, 2);
151         va0x1 = vext_s8(va0x1, va0x1, 2);
152         va1x0 = vext_s8(va1x0, va1x0, 2);
153         va1x1 = vext_s8(va1x1, va1x1, 2);
154         va2x0 = vext_s8(va2x0, va2x0, 2);
155         va2x1 = vext_s8(va2x1, va2x1, 2);
156         int16x8_t vprod0x0123c1 = vmull_s8(vb0123c1x0, va0x0);
157         int16x8_t vprod1x0123c1 = vmull_s8(vb0123c1x0, va1x0);
158         int16x8_t vprod2x0123c1 = vmull_s8(vb0123c1x0, va2x0);
159         const int8x8_t vb0123c1x1 = vld1_s8(w); w = (const int8_t*) w + 8;
160         vprod0x0123c1 = vmlal_s8(vprod0x0123c1, vb0123c1x1, va0x1);
161         vprod1x0123c1 = vmlal_s8(vprod1x0123c1, vb0123c1x1, va1x1);
162         vprod2x0123c1 = vmlal_s8(vprod2x0123c1, vb0123c1x1, va2x1);
163         vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c1);
164         vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c1);
165         vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c1);
166         int16x8_t vprod0x4567c1 = vmull_s8(vb4567c1x0, va0x0);
167         int16x8_t vprod1x4567c1 = vmull_s8(vb4567c1x0, va1x0);
168         int16x8_t vprod2x4567c1 = vmull_s8(vb4567c1x0, va2x0);
169         const int8x8_t vb4567c1x1 = vld1_s8(w); w = (const int8_t*) w + 8;
170         vprod0x4567c1 = vmlal_s8(vprod0x4567c1, vb4567c1x1, va0x1);
171         vprod1x4567c1 = vmlal_s8(vprod1x4567c1, vb4567c1x1, va1x1);
172         vprod2x4567c1 = vmlal_s8(vprod2x4567c1, vb4567c1x1, va2x1);
173         vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c1);
174         vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c1);
175         vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c1);
176         int16x8_t vprod0x89ABc1 = vmull_s8(vb89ABc1x0, va0x0);
177         int16x8_t vprod1x89ABc1 = vmull_s8(vb89ABc1x0, va1x0);
178         int16x8_t vprod2x89ABc1 = vmull_s8(vb89ABc1x0, va2x0);
179         const int8x8_t vb89ABc1x1 = vld1_s8(w); w = (const int8_t*) w + 8;
180         vprod0x89ABc1 = vmlal_s8(vprod0x89ABc1, vb89ABc1x1, va0x1);
181         vprod1x89ABc1 = vmlal_s8(vprod1x89ABc1, vb89ABc1x1, va1x1);
182         vprod2x89ABc1 = vmlal_s8(vprod2x89ABc1, vb89ABc1x1, va2x1);
183         vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc1);
184         vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc1);
185         vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc1);
186         int16x8_t vprod0xCDEFc1 = vmull_s8(vbCDEFc1x0, va0x0);
187         int16x8_t vprod1xCDEFc1 = vmull_s8(vbCDEFc1x0, va1x0);
188         int16x8_t vprod2xCDEFc1 = vmull_s8(vbCDEFc1x0, va2x0);
189         const int8x8_t vbCDEFc1x1 = vld1_s8(w); w = (const int8_t*) w + 8;
190         vprod0xCDEFc1 = vmlal_s8(vprod0xCDEFc1, vbCDEFc1x1, va0x1);
191         vprod1xCDEFc1 = vmlal_s8(vprod1xCDEFc1, vbCDEFc1x1, va1x1);
192         vprod2xCDEFc1 = vmlal_s8(vprod2xCDEFc1, vbCDEFc1x1, va2x1);
193         vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc1);
194         vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc1);
195         vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc1);
196         va0x0 = vext_s8(va0x0, va0x0, 2);
197         va0x1 = vext_s8(va0x1, va0x1, 2);
198         va1x0 = vext_s8(va1x0, va1x0, 2);
199         va1x1 = vext_s8(va1x1, va1x1, 2);
200         va2x0 = vext_s8(va2x0, va2x0, 2);
201         va2x1 = vext_s8(va2x1, va2x1, 2);
202         int16x8_t vprod0x0123c2 = vmull_s8(vb0123c2x0, va0x0);
203         int16x8_t vprod1x0123c2 = vmull_s8(vb0123c2x0, va1x0);
204         int16x8_t vprod2x0123c2 = vmull_s8(vb0123c2x0, va2x0);
205         const int8x8_t vb0123c2x1 = vld1_s8(w); w = (const int8_t*) w + 8;
206         vprod0x0123c2 = vmlal_s8(vprod0x0123c2, vb0123c2x1, va0x1);
207         vprod1x0123c2 = vmlal_s8(vprod1x0123c2, vb0123c2x1, va1x1);
208         vprod2x0123c2 = vmlal_s8(vprod2x0123c2, vb0123c2x1, va2x1);
209         vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2);
210         vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2);
211         vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2);
212         int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2x0, va0x0);
213         int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2x0, va1x0);
214         int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2x0, va2x0);
215         const int8x8_t vb4567c2x1 = vld1_s8(w); w = (const int8_t*) w + 8;
216         vprod0x4567c2 = vmlal_s8(vprod0x4567c2, vb4567c2x1, va0x1);
217         vprod1x4567c2 = vmlal_s8(vprod1x4567c2, vb4567c2x1, va1x1);
218         vprod2x4567c2 = vmlal_s8(vprod2x4567c2, vb4567c2x1, va2x1);
219         vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2);
220         vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2);
221         vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2);
222         int16x8_t vprod0x89ABc2 = vmull_s8(vb89ABc2x0, va0x0);
223         int16x8_t vprod1x89ABc2 = vmull_s8(vb89ABc2x0, va1x0);
224         int16x8_t vprod2x89ABc2 = vmull_s8(vb89ABc2x0, va2x0);
225         const int8x8_t vb89ABc2x1 = vld1_s8(w); w = (const int8_t*) w + 8;
226         vprod0x89ABc2 = vmlal_s8(vprod0x89ABc2, vb89ABc2x1, va0x1);
227         vprod1x89ABc2 = vmlal_s8(vprod1x89ABc2, vb89ABc2x1, va1x1);
228         vprod2x89ABc2 = vmlal_s8(vprod2x89ABc2, vb89ABc2x1, va2x1);
229         vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2);
230         vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2);
231         vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2);
232         int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2x0, va0x0);
233         int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2x0, va1x0);
234         int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2x0, va2x0);
235         const int8x8_t vbCDEFc2x1 = vld1_s8(w); w = (const int8_t*) w + 8;
236         vprod0xCDEFc2 = vmlal_s8(vprod0xCDEFc2, vbCDEFc2x1, va0x1);
237         vprod1xCDEFc2 = vmlal_s8(vprod1xCDEFc2, vbCDEFc2x1, va1x1);
238         vprod2xCDEFc2 = vmlal_s8(vprod2xCDEFc2, vbCDEFc2x1, va2x1);
239         vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2);
240         vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2);
241         vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2);
242         va0x0 = vext_s8(va0x0, va0x0, 2);
243         va0x1 = vext_s8(va0x1, va0x1, 2);
244         va1x0 = vext_s8(va1x0, va1x0, 2);
245         va1x1 = vext_s8(va1x1, va1x1, 2);
246         va2x0 = vext_s8(va2x0, va2x0, 2);
247         va2x1 = vext_s8(va2x1, va2x1, 2);
248         int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3x0, va0x0);
249         int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3x0, va1x0);
250         int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3x0, va2x0);
251         const int8x8_t vb0123c3x1 = vld1_s8(w); w = (const int8_t*) w + 8;
252         vprod0x0123c3 = vmlal_s8(vprod0x0123c3, vb0123c3x1, va0x1);
253         vprod1x0123c3 = vmlal_s8(vprod1x0123c3, vb0123c3x1, va1x1);
254         vprod2x0123c3 = vmlal_s8(vprod2x0123c3, vb0123c3x1, va2x1);
255         vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
256         vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
257         vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
258         int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3x0, va0x0);
259         int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3x0, va1x0);
260         int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3x0, va2x0);
261         const int8x8_t vb4567c3x1 = vld1_s8(w); w = (const int8_t*) w + 8;
262         vprod0x4567c3 = vmlal_s8(vprod0x4567c3, vb4567c3x1, va0x1);
263         vprod1x4567c3 = vmlal_s8(vprod1x4567c3, vb4567c3x1, va1x1);
264         vprod2x4567c3 = vmlal_s8(vprod2x4567c3, vb4567c3x1, va2x1);
265         vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
266         vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
267         vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
268         int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3x0, va0x0);
269         int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3x0, va1x0);
270         int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3x0, va2x0);
271         const int8x8_t vb89ABc3x1 = vld1_s8(w); w = (const int8_t*) w + 8;
272         vprod0x89ABc3 = vmlal_s8(vprod0x89ABc3, vb89ABc3x1, va0x1);
273         vprod1x89ABc3 = vmlal_s8(vprod1x89ABc3, vb89ABc3x1, va1x1);
274         vprod2x89ABc3 = vmlal_s8(vprod2x89ABc3, vb89ABc3x1, va2x1);
275         vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
276         vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
277         vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
278         int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3x0, va0x0);
279         int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3x0, va1x0);
280         int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3x0, va2x0);
281         const int8x8_t vbCDEFc3x1 = vld1_s8(w); w = (const int8_t*) w + 8;
282         vprod0xCDEFc3 = vmlal_s8(vprod0xCDEFc3, vbCDEFc3x1, va0x1);
283         vprod1xCDEFc3 = vmlal_s8(vprod1xCDEFc3, vbCDEFc3x1, va1x1);
284         vprod2xCDEFc3 = vmlal_s8(vprod2xCDEFc3, vbCDEFc3x1, va2x1);
285         vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
286         vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
287         vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
288 
289         k -= 16 * sizeof(int8_t);
290       }
291       if (k != 0) {
292         int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
293         int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
294         int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
295 
296         const int8x8_t vb0123c0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
297         const int8x8_t vb4567c0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
298         const int8x8_t vb89ABc0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
299         const int8x8_t vbCDEFc0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
300         const int8x8_t vb0123c1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
301         const int8x8_t vb4567c1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
302         const int8x8_t vb89ABc1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
303         const int8x8_t vbCDEFc1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
304         const int8x8_t vb0123c2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
305         const int8x8_t vb4567c2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
306         const int8x8_t vb89ABc2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
307         const int8x8_t vbCDEFc2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
308         const int8x8_t vb0123c3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
309         const int8x8_t vb4567c3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
310         const int8x8_t vb89ABc3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
311         const int8x8_t vbCDEFc3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
312 
313         int16x8_t vprod0x0123c0 = vmull_s8(vb0123c0x0, va0x0);
314         int16x8_t vprod1x0123c0 = vmull_s8(vb0123c0x0, va1x0);
315         int16x8_t vprod2x0123c0 = vmull_s8(vb0123c0x0, va2x0);
316         vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c0);
317         vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c0);
318         vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c0);
319         int16x8_t vprod0x4567c0 = vmull_s8(vb4567c0x0, va0x0);
320         int16x8_t vprod1x4567c0 = vmull_s8(vb4567c0x0, va1x0);
321         int16x8_t vprod2x4567c0 = vmull_s8(vb4567c0x0, va2x0);
322         vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c0);
323         vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c0);
324         vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c0);
325         int16x8_t vprod0x89ABc0 = vmull_s8(vb89ABc0x0, va0x0);
326         int16x8_t vprod1x89ABc0 = vmull_s8(vb89ABc0x0, va1x0);
327         int16x8_t vprod2x89ABc0 = vmull_s8(vb89ABc0x0, va2x0);
328         vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc0);
329         vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc0);
330         vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc0);
331         int16x8_t vprod0xCDEFc0 = vmull_s8(vbCDEFc0x0, va0x0);
332         int16x8_t vprod1xCDEFc0 = vmull_s8(vbCDEFc0x0, va1x0);
333         int16x8_t vprod2xCDEFc0 = vmull_s8(vbCDEFc0x0, va2x0);
334         vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc0);
335         vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc0);
336         vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc0);
337         va0x0 = vext_s8(va0x0, va0x0, 2);
338         va1x0 = vext_s8(va1x0, va1x0, 2);
339         va2x0 = vext_s8(va2x0, va2x0, 2);
340         int16x8_t vprod0x0123c1 = vmull_s8(vb0123c1x0, va0x0);
341         int16x8_t vprod1x0123c1 = vmull_s8(vb0123c1x0, va1x0);
342         int16x8_t vprod2x0123c1 = vmull_s8(vb0123c1x0, va2x0);
343         vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c1);
344         vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c1);
345         vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c1);
346         int16x8_t vprod0x4567c1 = vmull_s8(vb4567c1x0, va0x0);
347         int16x8_t vprod1x4567c1 = vmull_s8(vb4567c1x0, va1x0);
348         int16x8_t vprod2x4567c1 = vmull_s8(vb4567c1x0, va2x0);
349         vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c1);
350         vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c1);
351         vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c1);
352         int16x8_t vprod0x89ABc1 = vmull_s8(vb89ABc1x0, va0x0);
353         int16x8_t vprod1x89ABc1 = vmull_s8(vb89ABc1x0, va1x0);
354         int16x8_t vprod2x89ABc1 = vmull_s8(vb89ABc1x0, va2x0);
355         vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc1);
356         vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc1);
357         vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc1);
358         int16x8_t vprod0xCDEFc1 = vmull_s8(vbCDEFc1x0, va0x0);
359         int16x8_t vprod1xCDEFc1 = vmull_s8(vbCDEFc1x0, va1x0);
360         int16x8_t vprod2xCDEFc1 = vmull_s8(vbCDEFc1x0, va2x0);
361         vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc1);
362         vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc1);
363         vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc1);
364         va0x0 = vext_s8(va0x0, va0x0, 2);
365         va1x0 = vext_s8(va1x0, va1x0, 2);
366         va2x0 = vext_s8(va2x0, va2x0, 2);
367         int16x8_t vprod0x0123c2 = vmull_s8(vb0123c2x0, va0x0);
368         int16x8_t vprod1x0123c2 = vmull_s8(vb0123c2x0, va1x0);
369         int16x8_t vprod2x0123c2 = vmull_s8(vb0123c2x0, va2x0);
370         vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2);
371         vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2);
372         vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2);
373         int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2x0, va0x0);
374         int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2x0, va1x0);
375         int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2x0, va2x0);
376         vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2);
377         vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2);
378         vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2);
379         int16x8_t vprod0x89ABc2 = vmull_s8(vb89ABc2x0, va0x0);
380         int16x8_t vprod1x89ABc2 = vmull_s8(vb89ABc2x0, va1x0);
381         int16x8_t vprod2x89ABc2 = vmull_s8(vb89ABc2x0, va2x0);
382         vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc2);
383         vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc2);
384         vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc2);
385         int16x8_t vprod0xCDEFc2 = vmull_s8(vbCDEFc2x0, va0x0);
386         int16x8_t vprod1xCDEFc2 = vmull_s8(vbCDEFc2x0, va1x0);
387         int16x8_t vprod2xCDEFc2 = vmull_s8(vbCDEFc2x0, va2x0);
388         vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc2);
389         vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc2);
390         vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc2);
391         va0x0 = vext_s8(va0x0, va0x0, 2);
392         va1x0 = vext_s8(va1x0, va1x0, 2);
393         va2x0 = vext_s8(va2x0, va2x0, 2);
394         int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3x0, va0x0);
395         int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3x0, va1x0);
396         int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3x0, va2x0);
397         vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
398         vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
399         vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
400         int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3x0, va0x0);
401         int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3x0, va1x0);
402         int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3x0, va2x0);
403         vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
404         vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
405         vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
406         int16x8_t vprod0x89ABc3 = vmull_s8(vb89ABc3x0, va0x0);
407         int16x8_t vprod1x89ABc3 = vmull_s8(vb89ABc3x0, va1x0);
408         int16x8_t vprod2x89ABc3 = vmull_s8(vb89ABc3x0, va2x0);
409         vacc0x89AB = vpadalq_s16(vacc0x89AB, vprod0x89ABc3);
410         vacc1x89AB = vpadalq_s16(vacc1x89AB, vprod1x89ABc3);
411         vacc2x89AB = vpadalq_s16(vacc2x89AB, vprod2x89ABc3);
412         int16x8_t vprod0xCDEFc3 = vmull_s8(vbCDEFc3x0, va0x0);
413         int16x8_t vprod1xCDEFc3 = vmull_s8(vbCDEFc3x0, va1x0);
414         int16x8_t vprod2xCDEFc3 = vmull_s8(vbCDEFc3x0, va2x0);
415         vacc0xCDEF = vpadalq_s16(vacc0xCDEF, vprod0xCDEFc3);
416         vacc1xCDEF = vpadalq_s16(vacc1xCDEF, vprod1xCDEFc3);
417         vacc2xCDEF = vpadalq_s16(vacc2xCDEF, vprod2xCDEFc3);
418 
419       }
420 
421       p -= 3 * sizeof(void*);
422     } while (p != 0);
423 
424     const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->rndnu_neon.right_pre_shift);
425     const int32x4_t vmultiplier = vld1q_dup_s32(&params->rndnu_neon.multiplier);
426     const int32x4_t vright_post_shift = vld1q_dup_s32(&params->rndnu_neon.right_post_shift);
427 
428     vacc0x0123 = vqshlq_s32(vacc0x0123, vright_pre_shift);
429     vacc0x4567 = vqshlq_s32(vacc0x4567, vright_pre_shift);
430     vacc0x89AB = vqshlq_s32(vacc0x89AB, vright_pre_shift);
431     vacc0xCDEF = vqshlq_s32(vacc0xCDEF, vright_pre_shift);
432     vacc1x0123 = vqshlq_s32(vacc1x0123, vright_pre_shift);
433     vacc1x4567 = vqshlq_s32(vacc1x4567, vright_pre_shift);
434     vacc1x89AB = vqshlq_s32(vacc1x89AB, vright_pre_shift);
435     vacc1xCDEF = vqshlq_s32(vacc1xCDEF, vright_pre_shift);
436     vacc2x0123 = vqshlq_s32(vacc2x0123, vright_pre_shift);
437     vacc2x4567 = vqshlq_s32(vacc2x4567, vright_pre_shift);
438     vacc2x89AB = vqshlq_s32(vacc2x89AB, vright_pre_shift);
439     vacc2xCDEF = vqshlq_s32(vacc2xCDEF, vright_pre_shift);
440 
441     vacc0x0123 = vqdmulhq_s32(vacc0x0123, vmultiplier);
442     vacc0x4567 = vqdmulhq_s32(vacc0x4567, vmultiplier);
443     vacc0x89AB = vqdmulhq_s32(vacc0x89AB, vmultiplier);
444     vacc0xCDEF = vqdmulhq_s32(vacc0xCDEF, vmultiplier);
445     vacc1x0123 = vqdmulhq_s32(vacc1x0123, vmultiplier);
446     vacc1x4567 = vqdmulhq_s32(vacc1x4567, vmultiplier);
447     vacc1x89AB = vqdmulhq_s32(vacc1x89AB, vmultiplier);
448     vacc1xCDEF = vqdmulhq_s32(vacc1xCDEF, vmultiplier);
449     vacc2x0123 = vqdmulhq_s32(vacc2x0123, vmultiplier);
450     vacc2x4567 = vqdmulhq_s32(vacc2x4567, vmultiplier);
451     vacc2x89AB = vqdmulhq_s32(vacc2x89AB, vmultiplier);
452     vacc2xCDEF = vqdmulhq_s32(vacc2xCDEF, vmultiplier);
453 
454     vacc0x0123 = vrshlq_s32(vacc0x0123, vright_post_shift);
455     vacc0x4567 = vrshlq_s32(vacc0x4567, vright_post_shift);
456     vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_post_shift);
457     vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_post_shift);
458     vacc1x0123 = vrshlq_s32(vacc1x0123, vright_post_shift);
459     vacc1x4567 = vrshlq_s32(vacc1x4567, vright_post_shift);
460     vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_post_shift);
461     vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_post_shift);
462     vacc2x0123 = vrshlq_s32(vacc2x0123, vright_post_shift);
463     vacc2x4567 = vrshlq_s32(vacc2x4567, vright_post_shift);
464     vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_post_shift);
465     vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_post_shift);
466 
467     const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->rndnu_neon.output_zero_point);
468 #if XNN_ARCH_ARM64
469     int16x8_t vacc0x01234567 = vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567);
470     int16x8_t vacc0x89ABCDEF = vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF);
471     int16x8_t vacc1x01234567 = vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567);
472     int16x8_t vacc1x89ABCDEF = vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF);
473     int16x8_t vacc2x01234567 = vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567);
474     int16x8_t vacc2x89ABCDEF = vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF);
475 
476     vacc0x01234567 = vqaddq_s16(vacc0x01234567, voutput_zero_point);
477     vacc0x89ABCDEF = vqaddq_s16(vacc0x89ABCDEF, voutput_zero_point);
478     vacc1x01234567 = vqaddq_s16(vacc1x01234567, voutput_zero_point);
479     vacc1x89ABCDEF = vqaddq_s16(vacc1x89ABCDEF, voutput_zero_point);
480     vacc2x01234567 = vqaddq_s16(vacc2x01234567, voutput_zero_point);
481     vacc2x89ABCDEF = vqaddq_s16(vacc2x89ABCDEF, voutput_zero_point);
482 
483     int8x16_t vout0x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc0x89ABCDEF);
484     int8x16_t vout1x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc1x01234567), vacc1x89ABCDEF);
485     int8x16_t vout2x0123456789ABCDEF = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc2x89ABCDEF);
486 #else
487     int16x8_t vacc0x01234567 = vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567));
488     int16x8_t vacc0x89ABCDEF = vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF));
489     int16x8_t vacc1x01234567 = vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567));
490     int16x8_t vacc1x89ABCDEF = vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF));
491     int16x8_t vacc2x01234567 = vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567));
492     int16x8_t vacc2x89ABCDEF = vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF));
493 
494     vacc0x01234567 = vqaddq_s16(vacc0x01234567, voutput_zero_point);
495     vacc0x89ABCDEF = vqaddq_s16(vacc0x89ABCDEF, voutput_zero_point);
496     vacc1x01234567 = vqaddq_s16(vacc1x01234567, voutput_zero_point);
497     vacc1x89ABCDEF = vqaddq_s16(vacc1x89ABCDEF, voutput_zero_point);
498     vacc2x01234567 = vqaddq_s16(vacc2x01234567, voutput_zero_point);
499     vacc2x89ABCDEF = vqaddq_s16(vacc2x89ABCDEF, voutput_zero_point);
500 
501     int8x16_t vout0x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc0x89ABCDEF));
502     int8x16_t vout1x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc1x01234567), vqmovn_s16(vacc1x89ABCDEF));
503     int8x16_t vout2x0123456789ABCDEF = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc2x89ABCDEF));
504 #endif
505 
506     const int8x16_t voutput_min = vld1q_dup_s8(&params->rndnu_neon.output_min);
507     vout0x0123456789ABCDEF = vmaxq_s8(vout0x0123456789ABCDEF, voutput_min);
508     vout1x0123456789ABCDEF = vmaxq_s8(vout1x0123456789ABCDEF, voutput_min);
509     vout2x0123456789ABCDEF = vmaxq_s8(vout2x0123456789ABCDEF, voutput_min);
510 
511     const int8x16_t voutput_max = vld1q_dup_s8(&params->rndnu_neon.output_max);
512     vout0x0123456789ABCDEF = vminq_s8(vout0x0123456789ABCDEF, voutput_max);
513     vout1x0123456789ABCDEF = vminq_s8(vout1x0123456789ABCDEF, voutput_max);
514     vout2x0123456789ABCDEF = vminq_s8(vout2x0123456789ABCDEF, voutput_max);
515 
516     if (nc >= 16) {
517       vst1q_s8(c2 + 0, vout2x0123456789ABCDEF);
518       vst1q_s8(c1 + 0, vout1x0123456789ABCDEF);
519       vst1q_s8(c0 + 0, vout0x0123456789ABCDEF);
520 
521       c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
522       c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
523       c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
524 
525       a = (const int8_t**restrict) ((uintptr_t) a - ks);
526 
527       nc -= 16;
528     } else {
529       int8x8_t vout2x01234567 = vget_low_s8(vout2x0123456789ABCDEF);
530       int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vget_low_s8(vout0x0123456789ABCDEF), vget_low_s8(vout1x0123456789ABCDEF));
531       if (nc & 8) {
532         vst1_s8(c2, vout2x01234567); c2 += 8;
533         vst1_s8(c1, vget_high_s8(vout0x01234567_1x01234567)); c1 += 8;
534         vst1_s8(c0, vget_low_s8(vout0x01234567_1x01234567)); c0 += 8;
535         vout2x01234567 = vget_high_s8(vout2x0123456789ABCDEF);
536         vout0x01234567_1x01234567 = vcombine_s8(vget_high_s8(vout0x0123456789ABCDEF), vget_high_s8(vout1x0123456789ABCDEF));
537       }
538       if (nc & 4) {
539         vst1_lane_u32((void*) c2, vreinterpret_u32_s8(vout2x01234567), 0); c2 += 4;
540         vst1q_lane_u32((void*) c1, vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
541         vst1q_lane_u32((void*) c0, vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
542         vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 4);
543         vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
544       }
545       if (nc & 2) {
546         vst1_lane_u16((void*) c2, vreinterpret_u16_s8(vout2x01234567), 0); c2 += 2;
547         vst1q_lane_u16((void*) c1, vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
548         vst1q_lane_u16((void*) c0, vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
549         vout2x01234567 = vext_s8(vout2x01234567, vout2x01234567, 2);
550         vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
551       }
552       if (nc & 1) {
553         vst1_lane_s8(c2, vout2x01234567, 0);
554         vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
555         vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
556       }
557 
558       nc = 0;
559     }
560   } while (nc != 0);
561 }
562