1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon-blocked.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_32x4__neonfma(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_32x4__neonfma(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
34 size_t output_decrement = output_stride * nc - 32 * sizeof(float);
35 while XNN_LIKELY(mc >= 32 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 size_t n = nc;
40 while (n >= 4) {
41 uint32_t nnz = *nnzmap++;
42 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
43 float32x4_t vacc4567n0 = vacc0123n0;
44 float32x4_t vacc89ABn0 = vacc0123n0;
45 float32x4_t vaccCDEFn0 = vacc0123n0;
46 float32x4_t vaccGHIJn0 = vacc0123n0;
47 float32x4_t vaccKLMNn0 = vacc0123n0;
48 float32x4_t vaccOPQRn0 = vacc0123n0;
49 float32x4_t vaccSTUVn0 = vacc0123n0;
50 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
51 float32x4_t vacc4567n1 = vacc0123n1;
52 float32x4_t vacc89ABn1 = vacc0123n1;
53 float32x4_t vaccCDEFn1 = vacc0123n1;
54 float32x4_t vaccGHIJn1 = vacc0123n1;
55 float32x4_t vaccKLMNn1 = vacc0123n1;
56 float32x4_t vaccOPQRn1 = vacc0123n1;
57 float32x4_t vaccSTUVn1 = vacc0123n1;
58 float32x4_t vacc0123n2 = vld1q_dup_f32(w); w += 1;
59 float32x4_t vacc4567n2 = vacc0123n2;
60 float32x4_t vacc89ABn2 = vacc0123n2;
61 float32x4_t vaccCDEFn2 = vacc0123n2;
62 float32x4_t vaccGHIJn2 = vacc0123n2;
63 float32x4_t vaccKLMNn2 = vacc0123n2;
64 float32x4_t vaccOPQRn2 = vacc0123n2;
65 float32x4_t vaccSTUVn2 = vacc0123n2;
66 float32x4_t vacc0123n3 = vld1q_dup_f32(w); w += 1;
67 float32x4_t vacc4567n3 = vacc0123n3;
68 float32x4_t vacc89ABn3 = vacc0123n3;
69 float32x4_t vaccCDEFn3 = vacc0123n3;
70 float32x4_t vaccGHIJn3 = vacc0123n3;
71 float32x4_t vaccKLMNn3 = vacc0123n3;
72 float32x4_t vaccOPQRn3 = vacc0123n3;
73 float32x4_t vaccSTUVn3 = vacc0123n3;
74 if XNN_LIKELY(nnz != 0) {
75 do {
76 const intptr_t diff = *dmap++;
77 const float32x4_t vi0123 = vld1q_f32(input);
78 const float32x4_t vi4567 = vld1q_f32(input + 4);
79 const float32x4_t vi89AB = vld1q_f32(input + 8);
80 const float32x4_t viCDEF = vld1q_f32(input + 12);
81 const float32x4_t viGHIJ = vld1q_f32(input + 16);
82 const float32x4_t viKLMN = vld1q_f32(input + 20);
83 const float32x4_t viOPQR = vld1q_f32(input + 24);
84 const float32x4_t viSTUV = vld1q_f32(input + 28);
85 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
86 __builtin_prefetch(input + 16);
87 __builtin_prefetch(input + 32);
88 const float32x4_t vw = vld1q_f32(w); w += 4;
89 __builtin_prefetch(w + 32);
90 vacc0123n0 = vfmaq_laneq_f32(vacc0123n0, vi0123, vw, 0);
91 vacc4567n0 = vfmaq_laneq_f32(vacc4567n0, vi4567, vw, 0);
92 vacc89ABn0 = vfmaq_laneq_f32(vacc89ABn0, vi89AB, vw, 0);
93 vaccCDEFn0 = vfmaq_laneq_f32(vaccCDEFn0, viCDEF, vw, 0);
94 vaccGHIJn0 = vfmaq_laneq_f32(vaccGHIJn0, viGHIJ, vw, 0);
95 vaccKLMNn0 = vfmaq_laneq_f32(vaccKLMNn0, viKLMN, vw, 0);
96 vaccOPQRn0 = vfmaq_laneq_f32(vaccOPQRn0, viOPQR, vw, 0);
97 vaccSTUVn0 = vfmaq_laneq_f32(vaccSTUVn0, viSTUV, vw, 0);
98 vacc0123n1 = vfmaq_laneq_f32(vacc0123n1, vi0123, vw, 1);
99 vacc4567n1 = vfmaq_laneq_f32(vacc4567n1, vi4567, vw, 1);
100 vacc89ABn1 = vfmaq_laneq_f32(vacc89ABn1, vi89AB, vw, 1);
101 vaccCDEFn1 = vfmaq_laneq_f32(vaccCDEFn1, viCDEF, vw, 1);
102 vaccGHIJn1 = vfmaq_laneq_f32(vaccGHIJn1, viGHIJ, vw, 1);
103 vaccKLMNn1 = vfmaq_laneq_f32(vaccKLMNn1, viKLMN, vw, 1);
104 vaccOPQRn1 = vfmaq_laneq_f32(vaccOPQRn1, viOPQR, vw, 1);
105 vaccSTUVn1 = vfmaq_laneq_f32(vaccSTUVn1, viSTUV, vw, 1);
106 vacc0123n2 = vfmaq_laneq_f32(vacc0123n2, vi0123, vw, 2);
107 vacc4567n2 = vfmaq_laneq_f32(vacc4567n2, vi4567, vw, 2);
108 vacc89ABn2 = vfmaq_laneq_f32(vacc89ABn2, vi89AB, vw, 2);
109 vaccCDEFn2 = vfmaq_laneq_f32(vaccCDEFn2, viCDEF, vw, 2);
110 vaccGHIJn2 = vfmaq_laneq_f32(vaccGHIJn2, viGHIJ, vw, 2);
111 vaccKLMNn2 = vfmaq_laneq_f32(vaccKLMNn2, viKLMN, vw, 2);
112 vaccOPQRn2 = vfmaq_laneq_f32(vaccOPQRn2, viOPQR, vw, 2);
113 vaccSTUVn2 = vfmaq_laneq_f32(vaccSTUVn2, viSTUV, vw, 2);
114 vacc0123n3 = vfmaq_laneq_f32(vacc0123n3, vi0123, vw, 3);
115 vacc4567n3 = vfmaq_laneq_f32(vacc4567n3, vi4567, vw, 3);
116 vacc89ABn3 = vfmaq_laneq_f32(vacc89ABn3, vi89AB, vw, 3);
117 vaccCDEFn3 = vfmaq_laneq_f32(vaccCDEFn3, viCDEF, vw, 3);
118 vaccGHIJn3 = vfmaq_laneq_f32(vaccGHIJn3, viGHIJ, vw, 3);
119 vaccKLMNn3 = vfmaq_laneq_f32(vaccKLMNn3, viKLMN, vw, 3);
120 vaccOPQRn3 = vfmaq_laneq_f32(vaccOPQRn3, viOPQR, vw, 3);
121 vaccSTUVn3 = vfmaq_laneq_f32(vaccSTUVn3, viSTUV, vw, 3);
122 } while (--nnz != 0);
123 }
124 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
125 float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
126 float32x4_t vout89ABn0 = vminq_f32(vacc89ABn0, vmax);
127 float32x4_t voutCDEFn0 = vminq_f32(vaccCDEFn0, vmax);
128 float32x4_t voutGHIJn0 = vminq_f32(vaccGHIJn0, vmax);
129 float32x4_t voutKLMNn0 = vminq_f32(vaccKLMNn0, vmax);
130 float32x4_t voutOPQRn0 = vminq_f32(vaccOPQRn0, vmax);
131 float32x4_t voutSTUVn0 = vminq_f32(vaccSTUVn0, vmax);
132 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
133 float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
134 float32x4_t vout89ABn1 = vminq_f32(vacc89ABn1, vmax);
135 float32x4_t voutCDEFn1 = vminq_f32(vaccCDEFn1, vmax);
136 float32x4_t voutGHIJn1 = vminq_f32(vaccGHIJn1, vmax);
137 float32x4_t voutKLMNn1 = vminq_f32(vaccKLMNn1, vmax);
138 float32x4_t voutOPQRn1 = vminq_f32(vaccOPQRn1, vmax);
139 float32x4_t voutSTUVn1 = vminq_f32(vaccSTUVn1, vmax);
140 float32x4_t vout0123n2 = vminq_f32(vacc0123n2, vmax);
141 float32x4_t vout4567n2 = vminq_f32(vacc4567n2, vmax);
142 float32x4_t vout89ABn2 = vminq_f32(vacc89ABn2, vmax);
143 float32x4_t voutCDEFn2 = vminq_f32(vaccCDEFn2, vmax);
144 float32x4_t voutGHIJn2 = vminq_f32(vaccGHIJn2, vmax);
145 float32x4_t voutKLMNn2 = vminq_f32(vaccKLMNn2, vmax);
146 float32x4_t voutOPQRn2 = vminq_f32(vaccOPQRn2, vmax);
147 float32x4_t voutSTUVn2 = vminq_f32(vaccSTUVn2, vmax);
148 float32x4_t vout0123n3 = vminq_f32(vacc0123n3, vmax);
149 float32x4_t vout4567n3 = vminq_f32(vacc4567n3, vmax);
150 float32x4_t vout89ABn3 = vminq_f32(vacc89ABn3, vmax);
151 float32x4_t voutCDEFn3 = vminq_f32(vaccCDEFn3, vmax);
152 float32x4_t voutGHIJn3 = vminq_f32(vaccGHIJn3, vmax);
153 float32x4_t voutKLMNn3 = vminq_f32(vaccKLMNn3, vmax);
154 float32x4_t voutOPQRn3 = vminq_f32(vaccOPQRn3, vmax);
155 float32x4_t voutSTUVn3 = vminq_f32(vaccSTUVn3, vmax);
156
157 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
158 vout4567n0 = vmaxq_f32(vout4567n0, vmin);
159 vout89ABn0 = vmaxq_f32(vout89ABn0, vmin);
160 voutCDEFn0 = vmaxq_f32(voutCDEFn0, vmin);
161 voutGHIJn0 = vmaxq_f32(voutGHIJn0, vmin);
162 voutKLMNn0 = vmaxq_f32(voutKLMNn0, vmin);
163 voutOPQRn0 = vmaxq_f32(voutOPQRn0, vmin);
164 voutSTUVn0 = vmaxq_f32(voutSTUVn0, vmin);
165 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
166 vout4567n1 = vmaxq_f32(vout4567n1, vmin);
167 vout89ABn1 = vmaxq_f32(vout89ABn1, vmin);
168 voutCDEFn1 = vmaxq_f32(voutCDEFn1, vmin);
169 voutGHIJn1 = vmaxq_f32(voutGHIJn1, vmin);
170 voutKLMNn1 = vmaxq_f32(voutKLMNn1, vmin);
171 voutOPQRn1 = vmaxq_f32(voutOPQRn1, vmin);
172 voutSTUVn1 = vmaxq_f32(voutSTUVn1, vmin);
173 vout0123n2 = vmaxq_f32(vout0123n2, vmin);
174 vout4567n2 = vmaxq_f32(vout4567n2, vmin);
175 vout89ABn2 = vmaxq_f32(vout89ABn2, vmin);
176 voutCDEFn2 = vmaxq_f32(voutCDEFn2, vmin);
177 voutGHIJn2 = vmaxq_f32(voutGHIJn2, vmin);
178 voutKLMNn2 = vmaxq_f32(voutKLMNn2, vmin);
179 voutOPQRn2 = vmaxq_f32(voutOPQRn2, vmin);
180 voutSTUVn2 = vmaxq_f32(voutSTUVn2, vmin);
181 vout0123n3 = vmaxq_f32(vout0123n3, vmin);
182 vout4567n3 = vmaxq_f32(vout4567n3, vmin);
183 vout89ABn3 = vmaxq_f32(vout89ABn3, vmin);
184 voutCDEFn3 = vmaxq_f32(voutCDEFn3, vmin);
185 voutGHIJn3 = vmaxq_f32(voutGHIJn3, vmin);
186 voutKLMNn3 = vmaxq_f32(voutKLMNn3, vmin);
187 voutOPQRn3 = vmaxq_f32(voutOPQRn3, vmin);
188 voutSTUVn3 = vmaxq_f32(voutSTUVn3, vmin);
189
190 vst1q_f32(output + 0, vout0123n0);
191 vst1q_f32(output + 4, vout4567n0);
192 vst1q_f32(output + 8, vout89ABn0);
193 vst1q_f32(output + 12, voutCDEFn0);
194 vst1q_f32(output + 16, voutGHIJn0);
195 vst1q_f32(output + 20, voutKLMNn0);
196 vst1q_f32(output + 24, voutOPQRn0);
197 vst1q_f32(output + 28, voutSTUVn0);
198 output = (float*restrict) ((uintptr_t) output + output_stride);
199 vst1q_f32(output + 0, vout0123n1);
200 vst1q_f32(output + 4, vout4567n1);
201 vst1q_f32(output + 8, vout89ABn1);
202 vst1q_f32(output + 12, voutCDEFn1);
203 vst1q_f32(output + 16, voutGHIJn1);
204 vst1q_f32(output + 20, voutKLMNn1);
205 vst1q_f32(output + 24, voutOPQRn1);
206 vst1q_f32(output + 28, voutSTUVn1);
207 output = (float*restrict) ((uintptr_t) output + output_stride);
208 vst1q_f32(output + 0, vout0123n2);
209 vst1q_f32(output + 4, vout4567n2);
210 vst1q_f32(output + 8, vout89ABn2);
211 vst1q_f32(output + 12, voutCDEFn2);
212 vst1q_f32(output + 16, voutGHIJn2);
213 vst1q_f32(output + 20, voutKLMNn2);
214 vst1q_f32(output + 24, voutOPQRn2);
215 vst1q_f32(output + 28, voutSTUVn2);
216 output = (float*restrict) ((uintptr_t) output + output_stride);
217 vst1q_f32(output + 0, vout0123n3);
218 vst1q_f32(output + 4, vout4567n3);
219 vst1q_f32(output + 8, vout89ABn3);
220 vst1q_f32(output + 12, voutCDEFn3);
221 vst1q_f32(output + 16, voutGHIJn3);
222 vst1q_f32(output + 20, voutKLMNn3);
223 vst1q_f32(output + 24, voutOPQRn3);
224 vst1q_f32(output + 28, voutSTUVn3);
225 output = (float*restrict) ((uintptr_t) output + output_stride);
226 n -= 4;
227 }
228
229 // clean up loop, fall back to nr=1
230 if XNN_UNLIKELY(n != 0) {
231 do {
232 uint32_t nnz = *nnzmap++;
233 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
234 float32x4_t vacc4567 = vacc0123;
235 float32x4_t vacc89AB = vacc0123;
236 float32x4_t vaccCDEF = vacc0123;
237 float32x4_t vaccGHIJ = vacc0123;
238 float32x4_t vaccKLMN = vacc0123;
239 float32x4_t vaccOPQR = vacc0123;
240 float32x4_t vaccSTUV = vacc0123;
241 if XNN_LIKELY(nnz != 0) {
242 do {
243 const intptr_t diff = *dmap++;
244 const float32x4_t vi0123 = vld1q_f32(input);
245 const float32x4_t vi4567 = vld1q_f32(input + 4);
246 const float32x4_t vi89AB = vld1q_f32(input + 8);
247 const float32x4_t viCDEF = vld1q_f32(input + 12);
248 const float32x4_t viGHIJ = vld1q_f32(input + 16);
249 const float32x4_t viKLMN = vld1q_f32(input + 20);
250 const float32x4_t viOPQR = vld1q_f32(input + 24);
251 const float32x4_t viSTUV = vld1q_f32(input + 28);
252 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
253 __builtin_prefetch(input + 16);
254 __builtin_prefetch(input + 32);
255 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
256 __builtin_prefetch(w + 32);
257 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
258 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
259 vacc89AB = vfmaq_f32(vacc89AB, vi89AB, vw);
260 vaccCDEF = vfmaq_f32(vaccCDEF, viCDEF, vw);
261 vaccGHIJ = vfmaq_f32(vaccGHIJ, viGHIJ, vw);
262 vaccKLMN = vfmaq_f32(vaccKLMN, viKLMN, vw);
263 vaccOPQR = vfmaq_f32(vaccOPQR, viOPQR, vw);
264 vaccSTUV = vfmaq_f32(vaccSTUV, viSTUV, vw);
265 } while (--nnz != 0);
266 }
267 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
268 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
269 float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
270 float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
271 float32x4_t voutGHIJ = vminq_f32(vaccGHIJ, vmax);
272 float32x4_t voutKLMN = vminq_f32(vaccKLMN, vmax);
273 float32x4_t voutOPQR = vminq_f32(vaccOPQR, vmax);
274 float32x4_t voutSTUV = vminq_f32(vaccSTUV, vmax);
275
276 vout0123 = vmaxq_f32(vout0123, vmin);
277 vout4567 = vmaxq_f32(vout4567, vmin);
278 vout89AB = vmaxq_f32(vout89AB, vmin);
279 voutCDEF = vmaxq_f32(voutCDEF, vmin);
280 voutGHIJ = vmaxq_f32(voutGHIJ, vmin);
281 voutKLMN = vmaxq_f32(voutKLMN, vmin);
282 voutOPQR = vmaxq_f32(voutOPQR, vmin);
283 voutSTUV = vmaxq_f32(voutSTUV, vmin);
284
285 vst1q_f32(output + 0, vout0123);
286 vst1q_f32(output + 4, vout4567);
287 vst1q_f32(output + 8, vout89AB);
288 vst1q_f32(output + 12, voutCDEF);
289 vst1q_f32(output + 16, voutGHIJ);
290 vst1q_f32(output + 20, voutKLMN);
291 vst1q_f32(output + 24, voutOPQR);
292 vst1q_f32(output + 28, voutSTUV);
293 output = (float*restrict) ((uintptr_t) output + output_stride);
294 n -= 1;
295 } while (n != 0);
296 }
297 output = (float*restrict) ((uintptr_t) output - output_decrement);
298 input += 32;
299 mc -= 32 * sizeof(float);
300 }
301 if XNN_UNLIKELY(mc != 0) {
302 output_decrement += 16 * sizeof(float);
303 if (mc & (16 * sizeof(float))) {
304 const float*restrict w = weights;
305 const int32_t* dmap = widx_dmap;
306 const uint32_t* nnzmap = nidx_nnzmap;
307 size_t n = nc;
308 while (n >= 4) {
309 uint32_t nnz = *nnzmap++;
310 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
311 float32x4_t vacc4567n0 = vacc0123n0;
312 float32x4_t vacc89ABn0 = vacc0123n0;
313 float32x4_t vaccCDEFn0 = vacc0123n0;
314 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
315 float32x4_t vacc4567n1 = vacc0123n1;
316 float32x4_t vacc89ABn1 = vacc0123n1;
317 float32x4_t vaccCDEFn1 = vacc0123n1;
318 float32x4_t vacc0123n2 = vld1q_dup_f32(w); w += 1;
319 float32x4_t vacc4567n2 = vacc0123n2;
320 float32x4_t vacc89ABn2 = vacc0123n2;
321 float32x4_t vaccCDEFn2 = vacc0123n2;
322 float32x4_t vacc0123n3 = vld1q_dup_f32(w); w += 1;
323 float32x4_t vacc4567n3 = vacc0123n3;
324 float32x4_t vacc89ABn3 = vacc0123n3;
325 float32x4_t vaccCDEFn3 = vacc0123n3;
326 if XNN_LIKELY(nnz != 0) {
327 do {
328 const intptr_t diff = *dmap++;
329 const float32x4_t vi0123 = vld1q_f32(input);
330 const float32x4_t vi4567 = vld1q_f32(input + 4);
331 const float32x4_t vi89AB = vld1q_f32(input + 8);
332 const float32x4_t viCDEF = vld1q_f32(input + 12);
333 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
334 const float32x4_t vw = vld1q_f32(w); w += 4;
335
336 vacc0123n0 = vfmaq_laneq_f32(vacc0123n0, vi0123, vw, 0);
337 vacc4567n0 = vfmaq_laneq_f32(vacc4567n0, vi4567, vw, 0);
338 vacc89ABn0 = vfmaq_laneq_f32(vacc89ABn0, vi89AB, vw, 0);
339 vaccCDEFn0 = vfmaq_laneq_f32(vaccCDEFn0, viCDEF, vw, 0);
340 vacc0123n1 = vfmaq_laneq_f32(vacc0123n1, vi0123, vw, 1);
341 vacc4567n1 = vfmaq_laneq_f32(vacc4567n1, vi4567, vw, 1);
342 vacc89ABn1 = vfmaq_laneq_f32(vacc89ABn1, vi89AB, vw, 1);
343 vaccCDEFn1 = vfmaq_laneq_f32(vaccCDEFn1, viCDEF, vw, 1);
344 vacc0123n2 = vfmaq_laneq_f32(vacc0123n2, vi0123, vw, 2);
345 vacc4567n2 = vfmaq_laneq_f32(vacc4567n2, vi4567, vw, 2);
346 vacc89ABn2 = vfmaq_laneq_f32(vacc89ABn2, vi89AB, vw, 2);
347 vaccCDEFn2 = vfmaq_laneq_f32(vaccCDEFn2, viCDEF, vw, 2);
348 vacc0123n3 = vfmaq_laneq_f32(vacc0123n3, vi0123, vw, 3);
349 vacc4567n3 = vfmaq_laneq_f32(vacc4567n3, vi4567, vw, 3);
350 vacc89ABn3 = vfmaq_laneq_f32(vacc89ABn3, vi89AB, vw, 3);
351 vaccCDEFn3 = vfmaq_laneq_f32(vaccCDEFn3, viCDEF, vw, 3);
352 } while (--nnz != 0);
353 }
354 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
355 float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
356 float32x4_t vout89ABn0 = vminq_f32(vacc89ABn0, vmax);
357 float32x4_t voutCDEFn0 = vminq_f32(vaccCDEFn0, vmax);
358 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
359 float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
360 float32x4_t vout89ABn1 = vminq_f32(vacc89ABn1, vmax);
361 float32x4_t voutCDEFn1 = vminq_f32(vaccCDEFn1, vmax);
362 float32x4_t vout0123n2 = vminq_f32(vacc0123n2, vmax);
363 float32x4_t vout4567n2 = vminq_f32(vacc4567n2, vmax);
364 float32x4_t vout89ABn2 = vminq_f32(vacc89ABn2, vmax);
365 float32x4_t voutCDEFn2 = vminq_f32(vaccCDEFn2, vmax);
366 float32x4_t vout0123n3 = vminq_f32(vacc0123n3, vmax);
367 float32x4_t vout4567n3 = vminq_f32(vacc4567n3, vmax);
368 float32x4_t vout89ABn3 = vminq_f32(vacc89ABn3, vmax);
369 float32x4_t voutCDEFn3 = vminq_f32(vaccCDEFn3, vmax);
370
371 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
372 vout4567n0 = vmaxq_f32(vout4567n0, vmin);
373 vout89ABn0 = vmaxq_f32(vout89ABn0, vmin);
374 voutCDEFn0 = vmaxq_f32(voutCDEFn0, vmin);
375 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
376 vout4567n1 = vmaxq_f32(vout4567n1, vmin);
377 vout89ABn1 = vmaxq_f32(vout89ABn1, vmin);
378 voutCDEFn1 = vmaxq_f32(voutCDEFn1, vmin);
379 vout0123n2 = vmaxq_f32(vout0123n2, vmin);
380 vout4567n2 = vmaxq_f32(vout4567n2, vmin);
381 vout89ABn2 = vmaxq_f32(vout89ABn2, vmin);
382 voutCDEFn2 = vmaxq_f32(voutCDEFn2, vmin);
383 vout0123n3 = vmaxq_f32(vout0123n3, vmin);
384 vout4567n3 = vmaxq_f32(vout4567n3, vmin);
385 vout89ABn3 = vmaxq_f32(vout89ABn3, vmin);
386 voutCDEFn3 = vmaxq_f32(voutCDEFn3, vmin);
387
388 vst1q_f32(output + 0, vout0123n0);
389 vst1q_f32(output + 4, vout4567n0);
390 vst1q_f32(output + 8, vout89ABn0);
391 vst1q_f32(output + 12, voutCDEFn0);
392 output = (float*restrict) ((uintptr_t) output + output_stride);
393 vst1q_f32(output + 0, vout0123n1);
394 vst1q_f32(output + 4, vout4567n1);
395 vst1q_f32(output + 8, vout89ABn1);
396 vst1q_f32(output + 12, voutCDEFn1);
397 output = (float*restrict) ((uintptr_t) output + output_stride);
398 vst1q_f32(output + 0, vout0123n2);
399 vst1q_f32(output + 4, vout4567n2);
400 vst1q_f32(output + 8, vout89ABn2);
401 vst1q_f32(output + 12, voutCDEFn2);
402 output = (float*restrict) ((uintptr_t) output + output_stride);
403 vst1q_f32(output + 0, vout0123n3);
404 vst1q_f32(output + 4, vout4567n3);
405 vst1q_f32(output + 8, vout89ABn3);
406 vst1q_f32(output + 12, voutCDEFn3);
407 output = (float*restrict) ((uintptr_t) output + output_stride);
408 n -= 4;
409 }
410
411 // clean up loop, fall back to nr=1
412 if XNN_UNLIKELY(n != 0) {
413 do {
414 uint32_t nnz = *nnzmap++;
415 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
416 float32x4_t vacc4567 = vacc0123;
417 float32x4_t vacc89AB = vacc0123;
418 float32x4_t vaccCDEF = vacc0123;
419 if XNN_LIKELY(nnz != 0) {
420 do {
421 const intptr_t diff = *dmap++;
422 const float32x4_t vi0123 = vld1q_f32(input);
423 const float32x4_t vi4567 = vld1q_f32(input + 4);
424 const float32x4_t vi89AB = vld1q_f32(input + 8);
425 const float32x4_t viCDEF = vld1q_f32(input + 12);
426 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
427 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
428 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
429 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
430 vacc89AB = vfmaq_f32(vacc89AB, vi89AB, vw);
431 vaccCDEF = vfmaq_f32(vaccCDEF, viCDEF, vw);
432 } while (--nnz != 0);
433 }
434 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
435 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
436 float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
437 float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
438
439 vout0123 = vmaxq_f32(vout0123, vmin);
440 vout4567 = vmaxq_f32(vout4567, vmin);
441 vout89AB = vmaxq_f32(vout89AB, vmin);
442 voutCDEF = vmaxq_f32(voutCDEF, vmin);
443
444 vst1q_f32(output + 0, vout0123);
445 vst1q_f32(output + 4, vout4567);
446 vst1q_f32(output + 8, vout89AB);
447 vst1q_f32(output + 12, voutCDEF);
448 output = (float*restrict) ((uintptr_t) output + output_stride);
449 n -= 1;
450 } while (n != 0);
451 }
452 output = (float*restrict) ((uintptr_t) output - output_decrement);
453 input += 16;
454 }
455 output_decrement += 8 * sizeof(float);
456 if (mc & (8 * sizeof(float))) {
457 const float*restrict w = weights;
458 const int32_t* dmap = widx_dmap;
459 const uint32_t* nnzmap = nidx_nnzmap;
460 size_t n = nc;
461 while (n >= 4) {
462 uint32_t nnz = *nnzmap++;
463 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
464 float32x4_t vacc4567n0 = vacc0123n0;
465 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
466 float32x4_t vacc4567n1 = vacc0123n1;
467 float32x4_t vacc0123n2 = vld1q_dup_f32(w); w += 1;
468 float32x4_t vacc4567n2 = vacc0123n2;
469 float32x4_t vacc0123n3 = vld1q_dup_f32(w); w += 1;
470 float32x4_t vacc4567n3 = vacc0123n3;
471 if XNN_LIKELY(nnz != 0) {
472 do {
473 const intptr_t diff = *dmap++;
474 const float32x4_t vi0123 = vld1q_f32(input);
475 const float32x4_t vi4567 = vld1q_f32(input + 4);
476 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
477 const float32x4_t vw = vld1q_f32(w); w += 4;
478
479 vacc0123n0 = vfmaq_laneq_f32(vacc0123n0, vi0123, vw, 0);
480 vacc4567n0 = vfmaq_laneq_f32(vacc4567n0, vi4567, vw, 0);
481 vacc0123n1 = vfmaq_laneq_f32(vacc0123n1, vi0123, vw, 1);
482 vacc4567n1 = vfmaq_laneq_f32(vacc4567n1, vi4567, vw, 1);
483 vacc0123n2 = vfmaq_laneq_f32(vacc0123n2, vi0123, vw, 2);
484 vacc4567n2 = vfmaq_laneq_f32(vacc4567n2, vi4567, vw, 2);
485 vacc0123n3 = vfmaq_laneq_f32(vacc0123n3, vi0123, vw, 3);
486 vacc4567n3 = vfmaq_laneq_f32(vacc4567n3, vi4567, vw, 3);
487 } while (--nnz != 0);
488 }
489 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
490 float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
491 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
492 float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
493 float32x4_t vout0123n2 = vminq_f32(vacc0123n2, vmax);
494 float32x4_t vout4567n2 = vminq_f32(vacc4567n2, vmax);
495 float32x4_t vout0123n3 = vminq_f32(vacc0123n3, vmax);
496 float32x4_t vout4567n3 = vminq_f32(vacc4567n3, vmax);
497
498 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
499 vout4567n0 = vmaxq_f32(vout4567n0, vmin);
500 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
501 vout4567n1 = vmaxq_f32(vout4567n1, vmin);
502 vout0123n2 = vmaxq_f32(vout0123n2, vmin);
503 vout4567n2 = vmaxq_f32(vout4567n2, vmin);
504 vout0123n3 = vmaxq_f32(vout0123n3, vmin);
505 vout4567n3 = vmaxq_f32(vout4567n3, vmin);
506
507 vst1q_f32(output + 0, vout0123n0);
508 vst1q_f32(output + 4, vout4567n0);
509 output = (float*restrict) ((uintptr_t) output + output_stride);
510 vst1q_f32(output + 0, vout0123n1);
511 vst1q_f32(output + 4, vout4567n1);
512 output = (float*restrict) ((uintptr_t) output + output_stride);
513 vst1q_f32(output + 0, vout0123n2);
514 vst1q_f32(output + 4, vout4567n2);
515 output = (float*restrict) ((uintptr_t) output + output_stride);
516 vst1q_f32(output + 0, vout0123n3);
517 vst1q_f32(output + 4, vout4567n3);
518 output = (float*restrict) ((uintptr_t) output + output_stride);
519 n -= 4;
520 }
521
522 // clean up loop, fall back to nr=1
523 if XNN_UNLIKELY(n != 0) {
524 do {
525 uint32_t nnz = *nnzmap++;
526 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
527 float32x4_t vacc4567 = vacc0123;
528 if XNN_LIKELY(nnz != 0) {
529 do {
530 const intptr_t diff = *dmap++;
531 const float32x4_t vi0123 = vld1q_f32(input);
532 const float32x4_t vi4567 = vld1q_f32(input + 4);
533 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
534 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
535 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
536 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
537 } while (--nnz != 0);
538 }
539 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
540 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
541
542 vout0123 = vmaxq_f32(vout0123, vmin);
543 vout4567 = vmaxq_f32(vout4567, vmin);
544
545 vst1q_f32(output + 0, vout0123);
546 vst1q_f32(output + 4, vout4567);
547 output = (float*restrict) ((uintptr_t) output + output_stride);
548 n -= 1;
549 } while (n != 0);
550 }
551 output = (float*restrict) ((uintptr_t) output - output_decrement);
552 input += 8;
553 }
554 output_decrement += 4 * sizeof(float);
555 if (mc & (4 * sizeof(float))) {
556 const float*restrict w = weights;
557 const int32_t* dmap = widx_dmap;
558 const uint32_t* nnzmap = nidx_nnzmap;
559 size_t n = nc;
560 while (n >= 4) {
561 uint32_t nnz = *nnzmap++;
562 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
563 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
564 float32x4_t vacc0123n2 = vld1q_dup_f32(w); w += 1;
565 float32x4_t vacc0123n3 = vld1q_dup_f32(w); w += 1;
566 if XNN_LIKELY(nnz != 0) {
567 do {
568 const intptr_t diff = *dmap++;
569 const float32x4_t vi0123 = vld1q_f32(input);
570 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
571 const float32x4_t vw = vld1q_f32(w); w += 4;
572
573 vacc0123n0 = vfmaq_laneq_f32(vacc0123n0, vi0123, vw, 0);
574 vacc0123n1 = vfmaq_laneq_f32(vacc0123n1, vi0123, vw, 1);
575 vacc0123n2 = vfmaq_laneq_f32(vacc0123n2, vi0123, vw, 2);
576 vacc0123n3 = vfmaq_laneq_f32(vacc0123n3, vi0123, vw, 3);
577 } while (--nnz != 0);
578 }
579 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
580 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
581 float32x4_t vout0123n2 = vminq_f32(vacc0123n2, vmax);
582 float32x4_t vout0123n3 = vminq_f32(vacc0123n3, vmax);
583
584 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
585 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
586 vout0123n2 = vmaxq_f32(vout0123n2, vmin);
587 vout0123n3 = vmaxq_f32(vout0123n3, vmin);
588
589 vst1q_f32(output + 0, vout0123n0);
590 output = (float*restrict) ((uintptr_t) output + output_stride);
591 vst1q_f32(output + 0, vout0123n1);
592 output = (float*restrict) ((uintptr_t) output + output_stride);
593 vst1q_f32(output + 0, vout0123n2);
594 output = (float*restrict) ((uintptr_t) output + output_stride);
595 vst1q_f32(output + 0, vout0123n3);
596 output = (float*restrict) ((uintptr_t) output + output_stride);
597 n -= 4;
598 }
599
600 // clean up loop, fall back to nr=1
601 if XNN_UNLIKELY(n != 0) {
602 do {
603 uint32_t nnz = *nnzmap++;
604 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
605 if XNN_LIKELY(nnz != 0) {
606 do {
607 const intptr_t diff = *dmap++;
608 const float32x4_t vi0123 = vld1q_f32(input);
609 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
610 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
611 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
612 } while (--nnz != 0);
613 }
614 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
615
616 vout0123 = vmaxq_f32(vout0123, vmin);
617
618 vst1q_f32(output + 0, vout0123);
619 output = (float*restrict) ((uintptr_t) output + output_stride);
620 n -= 1;
621 } while (n != 0);
622 }
623 output = (float*restrict) ((uintptr_t) output - output_decrement);
624 input += 4;
625 }
626 output_decrement += 2 * sizeof(float);
627 if (mc & (2 * sizeof(float))) {
628 const float*restrict w = weights;
629 const int32_t* dmap = widx_dmap;
630 const uint32_t* nnzmap = nidx_nnzmap;
631 size_t n = nc;
632 while (n >= 4) {
633 uint32_t nnz = *nnzmap++;
634 float32x2_t vacc01n0 = vld1_dup_f32(w); w += 1;
635 float32x2_t vacc01n1 = vld1_dup_f32(w); w += 1;
636 float32x2_t vacc01n2 = vld1_dup_f32(w); w += 1;
637 float32x2_t vacc01n3 = vld1_dup_f32(w); w += 1;
638 if XNN_LIKELY(nnz != 0) {
639 do {
640 const intptr_t diff = *dmap++;
641 const float32x2_t vi01 = vld1_f32(input);
642 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
643 const float32x4_t vw = vld1q_f32(w); w += 4;
644
645 vacc01n0 = vfma_laneq_f32(vacc01n0, vi01, vw, 0);
646 vacc01n1 = vfma_laneq_f32(vacc01n1, vi01, vw, 1);
647 vacc01n2 = vfma_laneq_f32(vacc01n2, vi01, vw, 2);
648 vacc01n3 = vfma_laneq_f32(vacc01n3, vi01, vw, 3);
649 } while (--nnz != 0);
650 }
651 float32x2_t vout01n0 = vmin_f32(vacc01n0, vget_low_f32(vmax));
652 float32x2_t vout01n1 = vmin_f32(vacc01n1, vget_low_f32(vmax));
653 float32x2_t vout01n2 = vmin_f32(vacc01n2, vget_low_f32(vmax));
654 float32x2_t vout01n3 = vmin_f32(vacc01n3, vget_low_f32(vmax));
655
656 vout01n0 = vmax_f32(vout01n0, vget_low_f32(vmin));
657 vout01n1 = vmax_f32(vout01n1, vget_low_f32(vmin));
658 vout01n2 = vmax_f32(vout01n2, vget_low_f32(vmin));
659 vout01n3 = vmax_f32(vout01n3, vget_low_f32(vmin));
660
661 vst1_f32(output + 0, vout01n0);
662 output = (float*restrict) ((uintptr_t) output + output_stride);
663 vst1_f32(output + 0, vout01n1);
664 output = (float*restrict) ((uintptr_t) output + output_stride);
665 vst1_f32(output + 0, vout01n2);
666 output = (float*restrict) ((uintptr_t) output + output_stride);
667 vst1_f32(output + 0, vout01n3);
668 output = (float*restrict) ((uintptr_t) output + output_stride);
669 n -= 4;
670 }
671
672 // clean up loop, fall back to nr=1
673 if XNN_UNLIKELY(n != 0) {
674 do {
675 uint32_t nnz = *nnzmap++;
676 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
677 if XNN_LIKELY(nnz != 0) {
678 do {
679 const intptr_t diff = *dmap++;
680 const float32x2_t vi01 = vld1_f32(input);
681 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
682 const float32x2_t vw = vld1_dup_f32(w); w += 1;
683 vacc01 = vfma_f32(vacc01, vi01, vw);
684 } while (--nnz != 0);
685 }
686 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
687 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
688
689 vst1_f32(output, vout01);
690 output = (float*restrict) ((uintptr_t) output + output_stride);
691 n -= 1;
692 } while (n != 0);
693 }
694 output = (float*restrict) ((uintptr_t) output - output_decrement);
695 input += 2;
696 }
697 output_decrement += 1 * sizeof(float);
698 if (mc & (1 * sizeof(float))) {
699 const float*restrict w = weights;
700 const int32_t* dmap = widx_dmap;
701 const uint32_t* nnzmap = nidx_nnzmap;
702 size_t n = nc;
703 while (n >= 4) {
704 uint32_t nnz = *nnzmap++;
705 float32x2_t vacc0n0 = vld1_dup_f32(w); w += 1;
706 float32x2_t vacc0n1 = vld1_dup_f32(w); w += 1;
707 float32x2_t vacc0n2 = vld1_dup_f32(w); w += 1;
708 float32x2_t vacc0n3 = vld1_dup_f32(w); w += 1;
709 if XNN_LIKELY(nnz != 0) {
710 do {
711 const intptr_t diff = *dmap++;
712 const float32x2_t vi0 = vld1_dup_f32(input);
713 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
714 const float32x4_t vw = vld1q_f32(w); w += 4;
715
716 vacc0n0 = vfma_laneq_f32(vacc0n0, vi0, vw, 0);
717 vacc0n1 = vfma_laneq_f32(vacc0n1, vi0, vw, 1);
718 vacc0n2 = vfma_laneq_f32(vacc0n2, vi0, vw, 2);
719 vacc0n3 = vfma_laneq_f32(vacc0n3, vi0, vw, 3);
720 } while (--nnz != 0);
721 }
722 float32x2_t vout0n0 = vmin_f32(vacc0n0, vget_low_f32(vmax));
723 float32x2_t vout0n1 = vmin_f32(vacc0n1, vget_low_f32(vmax));
724 float32x2_t vout0n2 = vmin_f32(vacc0n2, vget_low_f32(vmax));
725 float32x2_t vout0n3 = vmin_f32(vacc0n3, vget_low_f32(vmax));
726
727 vout0n0 = vmax_f32(vout0n0, vget_low_f32(vmin));
728 vout0n1 = vmax_f32(vout0n1, vget_low_f32(vmin));
729 vout0n2 = vmax_f32(vout0n2, vget_low_f32(vmin));
730 vout0n3 = vmax_f32(vout0n3, vget_low_f32(vmin));
731
732 vst1_lane_f32(output + 0, vout0n0, 0);
733 output = (float*restrict) ((uintptr_t) output + output_stride);
734 vst1_lane_f32(output + 0, vout0n1, 0);
735 output = (float*restrict) ((uintptr_t) output + output_stride);
736 vst1_lane_f32(output + 0, vout0n2, 0);
737 output = (float*restrict) ((uintptr_t) output + output_stride);
738 vst1_lane_f32(output + 0, vout0n3, 0);
739 output = (float*restrict) ((uintptr_t) output + output_stride);
740 n -= 4;
741 }
742
743 // clean up loop, fall back to nr=1
744 if XNN_UNLIKELY(n != 0) {
745 do {
746 uint32_t nnz = *nnzmap++;
747 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
748 if XNN_LIKELY(nnz != 0) {
749 do {
750 const intptr_t diff = *dmap++;
751 const float32x2_t vi0 = vld1_dup_f32(input);
752 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
753 const float32x2_t vw = vld1_dup_f32(w); w += 1;
754 vacc0 = vfma_f32(vacc0, vi0, vw);
755 } while (--nnz != 0);
756 }
757 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
758 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
759
760 vst1_lane_f32(output, vout0, 1);
761 output = (float*restrict) ((uintptr_t) output + output_stride);
762 n -= 1;
763 } while (n != 0);
764 }
765 output = (float*restrict) ((uintptr_t) output - output_decrement);
766 input += 1;
767 }
768 }
769 }
770