1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_32x1__neon_x2(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_32x1__neon_x2(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
34 size_t output_decrement = output_stride * nc - 32 * sizeof(float);
35 while XNN_LIKELY(mc >= 32 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 size_t n = nc;
40 do {
41 uint32_t nnz = *nnzmap++;
42 float32x4_t vacc0123x0 = vld1q_dup_f32(w); w += 1;
43 float32x4_t vacc0123x1 = vmovq_n_f32(0.0f);
44 float32x4_t vacc4567x0 = vacc0123x0;
45 float32x4_t vacc4567x1 = vmovq_n_f32(0.0f);
46 float32x4_t vacc89ABx0 = vacc0123x0;
47 float32x4_t vacc89ABx1 = vmovq_n_f32(0.0f);
48 float32x4_t vaccCDEFx0 = vacc0123x0;
49 float32x4_t vaccCDEFx1 = vmovq_n_f32(0.0f);
50 float32x4_t vaccGHIJx0 = vacc0123x0;
51 float32x4_t vaccGHIJx1 = vmovq_n_f32(0.0f);
52 float32x4_t vaccKLMNx0 = vacc0123x0;
53 float32x4_t vaccKLMNx1 = vmovq_n_f32(0.0f);
54 float32x4_t vaccOPQRx0 = vacc0123x0;
55 float32x4_t vaccOPQRx1 = vmovq_n_f32(0.0f);
56 float32x4_t vaccSTUVx0 = vacc0123x0;
57 float32x4_t vaccSTUVx1 = vmovq_n_f32(0.0f);
58 for (; nnz >= 2; nnz -= 2) {
59 const intptr_t diff0 = dmap[0];
60 const intptr_t diff1 = dmap[1];
61 dmap += 2;
62 const float32x4_t vi0123x0 = vld1q_f32(input);
63 const float32x4_t vi4567x0 = vld1q_f32(input + 4);
64 const float32x4_t vi89ABx0 = vld1q_f32(input + 8);
65 const float32x4_t viCDEFx0 = vld1q_f32(input + 12);
66 const float32x4_t viGHIJx0 = vld1q_f32(input + 16);
67 const float32x4_t viKLMNx0 = vld1q_f32(input + 20);
68 const float32x4_t viOPQRx0 = vld1q_f32(input + 24);
69 const float32x4_t viSTUVx0 = vld1q_f32(input + 28);
70 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff0);
71 __builtin_prefetch(input + 16);
72 __builtin_prefetch(input + 32);
73 const float32x4_t vw0 = vld1q_dup_f32(w); w += 1;
74 __builtin_prefetch(w + 32);
75 vacc0123x0 = vmlaq_f32(vacc0123x0, vi0123x0, vw0);
76 vacc4567x0 = vmlaq_f32(vacc4567x0, vi4567x0, vw0);
77 vacc89ABx0 = vmlaq_f32(vacc89ABx0, vi89ABx0, vw0);
78 vaccCDEFx0 = vmlaq_f32(vaccCDEFx0, viCDEFx0, vw0);
79 vaccGHIJx0 = vmlaq_f32(vaccGHIJx0, viGHIJx0, vw0);
80 vaccKLMNx0 = vmlaq_f32(vaccKLMNx0, viKLMNx0, vw0);
81 vaccOPQRx0 = vmlaq_f32(vaccOPQRx0, viOPQRx0, vw0);
82 vaccSTUVx0 = vmlaq_f32(vaccSTUVx0, viSTUVx0, vw0);
83 const float32x4_t vi0123x1 = vld1q_f32(input);
84 const float32x4_t vi4567x1 = vld1q_f32(input + 4);
85 const float32x4_t vi89ABx1 = vld1q_f32(input + 8);
86 const float32x4_t viCDEFx1 = vld1q_f32(input + 12);
87 const float32x4_t viGHIJx1 = vld1q_f32(input + 16);
88 const float32x4_t viKLMNx1 = vld1q_f32(input + 20);
89 const float32x4_t viOPQRx1 = vld1q_f32(input + 24);
90 const float32x4_t viSTUVx1 = vld1q_f32(input + 28);
91 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff1);
92 __builtin_prefetch(input + 16);
93 __builtin_prefetch(input + 32);
94 const float32x4_t vw1 = vld1q_dup_f32(w); w += 1;
95 __builtin_prefetch(w + 32);
96 vacc0123x1 = vmlaq_f32(vacc0123x1, vi0123x1, vw1);
97 vacc4567x1 = vmlaq_f32(vacc4567x1, vi4567x1, vw1);
98 vacc89ABx1 = vmlaq_f32(vacc89ABx1, vi89ABx1, vw1);
99 vaccCDEFx1 = vmlaq_f32(vaccCDEFx1, viCDEFx1, vw1);
100 vaccGHIJx1 = vmlaq_f32(vaccGHIJx1, viGHIJx1, vw1);
101 vaccKLMNx1 = vmlaq_f32(vaccKLMNx1, viKLMNx1, vw1);
102 vaccOPQRx1 = vmlaq_f32(vaccOPQRx1, viOPQRx1, vw1);
103 vaccSTUVx1 = vmlaq_f32(vaccSTUVx1, viSTUVx1, vw1);
104 }
105 float32x4_t vacc0123 = vacc0123x0;
106 float32x4_t vacc4567 = vacc4567x0;
107 float32x4_t vacc89AB = vacc89ABx0;
108 float32x4_t vaccCDEF = vaccCDEFx0;
109 float32x4_t vaccGHIJ = vaccGHIJx0;
110 float32x4_t vaccKLMN = vaccKLMNx0;
111 float32x4_t vaccOPQR = vaccOPQRx0;
112 float32x4_t vaccSTUV = vaccSTUVx0;
113 vacc0123 = vaddq_f32(vacc0123, vacc0123x1);
114 vacc4567 = vaddq_f32(vacc4567, vacc4567x1);
115 vacc89AB = vaddq_f32(vacc89AB, vacc89ABx1);
116 vaccCDEF = vaddq_f32(vaccCDEF, vaccCDEFx1);
117 vaccGHIJ = vaddq_f32(vaccGHIJ, vaccGHIJx1);
118 vaccKLMN = vaddq_f32(vaccKLMN, vaccKLMNx1);
119 vaccOPQR = vaddq_f32(vaccOPQR, vaccOPQRx1);
120 vaccSTUV = vaddq_f32(vaccSTUV, vaccSTUVx1);
121 if XNN_LIKELY(nnz != 0) {
122 do {
123 const intptr_t diff = *dmap++;
124 const float32x4_t vi0123 = vld1q_f32(input);
125 const float32x4_t vi4567 = vld1q_f32(input + 4);
126 const float32x4_t vi89AB = vld1q_f32(input + 8);
127 const float32x4_t viCDEF = vld1q_f32(input + 12);
128 const float32x4_t viGHIJ = vld1q_f32(input + 16);
129 const float32x4_t viKLMN = vld1q_f32(input + 20);
130 const float32x4_t viOPQR = vld1q_f32(input + 24);
131 const float32x4_t viSTUV = vld1q_f32(input + 28);
132 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
133 __builtin_prefetch(input + 16);
134 __builtin_prefetch(input + 32);
135 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
136 __builtin_prefetch(w + 32);
137 vacc0123 = vmlaq_f32(vacc0123, vi0123, vw);
138 vacc4567 = vmlaq_f32(vacc4567, vi4567, vw);
139 vacc89AB = vmlaq_f32(vacc89AB, vi89AB, vw);
140 vaccCDEF = vmlaq_f32(vaccCDEF, viCDEF, vw);
141 vaccGHIJ = vmlaq_f32(vaccGHIJ, viGHIJ, vw);
142 vaccKLMN = vmlaq_f32(vaccKLMN, viKLMN, vw);
143 vaccOPQR = vmlaq_f32(vaccOPQR, viOPQR, vw);
144 vaccSTUV = vmlaq_f32(vaccSTUV, viSTUV, vw);
145 } while (--nnz != 0);
146 }
147 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
148 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
149 float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
150 float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
151 float32x4_t voutGHIJ = vminq_f32(vaccGHIJ, vmax);
152 float32x4_t voutKLMN = vminq_f32(vaccKLMN, vmax);
153 float32x4_t voutOPQR = vminq_f32(vaccOPQR, vmax);
154 float32x4_t voutSTUV = vminq_f32(vaccSTUV, vmax);
155 vout0123 = vmaxq_f32(vout0123, vmin);
156 vout4567 = vmaxq_f32(vout4567, vmin);
157 vout89AB = vmaxq_f32(vout89AB, vmin);
158 voutCDEF = vmaxq_f32(voutCDEF, vmin);
159 voutGHIJ = vmaxq_f32(voutGHIJ, vmin);
160 voutKLMN = vmaxq_f32(voutKLMN, vmin);
161 voutOPQR = vmaxq_f32(voutOPQR, vmin);
162 voutSTUV = vmaxq_f32(voutSTUV, vmin);
163 vst1q_f32(output, vout0123);
164 vst1q_f32(output + 4, vout4567);
165 vst1q_f32(output + 8, vout89AB);
166 vst1q_f32(output + 12, voutCDEF);
167 vst1q_f32(output + 16, voutGHIJ);
168 vst1q_f32(output + 20, voutKLMN);
169 vst1q_f32(output + 24, voutOPQR);
170 vst1q_f32(output + 28, voutSTUV);
171 output = (float*restrict) ((uintptr_t) output + output_stride);
172 } while (--n != 0);
173 output = (float*restrict) ((uintptr_t) output - output_decrement);
174 input += 32;
175 mc -= 32 * sizeof(float);
176 }
177 if XNN_UNLIKELY(mc != 0) {
178 output_decrement += 16 * sizeof(float);
179 if (mc & (16 * sizeof(float))) {
180 const float*restrict w = weights;
181 const int32_t* dmap = widx_dmap;
182 const uint32_t* nnzmap = nidx_nnzmap;
183 size_t n = nc;
184 do {
185 uint32_t nnz = *nnzmap++;
186 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
187 float32x4_t vacc4567 = vacc0123;
188 float32x4_t vacc89AB = vacc0123;
189 float32x4_t vaccCDEF = vacc0123;
190 if XNN_LIKELY(nnz != 0) {
191 do {
192 const intptr_t diff = *dmap++;
193 const float32x4_t vi0123 = vld1q_f32(input);
194 const float32x4_t vi4567 = vld1q_f32(input + 4);
195 const float32x4_t vi89AB = vld1q_f32(input + 8);
196 const float32x4_t viCDEF = vld1q_f32(input + 12);
197 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
198 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
199 vacc0123 = vmlaq_f32(vacc0123, vi0123, vw);
200 vacc4567 = vmlaq_f32(vacc4567, vi4567, vw);
201 vacc89AB = vmlaq_f32(vacc89AB, vi89AB, vw);
202 vaccCDEF = vmlaq_f32(vaccCDEF, viCDEF, vw);
203 } while (--nnz != 0);
204 }
205 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
206 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
207 float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
208 float32x4_t voutCDEF = vminq_f32(vaccCDEF, vmax);
209 vout0123 = vmaxq_f32(vout0123, vmin);
210 vout4567 = vmaxq_f32(vout4567, vmin);
211 vout89AB = vmaxq_f32(vout89AB, vmin);
212 voutCDEF = vmaxq_f32(voutCDEF, vmin);
213 vst1q_f32(output, vout0123);
214 vst1q_f32(output + 4, vout4567);
215 vst1q_f32(output + 8, vout89AB);
216 vst1q_f32(output + 12, voutCDEF);
217 output = (float*restrict) ((uintptr_t) output + output_stride);
218 } while (--n != 0);
219 output = (float*restrict) ((uintptr_t) output - output_decrement);
220 input += 16;
221 }
222 output_decrement += 8 * sizeof(float);
223 if (mc & (8 * sizeof(float))) {
224 const float*restrict w = weights;
225 const int32_t* dmap = widx_dmap;
226 const uint32_t* nnzmap = nidx_nnzmap;
227 size_t n = nc;
228 do {
229 uint32_t nnz = *nnzmap++;
230 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
231 float32x4_t vacc4567 = vacc0123;
232 if XNN_LIKELY(nnz != 0) {
233 do {
234 const intptr_t diff = *dmap++;
235 const float32x4_t vi0123 = vld1q_f32(input);
236 const float32x4_t vi4567 = vld1q_f32(input + 4);
237 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
238 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
239 vacc0123 = vmlaq_f32(vacc0123, vi0123, vw);
240 vacc4567 = vmlaq_f32(vacc4567, vi4567, vw);
241 } while (--nnz != 0);
242 }
243 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
244 float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
245 vout0123 = vmaxq_f32(vout0123, vmin);
246 vout4567 = vmaxq_f32(vout4567, vmin);
247 vst1q_f32(output, vout0123);
248 vst1q_f32(output + 4, vout4567);
249 output = (float*restrict) ((uintptr_t) output + output_stride);
250 } while (--n != 0);
251 output = (float*restrict) ((uintptr_t) output - output_decrement);
252 input += 8;
253 }
254 output_decrement += 4 * sizeof(float);
255 if (mc & (4 * sizeof(float))) {
256 const float*restrict w = weights;
257 const int32_t* dmap = widx_dmap;
258 const uint32_t* nnzmap = nidx_nnzmap;
259 size_t n = nc;
260 do {
261 uint32_t nnz = *nnzmap++;
262 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
263 if XNN_LIKELY(nnz != 0) {
264 do {
265 const intptr_t diff = *dmap++;
266 const float32x4_t vi0123 = vld1q_f32(input);
267 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
268 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
269 vacc0123 = vmlaq_f32(vacc0123, vi0123, vw);
270 } while (--nnz != 0);
271 }
272 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
273 vout0123 = vmaxq_f32(vout0123, vmin);
274 vst1q_f32(output, vout0123);
275 output = (float*restrict) ((uintptr_t) output + output_stride);
276 } while (--n != 0);
277 output = (float*restrict) ((uintptr_t) output - output_decrement);
278 input += 4;
279 }
280 output_decrement += 2 * sizeof(float);
281 if (mc & (2 * sizeof(float))) {
282 const float*restrict w = weights;
283 const int32_t* dmap = widx_dmap;
284 const uint32_t* nnzmap = nidx_nnzmap;
285 size_t n = nc;
286 do {
287 uint32_t nnz = *nnzmap++;
288 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
289 if XNN_LIKELY(nnz != 0) {
290 do {
291 const intptr_t diff = *dmap++;
292 const float32x2_t vi01 = vld1_f32(input);
293 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
294 const float32x2_t vw = vld1_dup_f32(w); w += 1;
295 vacc01 = vmla_f32(vacc01, vi01, vw);
296 } while (--nnz != 0);
297 }
298 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
299 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
300 vst1_f32(output, vout01);
301 output = (float*restrict) ((uintptr_t) output + output_stride);
302 } while (--n != 0);
303 output = (float*restrict) ((uintptr_t) output - output_decrement);
304 input += 2;
305 }
306 output_decrement += 1 * sizeof(float);
307 if (mc & (1 * sizeof(float))) {
308 const float*restrict w = weights;
309 const int32_t* dmap = widx_dmap;
310 const uint32_t* nnzmap = nidx_nnzmap;
311 size_t n = nc;
312 do {
313 uint32_t nnz = *nnzmap++;
314 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
315 if XNN_LIKELY(nnz != 0) {
316 do {
317 const intptr_t diff = *dmap++;
318 const float32x2_t vi0 = vld1_dup_f32(input);
319 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
320 const float32x2_t vw = vld1_dup_f32(w); w += 1;
321 vacc0 = vmla_f32(vacc0, vi0, vw);
322 } while (--nnz != 0);
323 }
324 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
325 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
326 vst1_lane_f32(output, vout0, 0);
327 output = (float*restrict) ((uintptr_t) output + output_stride);
328 } while (--n != 0);
329 output = (float*restrict) ((uintptr_t) output - output_decrement);
330 input += 1;
331 }
332 }
333 }
334