1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon-blocked.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_8x4__neonfma(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_8x4__neonfma(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
34 size_t output_decrement = output_stride * nc - 8 * sizeof(float);
35 while XNN_LIKELY(mc >= 8 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 size_t n = nc;
40 while (n >= 4) {
41 uint32_t nnz = *nnzmap++;
42 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
43 float32x4_t vacc4567n0 = vacc0123n0;
44 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
45 float32x4_t vacc4567n1 = vacc0123n1;
46 float32x4_t vacc0123n2 = vld1q_dup_f32(w); w += 1;
47 float32x4_t vacc4567n2 = vacc0123n2;
48 float32x4_t vacc0123n3 = vld1q_dup_f32(w); w += 1;
49 float32x4_t vacc4567n3 = vacc0123n3;
50 if XNN_LIKELY(nnz != 0) {
51 do {
52 const intptr_t diff = *dmap++;
53 const float32x4_t vi0123 = vld1q_f32(input);
54 const float32x4_t vi4567 = vld1q_f32(input + 4);
55 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
56 __builtin_prefetch(input + 16);
57 const float32x4_t vw = vld1q_f32(w); w += 4;
58 __builtin_prefetch(w + 32);
59 vacc0123n0 = vfmaq_laneq_f32(vacc0123n0, vi0123, vw, 0);
60 vacc4567n0 = vfmaq_laneq_f32(vacc4567n0, vi4567, vw, 0);
61 vacc0123n1 = vfmaq_laneq_f32(vacc0123n1, vi0123, vw, 1);
62 vacc4567n1 = vfmaq_laneq_f32(vacc4567n1, vi4567, vw, 1);
63 vacc0123n2 = vfmaq_laneq_f32(vacc0123n2, vi0123, vw, 2);
64 vacc4567n2 = vfmaq_laneq_f32(vacc4567n2, vi4567, vw, 2);
65 vacc0123n3 = vfmaq_laneq_f32(vacc0123n3, vi0123, vw, 3);
66 vacc4567n3 = vfmaq_laneq_f32(vacc4567n3, vi4567, vw, 3);
67 } while (--nnz != 0);
68 }
69 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
70 float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
71 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
72 float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
73 float32x4_t vout0123n2 = vminq_f32(vacc0123n2, vmax);
74 float32x4_t vout4567n2 = vminq_f32(vacc4567n2, vmax);
75 float32x4_t vout0123n3 = vminq_f32(vacc0123n3, vmax);
76 float32x4_t vout4567n3 = vminq_f32(vacc4567n3, vmax);
77
78 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
79 vout4567n0 = vmaxq_f32(vout4567n0, vmin);
80 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
81 vout4567n1 = vmaxq_f32(vout4567n1, vmin);
82 vout0123n2 = vmaxq_f32(vout0123n2, vmin);
83 vout4567n2 = vmaxq_f32(vout4567n2, vmin);
84 vout0123n3 = vmaxq_f32(vout0123n3, vmin);
85 vout4567n3 = vmaxq_f32(vout4567n3, vmin);
86
87 vst1q_f32(output + 0, vout0123n0);
88 vst1q_f32(output + 4, vout4567n0);
89 output = (float*restrict) ((uintptr_t) output + output_stride);
90 vst1q_f32(output + 0, vout0123n1);
91 vst1q_f32(output + 4, vout4567n1);
92 output = (float*restrict) ((uintptr_t) output + output_stride);
93 vst1q_f32(output + 0, vout0123n2);
94 vst1q_f32(output + 4, vout4567n2);
95 output = (float*restrict) ((uintptr_t) output + output_stride);
96 vst1q_f32(output + 0, vout0123n3);
97 vst1q_f32(output + 4, vout4567n3);
98 output = (float*restrict) ((uintptr_t) output + output_stride);
99 n -= 4;
100 }
101
102 // clean up loop, fall back to nr=1
103 if XNN_UNLIKELY(n != 0) {
104 do {
105 uint32_t nnz = *nnzmap++;
106 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
107 float32x4_t vacc4567 = vacc0123;
108 if XNN_LIKELY(nnz != 0) {
109 do {
110 const intptr_t diff = *dmap++;
111 const float32x4_t vi0123 = vld1q_f32(input);
112 const float32x4_t vi4567 = vld1q_f32(input + 4);
113 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
114 __builtin_prefetch(input + 16);
115 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
116 __builtin_prefetch(w + 32);
117 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
118 vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
119 } while (--nnz != 0);
120 }
121 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
122 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
123
124 vout0123 = vmaxq_f32(vout0123, vmin);
125 vout4567 = vmaxq_f32(vout4567, vmin);
126
127 vst1q_f32(output + 0, vout0123);
128 vst1q_f32(output + 4, vout4567);
129 output = (float*restrict) ((uintptr_t) output + output_stride);
130 n -= 1;
131 } while (n != 0);
132 }
133 output = (float*restrict) ((uintptr_t) output - output_decrement);
134 input += 8;
135 mc -= 8 * sizeof(float);
136 }
137 if XNN_UNLIKELY(mc != 0) {
138 output_decrement += 4 * sizeof(float);
139 if (mc & (4 * sizeof(float))) {
140 const float*restrict w = weights;
141 const int32_t* dmap = widx_dmap;
142 const uint32_t* nnzmap = nidx_nnzmap;
143 size_t n = nc;
144 while (n >= 4) {
145 uint32_t nnz = *nnzmap++;
146 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
147 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
148 float32x4_t vacc0123n2 = vld1q_dup_f32(w); w += 1;
149 float32x4_t vacc0123n3 = vld1q_dup_f32(w); w += 1;
150 if XNN_LIKELY(nnz != 0) {
151 do {
152 const intptr_t diff = *dmap++;
153 const float32x4_t vi0123 = vld1q_f32(input);
154 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
155 const float32x4_t vw = vld1q_f32(w); w += 4;
156
157 vacc0123n0 = vfmaq_laneq_f32(vacc0123n0, vi0123, vw, 0);
158 vacc0123n1 = vfmaq_laneq_f32(vacc0123n1, vi0123, vw, 1);
159 vacc0123n2 = vfmaq_laneq_f32(vacc0123n2, vi0123, vw, 2);
160 vacc0123n3 = vfmaq_laneq_f32(vacc0123n3, vi0123, vw, 3);
161 } while (--nnz != 0);
162 }
163 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
164 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
165 float32x4_t vout0123n2 = vminq_f32(vacc0123n2, vmax);
166 float32x4_t vout0123n3 = vminq_f32(vacc0123n3, vmax);
167
168 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
169 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
170 vout0123n2 = vmaxq_f32(vout0123n2, vmin);
171 vout0123n3 = vmaxq_f32(vout0123n3, vmin);
172
173 vst1q_f32(output + 0, vout0123n0);
174 output = (float*restrict) ((uintptr_t) output + output_stride);
175 vst1q_f32(output + 0, vout0123n1);
176 output = (float*restrict) ((uintptr_t) output + output_stride);
177 vst1q_f32(output + 0, vout0123n2);
178 output = (float*restrict) ((uintptr_t) output + output_stride);
179 vst1q_f32(output + 0, vout0123n3);
180 output = (float*restrict) ((uintptr_t) output + output_stride);
181 n -= 4;
182 }
183
184 // clean up loop, fall back to nr=1
185 if XNN_UNLIKELY(n != 0) {
186 do {
187 uint32_t nnz = *nnzmap++;
188 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
189 if XNN_LIKELY(nnz != 0) {
190 do {
191 const intptr_t diff = *dmap++;
192 const float32x4_t vi0123 = vld1q_f32(input);
193 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
194 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
195 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
196 } while (--nnz != 0);
197 }
198 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
199
200 vout0123 = vmaxq_f32(vout0123, vmin);
201
202 vst1q_f32(output + 0, vout0123);
203 output = (float*restrict) ((uintptr_t) output + output_stride);
204 n -= 1;
205 } while (n != 0);
206 }
207 output = (float*restrict) ((uintptr_t) output - output_decrement);
208 input += 4;
209 }
210 output_decrement += 2 * sizeof(float);
211 if (mc & (2 * sizeof(float))) {
212 const float*restrict w = weights;
213 const int32_t* dmap = widx_dmap;
214 const uint32_t* nnzmap = nidx_nnzmap;
215 size_t n = nc;
216 while (n >= 4) {
217 uint32_t nnz = *nnzmap++;
218 float32x2_t vacc01n0 = vld1_dup_f32(w); w += 1;
219 float32x2_t vacc01n1 = vld1_dup_f32(w); w += 1;
220 float32x2_t vacc01n2 = vld1_dup_f32(w); w += 1;
221 float32x2_t vacc01n3 = vld1_dup_f32(w); w += 1;
222 if XNN_LIKELY(nnz != 0) {
223 do {
224 const intptr_t diff = *dmap++;
225 const float32x2_t vi01 = vld1_f32(input);
226 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
227 const float32x4_t vw = vld1q_f32(w); w += 4;
228
229 vacc01n0 = vfma_laneq_f32(vacc01n0, vi01, vw, 0);
230 vacc01n1 = vfma_laneq_f32(vacc01n1, vi01, vw, 1);
231 vacc01n2 = vfma_laneq_f32(vacc01n2, vi01, vw, 2);
232 vacc01n3 = vfma_laneq_f32(vacc01n3, vi01, vw, 3);
233 } while (--nnz != 0);
234 }
235 float32x2_t vout01n0 = vmin_f32(vacc01n0, vget_low_f32(vmax));
236 float32x2_t vout01n1 = vmin_f32(vacc01n1, vget_low_f32(vmax));
237 float32x2_t vout01n2 = vmin_f32(vacc01n2, vget_low_f32(vmax));
238 float32x2_t vout01n3 = vmin_f32(vacc01n3, vget_low_f32(vmax));
239
240 vout01n0 = vmax_f32(vout01n0, vget_low_f32(vmin));
241 vout01n1 = vmax_f32(vout01n1, vget_low_f32(vmin));
242 vout01n2 = vmax_f32(vout01n2, vget_low_f32(vmin));
243 vout01n3 = vmax_f32(vout01n3, vget_low_f32(vmin));
244
245 vst1_f32(output + 0, vout01n0);
246 output = (float*restrict) ((uintptr_t) output + output_stride);
247 vst1_f32(output + 0, vout01n1);
248 output = (float*restrict) ((uintptr_t) output + output_stride);
249 vst1_f32(output + 0, vout01n2);
250 output = (float*restrict) ((uintptr_t) output + output_stride);
251 vst1_f32(output + 0, vout01n3);
252 output = (float*restrict) ((uintptr_t) output + output_stride);
253 n -= 4;
254 }
255
256 // clean up loop, fall back to nr=1
257 if XNN_UNLIKELY(n != 0) {
258 do {
259 uint32_t nnz = *nnzmap++;
260 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
261 if XNN_LIKELY(nnz != 0) {
262 do {
263 const intptr_t diff = *dmap++;
264 const float32x2_t vi01 = vld1_f32(input);
265 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
266 const float32x2_t vw = vld1_dup_f32(w); w += 1;
267 vacc01 = vfma_f32(vacc01, vi01, vw);
268 } while (--nnz != 0);
269 }
270 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
271 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
272
273 vst1_f32(output, vout01);
274 output = (float*restrict) ((uintptr_t) output + output_stride);
275 n -= 1;
276 } while (n != 0);
277 }
278 output = (float*restrict) ((uintptr_t) output - output_decrement);
279 input += 2;
280 }
281 output_decrement += 1 * sizeof(float);
282 if (mc & (1 * sizeof(float))) {
283 const float*restrict w = weights;
284 const int32_t* dmap = widx_dmap;
285 const uint32_t* nnzmap = nidx_nnzmap;
286 size_t n = nc;
287 while (n >= 4) {
288 uint32_t nnz = *nnzmap++;
289 float32x2_t vacc0n0 = vld1_dup_f32(w); w += 1;
290 float32x2_t vacc0n1 = vld1_dup_f32(w); w += 1;
291 float32x2_t vacc0n2 = vld1_dup_f32(w); w += 1;
292 float32x2_t vacc0n3 = vld1_dup_f32(w); w += 1;
293 if XNN_LIKELY(nnz != 0) {
294 do {
295 const intptr_t diff = *dmap++;
296 const float32x2_t vi0 = vld1_dup_f32(input);
297 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
298 const float32x4_t vw = vld1q_f32(w); w += 4;
299
300 vacc0n0 = vfma_laneq_f32(vacc0n0, vi0, vw, 0);
301 vacc0n1 = vfma_laneq_f32(vacc0n1, vi0, vw, 1);
302 vacc0n2 = vfma_laneq_f32(vacc0n2, vi0, vw, 2);
303 vacc0n3 = vfma_laneq_f32(vacc0n3, vi0, vw, 3);
304 } while (--nnz != 0);
305 }
306 float32x2_t vout0n0 = vmin_f32(vacc0n0, vget_low_f32(vmax));
307 float32x2_t vout0n1 = vmin_f32(vacc0n1, vget_low_f32(vmax));
308 float32x2_t vout0n2 = vmin_f32(vacc0n2, vget_low_f32(vmax));
309 float32x2_t vout0n3 = vmin_f32(vacc0n3, vget_low_f32(vmax));
310
311 vout0n0 = vmax_f32(vout0n0, vget_low_f32(vmin));
312 vout0n1 = vmax_f32(vout0n1, vget_low_f32(vmin));
313 vout0n2 = vmax_f32(vout0n2, vget_low_f32(vmin));
314 vout0n3 = vmax_f32(vout0n3, vget_low_f32(vmin));
315
316 vst1_lane_f32(output + 0, vout0n0, 0);
317 output = (float*restrict) ((uintptr_t) output + output_stride);
318 vst1_lane_f32(output + 0, vout0n1, 0);
319 output = (float*restrict) ((uintptr_t) output + output_stride);
320 vst1_lane_f32(output + 0, vout0n2, 0);
321 output = (float*restrict) ((uintptr_t) output + output_stride);
322 vst1_lane_f32(output + 0, vout0n3, 0);
323 output = (float*restrict) ((uintptr_t) output + output_stride);
324 n -= 4;
325 }
326
327 // clean up loop, fall back to nr=1
328 if XNN_UNLIKELY(n != 0) {
329 do {
330 uint32_t nnz = *nnzmap++;
331 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
332 if XNN_LIKELY(nnz != 0) {
333 do {
334 const intptr_t diff = *dmap++;
335 const float32x2_t vi0 = vld1_dup_f32(input);
336 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
337 const float32x2_t vw = vld1_dup_f32(w); w += 1;
338 vacc0 = vfma_f32(vacc0, vi0, vw);
339 } while (--nnz != 0);
340 }
341 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
342 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
343
344 vst1_lane_f32(output, vout0, 1);
345 output = (float*restrict) ((uintptr_t) output + output_stride);
346 n -= 1;
347 } while (n != 0);
348 }
349 output = (float*restrict) ((uintptr_t) output - output_decrement);
350 input += 1;
351 }
352 }
353 }
354