xref: /aosp_15_r20/external/XNNPACK/src/f32-spmm/gen/12x2-minmax-neonfma.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-spmm/neon-blocked.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <arm_neon.h>
13 
14 #include <xnnpack/spmm.h>
15 
16 
xnn_f32_spmm_minmax_ukernel_12x2__neonfma(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_12x2__neonfma(
18     size_t mc,
19     size_t nc,
20     const float*restrict input,
21     const float*restrict weights,
22     const int32_t*restrict widx_dmap,
23     const uint32_t*restrict nidx_nnzmap,
24     float*restrict output,
25     size_t output_stride,
26     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28   assert(mc != 0);
29   assert(mc % sizeof(float) == 0);
30   assert(nc != 0);
31 
32   const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
33   const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
34   size_t output_decrement = output_stride * nc - 12 * sizeof(float);
35   while XNN_LIKELY(mc >= 12 * sizeof(float)) {
36     const float*restrict w = weights;
37     const int32_t* dmap = widx_dmap;
38     const uint32_t* nnzmap = nidx_nnzmap;
39     size_t n = nc;
40     while (n >= 2) {
41       uint32_t nnz = *nnzmap++;
42       float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
43       float32x4_t vacc4567n0 = vacc0123n0;
44       float32x4_t vacc89ABn0 = vacc0123n0;
45       float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
46       float32x4_t vacc4567n1 = vacc0123n1;
47       float32x4_t vacc89ABn1 = vacc0123n1;
48       if XNN_LIKELY(nnz != 0) {
49         do {
50           const intptr_t diff = *dmap++;
51           const float32x4_t vi0123 = vld1q_f32(input);
52           const float32x4_t vi4567 = vld1q_f32(input + 4);
53           const float32x4_t vi89AB = vld1q_f32(input + 8);
54           input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
55           __builtin_prefetch(input + 16);
56           const float32x2_t vw = vld1_f32(w); w += 2;
57           __builtin_prefetch(w + 32);
58           vacc0123n0 = vfmaq_lane_f32(vacc0123n0, vi0123, vw, 0);
59           vacc4567n0 = vfmaq_lane_f32(vacc4567n0, vi4567, vw, 0);
60           vacc89ABn0 = vfmaq_lane_f32(vacc89ABn0, vi89AB, vw, 0);
61           vacc0123n1 = vfmaq_lane_f32(vacc0123n1, vi0123, vw, 1);
62           vacc4567n1 = vfmaq_lane_f32(vacc4567n1, vi4567, vw, 1);
63           vacc89ABn1 = vfmaq_lane_f32(vacc89ABn1, vi89AB, vw, 1);
64         } while (--nnz != 0);
65       }
66       float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
67       float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
68       float32x4_t vout89ABn0 = vminq_f32(vacc89ABn0, vmax);
69       float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
70       float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
71       float32x4_t vout89ABn1 = vminq_f32(vacc89ABn1, vmax);
72 
73       vout0123n0 = vmaxq_f32(vout0123n0, vmin);
74       vout4567n0 = vmaxq_f32(vout4567n0, vmin);
75       vout89ABn0 = vmaxq_f32(vout89ABn0, vmin);
76       vout0123n1 = vmaxq_f32(vout0123n1, vmin);
77       vout4567n1 = vmaxq_f32(vout4567n1, vmin);
78       vout89ABn1 = vmaxq_f32(vout89ABn1, vmin);
79 
80       vst1q_f32(output + 0, vout0123n0);
81       vst1q_f32(output + 4, vout4567n0);
82       vst1q_f32(output + 8, vout89ABn0);
83       output = (float*restrict) ((uintptr_t) output + output_stride);
84       vst1q_f32(output + 0, vout0123n1);
85       vst1q_f32(output + 4, vout4567n1);
86       vst1q_f32(output + 8, vout89ABn1);
87       output = (float*restrict) ((uintptr_t) output + output_stride);
88       n -= 2;
89     }
90 
91     // clean up loop, fall back to nr=1
92     if XNN_UNLIKELY(n != 0) {
93       do {
94         uint32_t nnz = *nnzmap++;
95         float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
96         float32x4_t vacc4567 = vacc0123;
97         float32x4_t vacc89AB = vacc0123;
98         if XNN_LIKELY(nnz != 0) {
99           do {
100             const intptr_t diff = *dmap++;
101             const float32x4_t vi0123 = vld1q_f32(input);
102             const float32x4_t vi4567 = vld1q_f32(input + 4);
103             const float32x4_t vi89AB = vld1q_f32(input + 8);
104             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
105             __builtin_prefetch(input + 16);
106             const float32x4_t vw = vld1q_dup_f32(w); w += 1;
107             __builtin_prefetch(w + 32);
108             vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
109             vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
110             vacc89AB = vfmaq_f32(vacc89AB, vi89AB, vw);
111           } while (--nnz != 0);
112         }
113         float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
114         float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
115         float32x4_t vout89AB = vminq_f32(vacc89AB, vmax);
116 
117         vout0123 = vmaxq_f32(vout0123, vmin);
118         vout4567 = vmaxq_f32(vout4567, vmin);
119         vout89AB = vmaxq_f32(vout89AB, vmin);
120 
121         vst1q_f32(output + 0, vout0123);
122         vst1q_f32(output + 4, vout4567);
123         vst1q_f32(output + 8, vout89AB);
124         output = (float*restrict) ((uintptr_t) output + output_stride);
125         n -= 1;
126       } while (n != 0);
127     }
128     output = (float*restrict) ((uintptr_t) output - output_decrement);
129     input += 12;
130     mc -= 12 * sizeof(float);
131   }
132   if XNN_UNLIKELY(mc != 0) {
133     output_decrement += 4 * sizeof(float);
134     if (mc & (8 * sizeof(float))) {
135       const float*restrict w = weights;
136       const int32_t* dmap = widx_dmap;
137       const uint32_t* nnzmap = nidx_nnzmap;
138       size_t n = nc;
139       while (n >= 2) {
140         uint32_t nnz = *nnzmap++;
141         float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
142         float32x4_t vacc4567n0 = vacc0123n0;
143         float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
144         float32x4_t vacc4567n1 = vacc0123n1;
145         if XNN_LIKELY(nnz != 0) {
146           do {
147             const intptr_t diff = *dmap++;
148             const float32x4_t vi0123 = vld1q_f32(input);
149             const float32x4_t vi4567 = vld1q_f32(input + 4);
150             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
151             const float32x2_t vw = vld1_f32(w); w += 2;
152 
153             vacc0123n0 = vfmaq_lane_f32(vacc0123n0, vi0123, vw, 0);
154             vacc4567n0 = vfmaq_lane_f32(vacc4567n0, vi4567, vw, 0);
155             vacc0123n1 = vfmaq_lane_f32(vacc0123n1, vi0123, vw, 1);
156             vacc4567n1 = vfmaq_lane_f32(vacc4567n1, vi4567, vw, 1);
157           } while (--nnz != 0);
158         }
159         float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
160         float32x4_t vout4567n0 = vminq_f32(vacc4567n0, vmax);
161         float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
162         float32x4_t vout4567n1 = vminq_f32(vacc4567n1, vmax);
163 
164         vout0123n0 = vmaxq_f32(vout0123n0, vmin);
165         vout4567n0 = vmaxq_f32(vout4567n0, vmin);
166         vout0123n1 = vmaxq_f32(vout0123n1, vmin);
167         vout4567n1 = vmaxq_f32(vout4567n1, vmin);
168 
169         vst1q_f32(output + 0, vout0123n0);
170         vst1q_f32(output + 4, vout4567n0);
171         output = (float*restrict) ((uintptr_t) output + output_stride);
172         vst1q_f32(output + 0, vout0123n1);
173         vst1q_f32(output + 4, vout4567n1);
174         output = (float*restrict) ((uintptr_t) output + output_stride);
175         n -= 2;
176       }
177 
178       // clean up loop, fall back to nr=1
179       if XNN_UNLIKELY(n != 0) {
180         do {
181           uint32_t nnz = *nnzmap++;
182           float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
183           float32x4_t vacc4567 = vacc0123;
184           if XNN_LIKELY(nnz != 0) {
185             do {
186               const intptr_t diff = *dmap++;
187               const float32x4_t vi0123 = vld1q_f32(input);
188               const float32x4_t vi4567 = vld1q_f32(input + 4);
189               input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
190               const float32x4_t vw = vld1q_dup_f32(w); w += 1;
191               vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
192               vacc4567 = vfmaq_f32(vacc4567, vi4567, vw);
193             } while (--nnz != 0);
194           }
195           float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
196           float32x4_t vout4567 = vminq_f32(vacc4567, vmax);
197 
198           vout0123 = vmaxq_f32(vout0123, vmin);
199           vout4567 = vmaxq_f32(vout4567, vmin);
200 
201           vst1q_f32(output + 0, vout0123);
202           vst1q_f32(output + 4, vout4567);
203           output = (float*restrict) ((uintptr_t) output + output_stride);
204           n -= 1;
205         } while (n != 0);
206       }
207       output = (float*restrict) ((uintptr_t) output - output_decrement);
208       input += 8;
209     }
210     output_decrement += 4 * sizeof(float);
211     if (mc & (4 * sizeof(float))) {
212       const float*restrict w = weights;
213       const int32_t* dmap = widx_dmap;
214       const uint32_t* nnzmap = nidx_nnzmap;
215       size_t n = nc;
216       while (n >= 2) {
217         uint32_t nnz = *nnzmap++;
218         float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
219         float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
220         if XNN_LIKELY(nnz != 0) {
221           do {
222             const intptr_t diff = *dmap++;
223             const float32x4_t vi0123 = vld1q_f32(input);
224             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
225             const float32x2_t vw = vld1_f32(w); w += 2;
226 
227             vacc0123n0 = vfmaq_lane_f32(vacc0123n0, vi0123, vw, 0);
228             vacc0123n1 = vfmaq_lane_f32(vacc0123n1, vi0123, vw, 1);
229           } while (--nnz != 0);
230         }
231         float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
232         float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
233 
234         vout0123n0 = vmaxq_f32(vout0123n0, vmin);
235         vout0123n1 = vmaxq_f32(vout0123n1, vmin);
236 
237         vst1q_f32(output + 0, vout0123n0);
238         output = (float*restrict) ((uintptr_t) output + output_stride);
239         vst1q_f32(output + 0, vout0123n1);
240         output = (float*restrict) ((uintptr_t) output + output_stride);
241         n -= 2;
242       }
243 
244       // clean up loop, fall back to nr=1
245       if XNN_UNLIKELY(n != 0) {
246         do {
247           uint32_t nnz = *nnzmap++;
248           float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
249           if XNN_LIKELY(nnz != 0) {
250             do {
251               const intptr_t diff = *dmap++;
252               const float32x4_t vi0123 = vld1q_f32(input);
253               input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
254               const float32x4_t vw = vld1q_dup_f32(w); w += 1;
255               vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
256             } while (--nnz != 0);
257           }
258           float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
259 
260           vout0123 = vmaxq_f32(vout0123, vmin);
261 
262           vst1q_f32(output + 0, vout0123);
263           output = (float*restrict) ((uintptr_t) output + output_stride);
264           n -= 1;
265         } while (n != 0);
266       }
267       output = (float*restrict) ((uintptr_t) output - output_decrement);
268       input += 4;
269     }
270     output_decrement += 2 * sizeof(float);
271     if (mc & (2 * sizeof(float))) {
272       const float*restrict w = weights;
273       const int32_t* dmap = widx_dmap;
274       const uint32_t* nnzmap = nidx_nnzmap;
275       size_t n = nc;
276       while (n >= 2) {
277         uint32_t nnz = *nnzmap++;
278         float32x2_t vacc01n0 = vld1_dup_f32(w); w += 1;
279         float32x2_t vacc01n1 = vld1_dup_f32(w); w += 1;
280         if XNN_LIKELY(nnz != 0) {
281           do {
282             const intptr_t diff = *dmap++;
283             const float32x2_t vi01 = vld1_f32(input);
284             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
285             const float32x2_t vw = vld1_f32(w); w += 2;
286 
287             vacc01n0 = vfma_lane_f32(vacc01n0, vi01, vw, 0);
288             vacc01n1 = vfma_lane_f32(vacc01n1, vi01, vw, 1);
289           } while (--nnz != 0);
290         }
291         float32x2_t vout01n0 = vmin_f32(vacc01n0, vget_low_f32(vmax));
292         float32x2_t vout01n1 = vmin_f32(vacc01n1, vget_low_f32(vmax));
293 
294         vout01n0 = vmax_f32(vout01n0, vget_low_f32(vmin));
295         vout01n1 = vmax_f32(vout01n1, vget_low_f32(vmin));
296 
297         vst1_f32(output + 0, vout01n0);
298         output = (float*restrict) ((uintptr_t) output + output_stride);
299         vst1_f32(output + 0, vout01n1);
300         output = (float*restrict) ((uintptr_t) output + output_stride);
301         n -= 2;
302       }
303 
304       // clean up loop, fall back to nr=1
305       if XNN_UNLIKELY(n != 0) {
306         do {
307           uint32_t nnz = *nnzmap++;
308           float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
309           if XNN_LIKELY(nnz != 0) {
310             do {
311               const intptr_t diff = *dmap++;
312               const float32x2_t vi01 = vld1_f32(input);
313               input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
314               const float32x2_t vw = vld1_dup_f32(w); w += 1;
315               vacc01 = vfma_f32(vacc01, vi01, vw);
316             } while (--nnz != 0);
317           }
318           float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
319           vout01 = vmax_f32(vout01, vget_low_f32(vmin));
320 
321           vst1_f32(output, vout01);
322           output = (float*restrict) ((uintptr_t) output + output_stride);
323           n -= 1;
324         } while (n != 0);
325       }
326       output = (float*restrict) ((uintptr_t) output - output_decrement);
327       input += 2;
328     }
329     output_decrement += 1 * sizeof(float);
330     if (mc & (1 * sizeof(float))) {
331       const float*restrict w = weights;
332       const int32_t* dmap = widx_dmap;
333       const uint32_t* nnzmap = nidx_nnzmap;
334       size_t n = nc;
335       while (n >= 2) {
336         uint32_t nnz = *nnzmap++;
337         float32x2_t vacc0n0 = vld1_dup_f32(w); w += 1;
338         float32x2_t vacc0n1 = vld1_dup_f32(w); w += 1;
339         if XNN_LIKELY(nnz != 0) {
340           do {
341             const intptr_t diff = *dmap++;
342             const float32x2_t vi0 = vld1_dup_f32(input);
343             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
344             const float32x2_t vw = vld1_f32(w); w += 2;
345 
346             vacc0n0 = vfma_lane_f32(vacc0n0, vi0, vw, 0);
347             vacc0n1 = vfma_lane_f32(vacc0n1, vi0, vw, 1);
348           } while (--nnz != 0);
349         }
350         float32x2_t vout0n0 = vmin_f32(vacc0n0, vget_low_f32(vmax));
351         float32x2_t vout0n1 = vmin_f32(vacc0n1, vget_low_f32(vmax));
352 
353         vout0n0 = vmax_f32(vout0n0, vget_low_f32(vmin));
354         vout0n1 = vmax_f32(vout0n1, vget_low_f32(vmin));
355 
356         vst1_lane_f32(output + 0, vout0n0, 0);
357         output = (float*restrict) ((uintptr_t) output + output_stride);
358         vst1_lane_f32(output + 0, vout0n1, 0);
359         output = (float*restrict) ((uintptr_t) output + output_stride);
360         n -= 2;
361       }
362 
363       // clean up loop, fall back to nr=1
364       if XNN_UNLIKELY(n != 0) {
365         do {
366           uint32_t nnz = *nnzmap++;
367           float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
368           if XNN_LIKELY(nnz != 0) {
369             do {
370               const intptr_t diff = *dmap++;
371               const float32x2_t vi0 = vld1_dup_f32(input);
372               input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
373               const float32x2_t vw = vld1_dup_f32(w); w += 1;
374               vacc0 = vfma_f32(vacc0, vi0, vw);
375             } while (--nnz != 0);
376           }
377           float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
378           vout0 = vmax_f32(vout0, vget_low_f32(vmin));
379 
380           vst1_lane_f32(output, vout0, 1);
381           output = (float*restrict) ((uintptr_t) output + output_stride);
382           n -= 1;
383         } while (n != 0);
384       }
385       output = (float*restrict) ((uintptr_t) output - output_decrement);
386       input += 1;
387     }
388     }
389 }
390