1 // Auto-generated file. Do not edit!
2 // Template: src/f32-spmm/neon-blocked.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f32_spmm_minmax_ukernel_4x4__neonfma(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_4x4__neonfma(
18 size_t mc,
19 size_t nc,
20 const float*restrict input,
21 const float*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 float*restrict output,
25 size_t output_stride,
26 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(float) == 0);
30 assert(nc != 0);
31
32 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
33 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
34 size_t output_decrement = output_stride * nc - 4 * sizeof(float);
35 while XNN_LIKELY(mc >= 4 * sizeof(float)) {
36 const float*restrict w = weights;
37 const int32_t* dmap = widx_dmap;
38 const uint32_t* nnzmap = nidx_nnzmap;
39 size_t n = nc;
40 while (n >= 4) {
41 uint32_t nnz = *nnzmap++;
42 float32x4_t vacc0123n0 = vld1q_dup_f32(w); w += 1;
43 float32x4_t vacc0123n1 = vld1q_dup_f32(w); w += 1;
44 float32x4_t vacc0123n2 = vld1q_dup_f32(w); w += 1;
45 float32x4_t vacc0123n3 = vld1q_dup_f32(w); w += 1;
46 if XNN_LIKELY(nnz != 0) {
47 do {
48 const intptr_t diff = *dmap++;
49 const float32x4_t vi0123 = vld1q_f32(input);
50 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
51 __builtin_prefetch(input + 16);
52 const float32x4_t vw = vld1q_f32(w); w += 4;
53 __builtin_prefetch(w + 32);
54 vacc0123n0 = vfmaq_laneq_f32(vacc0123n0, vi0123, vw, 0);
55 vacc0123n1 = vfmaq_laneq_f32(vacc0123n1, vi0123, vw, 1);
56 vacc0123n2 = vfmaq_laneq_f32(vacc0123n2, vi0123, vw, 2);
57 vacc0123n3 = vfmaq_laneq_f32(vacc0123n3, vi0123, vw, 3);
58 } while (--nnz != 0);
59 }
60 float32x4_t vout0123n0 = vminq_f32(vacc0123n0, vmax);
61 float32x4_t vout0123n1 = vminq_f32(vacc0123n1, vmax);
62 float32x4_t vout0123n2 = vminq_f32(vacc0123n2, vmax);
63 float32x4_t vout0123n3 = vminq_f32(vacc0123n3, vmax);
64
65 vout0123n0 = vmaxq_f32(vout0123n0, vmin);
66 vout0123n1 = vmaxq_f32(vout0123n1, vmin);
67 vout0123n2 = vmaxq_f32(vout0123n2, vmin);
68 vout0123n3 = vmaxq_f32(vout0123n3, vmin);
69
70 vst1q_f32(output + 0, vout0123n0);
71 output = (float*restrict) ((uintptr_t) output + output_stride);
72 vst1q_f32(output + 0, vout0123n1);
73 output = (float*restrict) ((uintptr_t) output + output_stride);
74 vst1q_f32(output + 0, vout0123n2);
75 output = (float*restrict) ((uintptr_t) output + output_stride);
76 vst1q_f32(output + 0, vout0123n3);
77 output = (float*restrict) ((uintptr_t) output + output_stride);
78 n -= 4;
79 }
80
81 // clean up loop, fall back to nr=1
82 if XNN_UNLIKELY(n != 0) {
83 do {
84 uint32_t nnz = *nnzmap++;
85 float32x4_t vacc0123 = vld1q_dup_f32(w); w += 1;
86 if XNN_LIKELY(nnz != 0) {
87 do {
88 const intptr_t diff = *dmap++;
89 const float32x4_t vi0123 = vld1q_f32(input);
90 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
91 __builtin_prefetch(input + 16);
92 const float32x4_t vw = vld1q_dup_f32(w); w += 1;
93 __builtin_prefetch(w + 32);
94 vacc0123 = vfmaq_f32(vacc0123, vi0123, vw);
95 } while (--nnz != 0);
96 }
97 float32x4_t vout0123 = vminq_f32(vacc0123, vmax);
98
99 vout0123 = vmaxq_f32(vout0123, vmin);
100
101 vst1q_f32(output + 0, vout0123);
102 output = (float*restrict) ((uintptr_t) output + output_stride);
103 n -= 1;
104 } while (n != 0);
105 }
106 output = (float*restrict) ((uintptr_t) output - output_decrement);
107 input += 4;
108 mc -= 4 * sizeof(float);
109 }
110 if XNN_UNLIKELY(mc != 0) {
111 output_decrement += 2 * sizeof(float);
112 if (mc & (2 * sizeof(float))) {
113 const float*restrict w = weights;
114 const int32_t* dmap = widx_dmap;
115 const uint32_t* nnzmap = nidx_nnzmap;
116 size_t n = nc;
117 while (n >= 4) {
118 uint32_t nnz = *nnzmap++;
119 float32x2_t vacc01n0 = vld1_dup_f32(w); w += 1;
120 float32x2_t vacc01n1 = vld1_dup_f32(w); w += 1;
121 float32x2_t vacc01n2 = vld1_dup_f32(w); w += 1;
122 float32x2_t vacc01n3 = vld1_dup_f32(w); w += 1;
123 if XNN_LIKELY(nnz != 0) {
124 do {
125 const intptr_t diff = *dmap++;
126 const float32x2_t vi01 = vld1_f32(input);
127 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
128 const float32x4_t vw = vld1q_f32(w); w += 4;
129
130 vacc01n0 = vfma_laneq_f32(vacc01n0, vi01, vw, 0);
131 vacc01n1 = vfma_laneq_f32(vacc01n1, vi01, vw, 1);
132 vacc01n2 = vfma_laneq_f32(vacc01n2, vi01, vw, 2);
133 vacc01n3 = vfma_laneq_f32(vacc01n3, vi01, vw, 3);
134 } while (--nnz != 0);
135 }
136 float32x2_t vout01n0 = vmin_f32(vacc01n0, vget_low_f32(vmax));
137 float32x2_t vout01n1 = vmin_f32(vacc01n1, vget_low_f32(vmax));
138 float32x2_t vout01n2 = vmin_f32(vacc01n2, vget_low_f32(vmax));
139 float32x2_t vout01n3 = vmin_f32(vacc01n3, vget_low_f32(vmax));
140
141 vout01n0 = vmax_f32(vout01n0, vget_low_f32(vmin));
142 vout01n1 = vmax_f32(vout01n1, vget_low_f32(vmin));
143 vout01n2 = vmax_f32(vout01n2, vget_low_f32(vmin));
144 vout01n3 = vmax_f32(vout01n3, vget_low_f32(vmin));
145
146 vst1_f32(output + 0, vout01n0);
147 output = (float*restrict) ((uintptr_t) output + output_stride);
148 vst1_f32(output + 0, vout01n1);
149 output = (float*restrict) ((uintptr_t) output + output_stride);
150 vst1_f32(output + 0, vout01n2);
151 output = (float*restrict) ((uintptr_t) output + output_stride);
152 vst1_f32(output + 0, vout01n3);
153 output = (float*restrict) ((uintptr_t) output + output_stride);
154 n -= 4;
155 }
156
157 // clean up loop, fall back to nr=1
158 if XNN_UNLIKELY(n != 0) {
159 do {
160 uint32_t nnz = *nnzmap++;
161 float32x2_t vacc01 = vld1_dup_f32(w); w += 1;
162 if XNN_LIKELY(nnz != 0) {
163 do {
164 const intptr_t diff = *dmap++;
165 const float32x2_t vi01 = vld1_f32(input);
166 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
167 const float32x2_t vw = vld1_dup_f32(w); w += 1;
168 vacc01 = vfma_f32(vacc01, vi01, vw);
169 } while (--nnz != 0);
170 }
171 float32x2_t vout01 = vmin_f32(vacc01, vget_low_f32(vmax));
172 vout01 = vmax_f32(vout01, vget_low_f32(vmin));
173
174 vst1_f32(output, vout01);
175 output = (float*restrict) ((uintptr_t) output + output_stride);
176 n -= 1;
177 } while (n != 0);
178 }
179 output = (float*restrict) ((uintptr_t) output - output_decrement);
180 input += 2;
181 }
182 output_decrement += 1 * sizeof(float);
183 if (mc & (1 * sizeof(float))) {
184 const float*restrict w = weights;
185 const int32_t* dmap = widx_dmap;
186 const uint32_t* nnzmap = nidx_nnzmap;
187 size_t n = nc;
188 while (n >= 4) {
189 uint32_t nnz = *nnzmap++;
190 float32x2_t vacc0n0 = vld1_dup_f32(w); w += 1;
191 float32x2_t vacc0n1 = vld1_dup_f32(w); w += 1;
192 float32x2_t vacc0n2 = vld1_dup_f32(w); w += 1;
193 float32x2_t vacc0n3 = vld1_dup_f32(w); w += 1;
194 if XNN_LIKELY(nnz != 0) {
195 do {
196 const intptr_t diff = *dmap++;
197 const float32x2_t vi0 = vld1_dup_f32(input);
198 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
199 const float32x4_t vw = vld1q_f32(w); w += 4;
200
201 vacc0n0 = vfma_laneq_f32(vacc0n0, vi0, vw, 0);
202 vacc0n1 = vfma_laneq_f32(vacc0n1, vi0, vw, 1);
203 vacc0n2 = vfma_laneq_f32(vacc0n2, vi0, vw, 2);
204 vacc0n3 = vfma_laneq_f32(vacc0n3, vi0, vw, 3);
205 } while (--nnz != 0);
206 }
207 float32x2_t vout0n0 = vmin_f32(vacc0n0, vget_low_f32(vmax));
208 float32x2_t vout0n1 = vmin_f32(vacc0n1, vget_low_f32(vmax));
209 float32x2_t vout0n2 = vmin_f32(vacc0n2, vget_low_f32(vmax));
210 float32x2_t vout0n3 = vmin_f32(vacc0n3, vget_low_f32(vmax));
211
212 vout0n0 = vmax_f32(vout0n0, vget_low_f32(vmin));
213 vout0n1 = vmax_f32(vout0n1, vget_low_f32(vmin));
214 vout0n2 = vmax_f32(vout0n2, vget_low_f32(vmin));
215 vout0n3 = vmax_f32(vout0n3, vget_low_f32(vmin));
216
217 vst1_lane_f32(output + 0, vout0n0, 0);
218 output = (float*restrict) ((uintptr_t) output + output_stride);
219 vst1_lane_f32(output + 0, vout0n1, 0);
220 output = (float*restrict) ((uintptr_t) output + output_stride);
221 vst1_lane_f32(output + 0, vout0n2, 0);
222 output = (float*restrict) ((uintptr_t) output + output_stride);
223 vst1_lane_f32(output + 0, vout0n3, 0);
224 output = (float*restrict) ((uintptr_t) output + output_stride);
225 n -= 4;
226 }
227
228 // clean up loop, fall back to nr=1
229 if XNN_UNLIKELY(n != 0) {
230 do {
231 uint32_t nnz = *nnzmap++;
232 float32x2_t vacc0 = vld1_dup_f32(w); w += 1;
233 if XNN_LIKELY(nnz != 0) {
234 do {
235 const intptr_t diff = *dmap++;
236 const float32x2_t vi0 = vld1_dup_f32(input);
237 input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
238 const float32x2_t vw = vld1_dup_f32(w); w += 1;
239 vacc0 = vfma_f32(vacc0, vi0, vw);
240 } while (--nnz != 0);
241 }
242 float32x2_t vout0 = vmin_f32(vacc0, vget_low_f32(vmax));
243 vout0 = vmax_f32(vout0, vget_low_f32(vmin));
244
245 vst1_lane_f32(output, vout0, 1);
246 output = (float*restrict) ((uintptr_t) output + output_stride);
247 n -= 1;
248 } while (n != 0);
249 }
250 output = (float*restrict) ((uintptr_t) output - output_decrement);
251 input += 1;
252 }
253 }
254 }
255