xref: /aosp_15_r20/external/XNNPACK/src/qs8-gemm/gen/4x8c2s4-minmax-rndnu-neon-mlal.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Auto-generated file. Do not edit!
2 //   Template: src/qs8-gemm/c2-neon-mull-shuffle.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2021 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/gemm.h>
15 #include <xnnpack/math.h>
16 
17 
xnn_qs8_gemm_minmax_rndnu_ukernel_4x8c2s4__neon_mlal(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_qs8_gemm_minmax_rndnu_ukernel_4x8c2s4__neon_mlal(
19     size_t mr,
20     size_t nc,
21     size_t kc,
22     const int8_t* restrict a,
23     size_t a_stride,
24     const void* restrict w,
25     int8_t* restrict c,
26     size_t cm_stride,
27     size_t cn_stride,
28     const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
29 {
30   assert(mr != 0);
31   assert(mr <= 4);
32   assert(nc != 0);
33   assert(kc != 0);
34   assert(kc % sizeof(int8_t) == 0);
35   assert(a != NULL);
36   assert(w != NULL);
37   assert(c != NULL);
38 
39   const int8_t* a0 = a;
40   int8_t* c0 = c;
41   const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
42   int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
43   if XNN_UNPREDICTABLE(mr < 2) {
44     a1 = a0;
45     c1 = c0;
46   }
47   const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
48   int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
49   if XNN_UNPREDICTABLE(mr <= 2) {
50     a2 = a1;
51     c2 = c1;
52   }
53   const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride);
54   int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
55   if XNN_UNPREDICTABLE(mr != 4) {
56     a3 = a2;
57     c3 = c2;
58   }
59 
60   kc = round_up_po2(kc, 8 * sizeof(int8_t));
61   do {
62     int32x4_t vacc0x0123 = vld1q_s32(w); w = (const int32_t*) w + 4;
63     int32x4_t vacc0x4567 = vld1q_s32(w); w = (const int32_t*) w + 4;
64     int32x4_t vacc1x0123 = vacc0x0123;
65     int32x4_t vacc1x4567 = vacc0x4567;
66     int32x4_t vacc2x0123 = vacc0x0123;
67     int32x4_t vacc2x4567 = vacc0x4567;
68     int32x4_t vacc3x0123 = vacc0x0123;
69     int32x4_t vacc3x4567 = vacc0x4567;
70 
71     size_t k = kc;
72     while (k >= 16 * sizeof(int8_t)) {
73       int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
74       int8x8_t va0x1 = vld1_s8(a0); a0 += 8;
75       int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
76       int8x8_t va1x1 = vld1_s8(a1); a1 += 8;
77       int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
78       int8x8_t va2x1 = vld1_s8(a2); a2 += 8;
79       int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
80       int8x8_t va3x1 = vld1_s8(a3); a3 += 8;
81 
82       const int8x8_t vb0123c0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
83       const int8x8_t vb4567c0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
84       const int8x8_t vb0123c1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
85       const int8x8_t vb4567c1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
86       const int8x8_t vb0123c2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
87       const int8x8_t vb4567c2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
88       const int8x8_t vb0123c3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
89       const int8x8_t vb4567c3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
90 
91       int16x8_t vprod0x0123c0 = vmull_s8(vb0123c0x0, va0x0);
92       int16x8_t vprod1x0123c0 = vmull_s8(vb0123c0x0, va1x0);
93       int16x8_t vprod2x0123c0 = vmull_s8(vb0123c0x0, va2x0);
94       int16x8_t vprod3x0123c0 = vmull_s8(vb0123c0x0, va3x0);
95       const int8x8_t vb0123c0x1 = vld1_s8(w); w = (const int8_t*) w + 8;
96       vprod0x0123c0 = vmlal_s8(vprod0x0123c0, vb0123c0x1, va0x1);
97       vprod1x0123c0 = vmlal_s8(vprod1x0123c0, vb0123c0x1, va1x1);
98       vprod2x0123c0 = vmlal_s8(vprod2x0123c0, vb0123c0x1, va2x1);
99       vprod3x0123c0 = vmlal_s8(vprod3x0123c0, vb0123c0x1, va3x1);
100       vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c0);
101       vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c0);
102       vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c0);
103       vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c0);
104       int16x8_t vprod0x4567c0 = vmull_s8(vb4567c0x0, va0x0);
105       int16x8_t vprod1x4567c0 = vmull_s8(vb4567c0x0, va1x0);
106       int16x8_t vprod2x4567c0 = vmull_s8(vb4567c0x0, va2x0);
107       int16x8_t vprod3x4567c0 = vmull_s8(vb4567c0x0, va3x0);
108       const int8x8_t vb4567c0x1 = vld1_s8(w); w = (const int8_t*) w + 8;
109       vprod0x4567c0 = vmlal_s8(vprod0x4567c0, vb4567c0x1, va0x1);
110       vprod1x4567c0 = vmlal_s8(vprod1x4567c0, vb4567c0x1, va1x1);
111       vprod2x4567c0 = vmlal_s8(vprod2x4567c0, vb4567c0x1, va2x1);
112       vprod3x4567c0 = vmlal_s8(vprod3x4567c0, vb4567c0x1, va3x1);
113       vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c0);
114       vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c0);
115       vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c0);
116       vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c0);
117       va0x0 = vext_s8(va0x0, va0x0, 2);
118       va0x1 = vext_s8(va0x1, va0x1, 2);
119       va1x0 = vext_s8(va1x0, va1x0, 2);
120       va1x1 = vext_s8(va1x1, va1x1, 2);
121       va2x0 = vext_s8(va2x0, va2x0, 2);
122       va2x1 = vext_s8(va2x1, va2x1, 2);
123       va3x0 = vext_s8(va3x0, va3x0, 2);
124       va3x1 = vext_s8(va3x1, va3x1, 2);
125       int16x8_t vprod0x0123c1 = vmull_s8(vb0123c1x0, va0x0);
126       int16x8_t vprod1x0123c1 = vmull_s8(vb0123c1x0, va1x0);
127       int16x8_t vprod2x0123c1 = vmull_s8(vb0123c1x0, va2x0);
128       int16x8_t vprod3x0123c1 = vmull_s8(vb0123c1x0, va3x0);
129       const int8x8_t vb0123c1x1 = vld1_s8(w); w = (const int8_t*) w + 8;
130       vprod0x0123c1 = vmlal_s8(vprod0x0123c1, vb0123c1x1, va0x1);
131       vprod1x0123c1 = vmlal_s8(vprod1x0123c1, vb0123c1x1, va1x1);
132       vprod2x0123c1 = vmlal_s8(vprod2x0123c1, vb0123c1x1, va2x1);
133       vprod3x0123c1 = vmlal_s8(vprod3x0123c1, vb0123c1x1, va3x1);
134       vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c1);
135       vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c1);
136       vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c1);
137       vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c1);
138       int16x8_t vprod0x4567c1 = vmull_s8(vb4567c1x0, va0x0);
139       int16x8_t vprod1x4567c1 = vmull_s8(vb4567c1x0, va1x0);
140       int16x8_t vprod2x4567c1 = vmull_s8(vb4567c1x0, va2x0);
141       int16x8_t vprod3x4567c1 = vmull_s8(vb4567c1x0, va3x0);
142       const int8x8_t vb4567c1x1 = vld1_s8(w); w = (const int8_t*) w + 8;
143       vprod0x4567c1 = vmlal_s8(vprod0x4567c1, vb4567c1x1, va0x1);
144       vprod1x4567c1 = vmlal_s8(vprod1x4567c1, vb4567c1x1, va1x1);
145       vprod2x4567c1 = vmlal_s8(vprod2x4567c1, vb4567c1x1, va2x1);
146       vprod3x4567c1 = vmlal_s8(vprod3x4567c1, vb4567c1x1, va3x1);
147       vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c1);
148       vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c1);
149       vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c1);
150       vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c1);
151       va0x0 = vext_s8(va0x0, va0x0, 2);
152       va0x1 = vext_s8(va0x1, va0x1, 2);
153       va1x0 = vext_s8(va1x0, va1x0, 2);
154       va1x1 = vext_s8(va1x1, va1x1, 2);
155       va2x0 = vext_s8(va2x0, va2x0, 2);
156       va2x1 = vext_s8(va2x1, va2x1, 2);
157       va3x0 = vext_s8(va3x0, va3x0, 2);
158       va3x1 = vext_s8(va3x1, va3x1, 2);
159       int16x8_t vprod0x0123c2 = vmull_s8(vb0123c2x0, va0x0);
160       int16x8_t vprod1x0123c2 = vmull_s8(vb0123c2x0, va1x0);
161       int16x8_t vprod2x0123c2 = vmull_s8(vb0123c2x0, va2x0);
162       int16x8_t vprod3x0123c2 = vmull_s8(vb0123c2x0, va3x0);
163       const int8x8_t vb0123c2x1 = vld1_s8(w); w = (const int8_t*) w + 8;
164       vprod0x0123c2 = vmlal_s8(vprod0x0123c2, vb0123c2x1, va0x1);
165       vprod1x0123c2 = vmlal_s8(vprod1x0123c2, vb0123c2x1, va1x1);
166       vprod2x0123c2 = vmlal_s8(vprod2x0123c2, vb0123c2x1, va2x1);
167       vprod3x0123c2 = vmlal_s8(vprod3x0123c2, vb0123c2x1, va3x1);
168       vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2);
169       vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2);
170       vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2);
171       vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2);
172       int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2x0, va0x0);
173       int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2x0, va1x0);
174       int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2x0, va2x0);
175       int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2x0, va3x0);
176       const int8x8_t vb4567c2x1 = vld1_s8(w); w = (const int8_t*) w + 8;
177       vprod0x4567c2 = vmlal_s8(vprod0x4567c2, vb4567c2x1, va0x1);
178       vprod1x4567c2 = vmlal_s8(vprod1x4567c2, vb4567c2x1, va1x1);
179       vprod2x4567c2 = vmlal_s8(vprod2x4567c2, vb4567c2x1, va2x1);
180       vprod3x4567c2 = vmlal_s8(vprod3x4567c2, vb4567c2x1, va3x1);
181       vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2);
182       vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2);
183       vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2);
184       vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2);
185       va0x0 = vext_s8(va0x0, va0x0, 2);
186       va0x1 = vext_s8(va0x1, va0x1, 2);
187       va1x0 = vext_s8(va1x0, va1x0, 2);
188       va1x1 = vext_s8(va1x1, va1x1, 2);
189       va2x0 = vext_s8(va2x0, va2x0, 2);
190       va2x1 = vext_s8(va2x1, va2x1, 2);
191       va3x0 = vext_s8(va3x0, va3x0, 2);
192       va3x1 = vext_s8(va3x1, va3x1, 2);
193       int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3x0, va0x0);
194       int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3x0, va1x0);
195       int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3x0, va2x0);
196       int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3x0, va3x0);
197       const int8x8_t vb0123c3x1 = vld1_s8(w); w = (const int8_t*) w + 8;
198       vprod0x0123c3 = vmlal_s8(vprod0x0123c3, vb0123c3x1, va0x1);
199       vprod1x0123c3 = vmlal_s8(vprod1x0123c3, vb0123c3x1, va1x1);
200       vprod2x0123c3 = vmlal_s8(vprod2x0123c3, vb0123c3x1, va2x1);
201       vprod3x0123c3 = vmlal_s8(vprod3x0123c3, vb0123c3x1, va3x1);
202       vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
203       vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
204       vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
205       vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
206       int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3x0, va0x0);
207       int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3x0, va1x0);
208       int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3x0, va2x0);
209       int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3x0, va3x0);
210       const int8x8_t vb4567c3x1 = vld1_s8(w); w = (const int8_t*) w + 8;
211       vprod0x4567c3 = vmlal_s8(vprod0x4567c3, vb4567c3x1, va0x1);
212       vprod1x4567c3 = vmlal_s8(vprod1x4567c3, vb4567c3x1, va1x1);
213       vprod2x4567c3 = vmlal_s8(vprod2x4567c3, vb4567c3x1, va2x1);
214       vprod3x4567c3 = vmlal_s8(vprod3x4567c3, vb4567c3x1, va3x1);
215       vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
216       vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
217       vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
218       vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
219 
220       k -= 16 * sizeof(int8_t);
221     }
222     if (k != 0) {
223       int8x8_t va0x0 = vld1_s8(a0); a0 += 8;
224       int8x8_t va1x0 = vld1_s8(a1); a1 += 8;
225       int8x8_t va2x0 = vld1_s8(a2); a2 += 8;
226       int8x8_t va3x0 = vld1_s8(a3); a3 += 8;
227 
228       const int8x8_t vb0123c0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
229       const int8x8_t vb4567c0x0 = vld1_s8(w); w = (const int8_t*) w + 8;
230       const int8x8_t vb0123c1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
231       const int8x8_t vb4567c1x0 = vld1_s8(w); w = (const int8_t*) w + 8;
232       const int8x8_t vb0123c2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
233       const int8x8_t vb4567c2x0 = vld1_s8(w); w = (const int8_t*) w + 8;
234       const int8x8_t vb0123c3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
235       const int8x8_t vb4567c3x0 = vld1_s8(w); w = (const int8_t*) w + 8;
236 
237       int16x8_t vprod0x0123c0 = vmull_s8(vb0123c0x0, va0x0);
238       int16x8_t vprod1x0123c0 = vmull_s8(vb0123c0x0, va1x0);
239       int16x8_t vprod2x0123c0 = vmull_s8(vb0123c0x0, va2x0);
240       int16x8_t vprod3x0123c0 = vmull_s8(vb0123c0x0, va3x0);
241       vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c0);
242       vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c0);
243       vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c0);
244       vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c0);
245       int16x8_t vprod0x4567c0 = vmull_s8(vb4567c0x0, va0x0);
246       int16x8_t vprod1x4567c0 = vmull_s8(vb4567c0x0, va1x0);
247       int16x8_t vprod2x4567c0 = vmull_s8(vb4567c0x0, va2x0);
248       int16x8_t vprod3x4567c0 = vmull_s8(vb4567c0x0, va3x0);
249       vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c0);
250       vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c0);
251       vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c0);
252       vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c0);
253       va0x0 = vext_s8(va0x0, va0x0, 2);
254       va1x0 = vext_s8(va1x0, va1x0, 2);
255       va2x0 = vext_s8(va2x0, va2x0, 2);
256       va3x0 = vext_s8(va3x0, va3x0, 2);
257       int16x8_t vprod0x0123c1 = vmull_s8(vb0123c1x0, va0x0);
258       int16x8_t vprod1x0123c1 = vmull_s8(vb0123c1x0, va1x0);
259       int16x8_t vprod2x0123c1 = vmull_s8(vb0123c1x0, va2x0);
260       int16x8_t vprod3x0123c1 = vmull_s8(vb0123c1x0, va3x0);
261       vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c1);
262       vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c1);
263       vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c1);
264       vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c1);
265       int16x8_t vprod0x4567c1 = vmull_s8(vb4567c1x0, va0x0);
266       int16x8_t vprod1x4567c1 = vmull_s8(vb4567c1x0, va1x0);
267       int16x8_t vprod2x4567c1 = vmull_s8(vb4567c1x0, va2x0);
268       int16x8_t vprod3x4567c1 = vmull_s8(vb4567c1x0, va3x0);
269       vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c1);
270       vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c1);
271       vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c1);
272       vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c1);
273       va0x0 = vext_s8(va0x0, va0x0, 2);
274       va1x0 = vext_s8(va1x0, va1x0, 2);
275       va2x0 = vext_s8(va2x0, va2x0, 2);
276       va3x0 = vext_s8(va3x0, va3x0, 2);
277       int16x8_t vprod0x0123c2 = vmull_s8(vb0123c2x0, va0x0);
278       int16x8_t vprod1x0123c2 = vmull_s8(vb0123c2x0, va1x0);
279       int16x8_t vprod2x0123c2 = vmull_s8(vb0123c2x0, va2x0);
280       int16x8_t vprod3x0123c2 = vmull_s8(vb0123c2x0, va3x0);
281       vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c2);
282       vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c2);
283       vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c2);
284       vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c2);
285       int16x8_t vprod0x4567c2 = vmull_s8(vb4567c2x0, va0x0);
286       int16x8_t vprod1x4567c2 = vmull_s8(vb4567c2x0, va1x0);
287       int16x8_t vprod2x4567c2 = vmull_s8(vb4567c2x0, va2x0);
288       int16x8_t vprod3x4567c2 = vmull_s8(vb4567c2x0, va3x0);
289       vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c2);
290       vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c2);
291       vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c2);
292       vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c2);
293       va0x0 = vext_s8(va0x0, va0x0, 2);
294       va1x0 = vext_s8(va1x0, va1x0, 2);
295       va2x0 = vext_s8(va2x0, va2x0, 2);
296       va3x0 = vext_s8(va3x0, va3x0, 2);
297       int16x8_t vprod0x0123c3 = vmull_s8(vb0123c3x0, va0x0);
298       int16x8_t vprod1x0123c3 = vmull_s8(vb0123c3x0, va1x0);
299       int16x8_t vprod2x0123c3 = vmull_s8(vb0123c3x0, va2x0);
300       int16x8_t vprod3x0123c3 = vmull_s8(vb0123c3x0, va3x0);
301       vacc0x0123 = vpadalq_s16(vacc0x0123, vprod0x0123c3);
302       vacc1x0123 = vpadalq_s16(vacc1x0123, vprod1x0123c3);
303       vacc2x0123 = vpadalq_s16(vacc2x0123, vprod2x0123c3);
304       vacc3x0123 = vpadalq_s16(vacc3x0123, vprod3x0123c3);
305       int16x8_t vprod0x4567c3 = vmull_s8(vb4567c3x0, va0x0);
306       int16x8_t vprod1x4567c3 = vmull_s8(vb4567c3x0, va1x0);
307       int16x8_t vprod2x4567c3 = vmull_s8(vb4567c3x0, va2x0);
308       int16x8_t vprod3x4567c3 = vmull_s8(vb4567c3x0, va3x0);
309       vacc0x4567 = vpadalq_s16(vacc0x4567, vprod0x4567c3);
310       vacc1x4567 = vpadalq_s16(vacc1x4567, vprod1x4567c3);
311       vacc2x4567 = vpadalq_s16(vacc2x4567, vprod2x4567c3);
312       vacc3x4567 = vpadalq_s16(vacc3x4567, vprod3x4567c3);
313 
314     }
315 
316     const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->rndnu_neon.right_pre_shift);
317     const int32x4_t vmultiplier = vld1q_dup_s32(&params->rndnu_neon.multiplier);
318     const int32x4_t vright_post_shift = vld1q_dup_s32(&params->rndnu_neon.right_post_shift);
319 
320     vacc0x0123 = vqshlq_s32(vacc0x0123, vright_pre_shift);
321     vacc0x4567 = vqshlq_s32(vacc0x4567, vright_pre_shift);
322     vacc1x0123 = vqshlq_s32(vacc1x0123, vright_pre_shift);
323     vacc1x4567 = vqshlq_s32(vacc1x4567, vright_pre_shift);
324     vacc2x0123 = vqshlq_s32(vacc2x0123, vright_pre_shift);
325     vacc2x4567 = vqshlq_s32(vacc2x4567, vright_pre_shift);
326     vacc3x0123 = vqshlq_s32(vacc3x0123, vright_pre_shift);
327     vacc3x4567 = vqshlq_s32(vacc3x4567, vright_pre_shift);
328 
329     vacc0x0123 = vqdmulhq_s32(vacc0x0123, vmultiplier);
330     vacc0x4567 = vqdmulhq_s32(vacc0x4567, vmultiplier);
331     vacc1x0123 = vqdmulhq_s32(vacc1x0123, vmultiplier);
332     vacc1x4567 = vqdmulhq_s32(vacc1x4567, vmultiplier);
333     vacc2x0123 = vqdmulhq_s32(vacc2x0123, vmultiplier);
334     vacc2x4567 = vqdmulhq_s32(vacc2x4567, vmultiplier);
335     vacc3x0123 = vqdmulhq_s32(vacc3x0123, vmultiplier);
336     vacc3x4567 = vqdmulhq_s32(vacc3x4567, vmultiplier);
337 
338     vacc0x0123 = vrshlq_s32(vacc0x0123, vright_post_shift);
339     vacc0x4567 = vrshlq_s32(vacc0x4567, vright_post_shift);
340     vacc1x0123 = vrshlq_s32(vacc1x0123, vright_post_shift);
341     vacc1x4567 = vrshlq_s32(vacc1x4567, vright_post_shift);
342     vacc2x0123 = vrshlq_s32(vacc2x0123, vright_post_shift);
343     vacc2x4567 = vrshlq_s32(vacc2x4567, vright_post_shift);
344     vacc3x0123 = vrshlq_s32(vacc3x0123, vright_post_shift);
345     vacc3x4567 = vrshlq_s32(vacc3x4567, vright_post_shift);
346 
347     const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->rndnu_neon.output_zero_point);
348 #if XNN_ARCH_ARM64
349     int16x8_t vacc0x01234567 = vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567);
350     int16x8_t vacc1x01234567 = vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567);
351     int16x8_t vacc2x01234567 = vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567);
352     int16x8_t vacc3x01234567 = vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567);
353 
354     vacc0x01234567 = vqaddq_s16(vacc0x01234567, voutput_zero_point);
355     vacc1x01234567 = vqaddq_s16(vacc1x01234567, voutput_zero_point);
356     vacc2x01234567 = vqaddq_s16(vacc2x01234567, voutput_zero_point);
357     vacc3x01234567 = vqaddq_s16(vacc3x01234567, voutput_zero_point);
358 
359     int8x16_t vout0x01234567_1x01234567 = vqmovn_high_s16(vqmovn_s16(vacc0x01234567), vacc1x01234567);
360     int8x16_t vout2x01234567_3x01234567 = vqmovn_high_s16(vqmovn_s16(vacc2x01234567), vacc3x01234567);
361 #else
362     int16x8_t vacc0x01234567 = vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567));
363     int16x8_t vacc1x01234567 = vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567));
364     int16x8_t vacc2x01234567 = vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567));
365     int16x8_t vacc3x01234567 = vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567));
366 
367     vacc0x01234567 = vqaddq_s16(vacc0x01234567, voutput_zero_point);
368     vacc1x01234567 = vqaddq_s16(vacc1x01234567, voutput_zero_point);
369     vacc2x01234567 = vqaddq_s16(vacc2x01234567, voutput_zero_point);
370     vacc3x01234567 = vqaddq_s16(vacc3x01234567, voutput_zero_point);
371 
372     int8x16_t vout0x01234567_1x01234567 = vcombine_s8(vqmovn_s16(vacc0x01234567), vqmovn_s16(vacc1x01234567));
373     int8x16_t vout2x01234567_3x01234567 = vcombine_s8(vqmovn_s16(vacc2x01234567), vqmovn_s16(vacc3x01234567));
374 #endif
375 
376     const int8x16_t voutput_min = vld1q_dup_s8(&params->rndnu_neon.output_min);
377     vout0x01234567_1x01234567 = vmaxq_s8(vout0x01234567_1x01234567, voutput_min);
378     vout2x01234567_3x01234567 = vmaxq_s8(vout2x01234567_3x01234567, voutput_min);
379 
380     const int8x16_t voutput_max = vld1q_dup_s8(&params->rndnu_neon.output_max);
381     vout0x01234567_1x01234567 = vminq_s8(vout0x01234567_1x01234567, voutput_max);
382     vout2x01234567_3x01234567 = vminq_s8(vout2x01234567_3x01234567, voutput_max);
383 
384     if (nc >= 8) {
385       vst1_s8(c0 + 0, vget_low_s8(vout0x01234567_1x01234567));
386       vst1_s8(c1 + 0, vget_high_s8(vout0x01234567_1x01234567));
387       vst1_s8(c2 + 0, vget_low_s8(vout2x01234567_3x01234567));
388       vst1_s8(c3 + 0, vget_high_s8(vout2x01234567_3x01234567));
389 
390       c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
391       c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
392       c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
393       c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
394 
395       a0 = (const int8_t*) ((uintptr_t) a0 - kc);
396       a1 = (const int8_t*) ((uintptr_t) a1 - kc);
397       a2 = (const int8_t*) ((uintptr_t) a2 - kc);
398       a3 = (const int8_t*) ((uintptr_t) a3 - kc);
399 
400       nc -= 8;
401     } else {
402       // Final case where not all of the 8 columns fit in the destination.
403       if (nc & 4) {
404         vst1q_lane_u32((void*) c0, vreinterpretq_u32_s8(vout0x01234567_1x01234567), 0); c0 += 4;
405         vst1q_lane_u32((void*) c1, vreinterpretq_u32_s8(vout0x01234567_1x01234567), 2); c1 += 4;
406         vst1q_lane_u32((void*) c2, vreinterpretq_u32_s8(vout2x01234567_3x01234567), 0); c2 += 4;
407         vst1q_lane_u32((void*) c3, vreinterpretq_u32_s8(vout2x01234567_3x01234567), 2); c3 += 4;
408         vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
409         vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
410       }
411       if (nc & 2) {
412         vst1q_lane_u16((void*) c0, vreinterpretq_u16_s8(vout0x01234567_1x01234567), 0); c0 += 2;
413         vst1q_lane_u16((void*) c1, vreinterpretq_u16_s8(vout0x01234567_1x01234567), 4); c1 += 2;
414         vst1q_lane_u16((void*) c2, vreinterpretq_u16_s8(vout2x01234567_3x01234567), 0); c2 += 2;
415         vst1q_lane_u16((void*) c3, vreinterpretq_u16_s8(vout2x01234567_3x01234567), 4); c3 += 2;
416         vout0x01234567_1x01234567 = vextq_s8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
417         vout2x01234567_3x01234567 = vextq_s8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
418       }
419       if (nc & 1) {
420         vst1q_lane_s8(c0, vout0x01234567_1x01234567, 0);
421         vst1q_lane_s8(c1, vout0x01234567_1x01234567, 8);
422         vst1q_lane_s8(c2, vout2x01234567_3x01234567, 0);
423         vst1q_lane_s8(c3, vout2x01234567_3x01234567, 8);
424       }
425 
426       nc = 0;
427     }
428   } while (nc != 0);
429 }
430