xref: /aosp_15_r20/external/XNNPACK/src/f32-spmm/neon-blocked.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$assert NR in [1, 2, 4]
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9#include <assert.h>
10
11#include <arm_neon.h>
12
13#include <xnnpack/spmm.h>
14
15
16void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__${"neonfma" if FMA else "neon"}(
17    size_t mc,
18    size_t nc,
19    const float*restrict input,
20    const float*restrict weights,
21    const int32_t*restrict widx_dmap,
22    const uint32_t*restrict nidx_nnzmap,
23    float*restrict output,
24    size_t output_stride,
25    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
26{
27  assert(mc != 0);
28  assert(mc % sizeof(float) == 0);
29  assert(nc != 0);
30
31  const float32x4_t vmin = vld1q_dup_f32(&params->scalar.min);
32  const float32x4_t vmax = vld1q_dup_f32(&params->scalar.max);
33  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
34  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
35    const float*restrict w = weights;
36    const int32_t* dmap = widx_dmap;
37    const uint32_t* nnzmap = nidx_nnzmap;
38    size_t n = nc;
39    while (n >= ${NR}) {
40      uint32_t nnz = *nnzmap++;
41      $for N in range(0, NR, 1):
42        float32x4_t vacc${ABC[0:4]}n${N} = vld1q_dup_f32(w); w += 1;
43        $for M in range(4, MR, 4):
44          float32x4_t vacc${ABC[M:M+4]}n${N} = vacc${ABC[0:4]}n${N};
45      if XNN_LIKELY(nnz != 0) {
46        do {
47          const intptr_t diff = *dmap++;
48          const float32x4_t vi${ABC[0:4]} = vld1q_f32(input);
49          $for M in range(4, MR, 4):
50            const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
51          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
52          $for M in range(0, MR, 16):
53            __builtin_prefetch(input + ${M+16});
54          $if NR == 1:
55            const float32x4_t vw = vld1q_dup_f32(w); w += 1;
56          $elif NR == 2:
57            const float32x2_t vw = vld1_f32(w); w += 2;
58          $elif NR == 4:
59            const float32x4_t vw = vld1q_f32(w); w += 4;
60          __builtin_prefetch(w + 32);
61          $if NR == 1:
62            $for M in range(0, MR, 4):
63              vacc${ABC[M:M+4]}c0 = vfmaq_f32(vacc${ABC[M:M+4]}c0, vi${ABC[M:M+4]}, vw);
64          $else:
65            $for N in range(NR):
66              $for M in range(0, MR, 4):
67                vacc${ABC[M:M+4]}n${N} = vfmaq_lane${"q" if NR == 4 else ""}_f32(vacc${ABC[M:M+4]}n${N}, vi${ABC[M:M+4]}, vw, ${N});
68        } while (--nnz != 0);
69      }
70      $for N in range(0, NR, 1):
71        $for M in range(0, MR, 4):
72          float32x4_t vout${ABC[M:M+4]}n${N} = vminq_f32(vacc${ABC[M:M+4]}n${N}, vmax);
73
74      $for N in range(0, NR, 1):
75        $for M in range(0, MR, 4):
76          vout${ABC[M:M+4]}n${N} = vmaxq_f32(vout${ABC[M:M+4]}n${N}, vmin);
77
78      $for N in range(0, NR, 1):
79        $for M in range(0, MR, 4):
80          vst1q_f32(output + ${M}, vout${ABC[M:M+4]}n${N});
81        output = (float*restrict) ((uintptr_t) output + output_stride);
82      n -= ${NR};
83    }
84
85    // clean up loop, fall back to nr=1
86    if XNN_UNLIKELY(n != 0) {
87      do {
88        uint32_t nnz = *nnzmap++;
89        float32x4_t vacc${ABC[0:4]} = vld1q_dup_f32(w); w += 1;
90        $for M in range(4, MR, 4):
91          float32x4_t vacc${ABC[M:M+4]} = vacc${ABC[0:4]};
92        if XNN_LIKELY(nnz != 0) {
93          do {
94            const intptr_t diff = *dmap++;
95            const float32x4_t vi${ABC[0:4]} = vld1q_f32(input);
96            $for M in range(4, MR, 4):
97              const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
98            input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
99            $for M in range(0, MR, 16):
100              __builtin_prefetch(input + ${M+16});
101            const float32x4_t vw = vld1q_dup_f32(w); w += 1;
102            __builtin_prefetch(w + 32);
103            $for M in range(0, MR, 4):
104              vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vw);
105          } while (--nnz != 0);
106        }
107        $for M in range(0, MR, 4):
108          float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
109
110        $for M in range(0, MR, 4):
111          vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
112
113        $for M in range(0, MR, 4):
114          vst1q_f32(output + ${M}, vout${ABC[M:M+4]});
115        output = (float*restrict) ((uintptr_t) output + output_stride);
116        n -= 1;
117      } while (n != 0);
118    }
119    output = (float*restrict) ((uintptr_t) output - output_decrement);
120    input += ${MR};
121    mc -= ${MR} * sizeof(float);
122  }
123  if XNN_UNLIKELY(mc != 0) {
124    $for LOG2M in reversed(range((MR - 1).bit_length())):
125      $SUBMR = 1 << LOG2M
126      $if SUBMR * 2 >= MR:
127        output_decrement += ${MR - SUBMR} * sizeof(float);
128      $else:
129        output_decrement += ${SUBMR} * sizeof(float);
130      if (mc & (${SUBMR} * sizeof(float))) {
131        const float*restrict w = weights;
132        const int32_t* dmap = widx_dmap;
133        const uint32_t* nnzmap = nidx_nnzmap;
134        size_t n = nc;
135        while (n >= ${NR}) {
136          uint32_t nnz = *nnzmap++;
137          $for N in range(0, NR, 1):
138            $if SUBMR < 4:
139              float32x2_t vacc${ABC[0:SUBMR]}n${N} = vld1_dup_f32(w); w += 1;
140            $else:
141              float32x4_t vacc${ABC[0:4]}n${N} = vld1q_dup_f32(w); w += 1;
142            $for M in range(4, SUBMR, 4):
143              float32x4_t vacc${ABC[M:M+4]}n${N} = vacc${ABC[0:4]}n${N};
144          if XNN_LIKELY(nnz != 0) {
145            do {
146              const intptr_t diff = *dmap++;
147              $if SUBMR == 1:
148                const float32x2_t vi${ABC[0]} = vld1_dup_f32(input);
149              $elif SUBMR == 2:
150                const float32x2_t vi${ABC[0:2]} = vld1_f32(input);
151              $else:
152                const float32x4_t vi${ABC[0:4]} = vld1q_f32(input);
153              $for M in range(4, SUBMR, 4):
154                const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
155              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
156              $if NR == 1:
157                $if SUBMR < 4:
158                  const float32x2_t vw = vld1_dup_f32(w); w += 1;
159                $else:
160                  const float32x4_t vw = vld1q_dup_f32(w); w += 1;
161              $elif NR == 2:
162                const float32x2_t vw = vld1_f32(w); w += 2;
163              $elif NR == 4:
164                const float32x4_t vw = vld1q_f32(w); w += 4;
165
166              $if NR == 1:
167                $if SUBMR < 4:
168                    vacc${ABC[0:SUBMR]}c0 = vfmaq_f32(vacc${ABC[0:SUBMR]}c0, vi${ABC[0:SUBMR]}, vw);
169                $else:
170                  $for M in range(0, SUBMR, 4):
171                    vacc${ABC[M:M+4]}c0 = vfmaq_f32(vacc${ABC[M:M+4]}c0, vi${ABC[M:M+4]}, vw);
172              $else:
173                $for N in range(NR):
174                  $if SUBMR < 4:
175                    vacc${ABC[0:SUBMR]}n${N} = vfma_lane${"q" if NR == 4 else ""}_f32(vacc${ABC[0:SUBMR]}n${N}, vi${ABC[0:SUBMR]}, vw, ${N});
176                  $else:
177                    $for M in range(0, SUBMR, 4):
178                      vacc${ABC[M:M+4]}n${N} = vfmaq_lane${"q" if NR == 4 else ""}_f32(vacc${ABC[M:M+4]}n${N}, vi${ABC[M:M+4]}, vw, ${N});
179            } while (--nnz != 0);
180          }
181          $for N in range(0, NR, 1):
182            $if SUBMR < 4:
183              float32x2_t vout${ABC[0:SUBMR]}n${N} = vmin_f32(vacc${ABC[0:SUBMR]}n${N}, vget_low_f32(vmax));
184            $else:
185              $for M in range(0, SUBMR, 4):
186                float32x4_t vout${ABC[M:M+4]}n${N} = vminq_f32(vacc${ABC[M:M+4]}n${N}, vmax);
187
188          $for N in range(0, NR, 1):
189            $if SUBMR < 4:
190              vout${ABC[0:SUBMR]}n${N} = vmax_f32(vout${ABC[0:SUBMR]}n${N}, vget_low_f32(vmin));
191            $else:
192              $for M in range(0, SUBMR, 4):
193                vout${ABC[M:M+4]}n${N} = vmaxq_f32(vout${ABC[M:M+4]}n${N}, vmin);
194
195          $for N in range(NR):
196            $if SUBMR == 1:
197              vst1_lane_f32(output + ${M}, vout${ABC[0:SUBMR]}n${N}, 0);
198            $elif SUBMR == 2:
199              vst1_f32(output + ${M}, vout${ABC[0:SUBMR]}n${N});
200            $else:
201              $for M in range(0, SUBMR, 4):
202                vst1q_f32(output + ${M}, vout${ABC[M:M+4]}n${N});
203            output = (float*restrict) ((uintptr_t) output + output_stride);
204          n -= ${NR};
205        }
206
207        // clean up loop, fall back to nr=1
208        if XNN_UNLIKELY(n != 0) {
209          do {
210            uint32_t nnz = *nnzmap++;
211            $if SUBMR < 4:
212              float32x2_t vacc${ABC[0:SUBMR]} = vld1_dup_f32(w); w += 1;
213            $else:
214              float32x4_t vacc${ABC[0:4]} = vld1q_dup_f32(w); w += 1;
215            $for M in range(4, SUBMR, 4):
216              float32x4_t vacc${ABC[M:M+4]} = vacc${ABC[0:4]};
217            if XNN_LIKELY(nnz != 0) {
218              do {
219                const intptr_t diff = *dmap++;
220                $if SUBMR == 1:
221                  const float32x2_t vi${ABC[0:1]} = vld1_dup_f32(input);
222                $elif SUBMR == 2:
223                  const float32x2_t vi${ABC[0:2]} = vld1_f32(input);
224                $else:
225                  const float32x4_t vi${ABC[0:4]} = vld1q_f32(input);
226                $for M in range(4, SUBMR, 4):
227                  const float32x4_t vi${ABC[M:M+4]} = vld1q_f32(input + ${M});
228                input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
229                $if SUBMR < 4:
230                  const float32x2_t vw = vld1_dup_f32(w); w += 1;
231                  vacc${ABC[0:SUBMR]} = vfma_f32(vacc${ABC[0:SUBMR]}, vi${ABC[0:SUBMR]}, vw);
232                $else:
233                  const float32x4_t vw = vld1q_dup_f32(w); w += 1;
234                  $for M in range(0, SUBMR, 4):
235                    vacc${ABC[M:M+4]} = vfmaq_f32(vacc${ABC[M:M+4]}, vi${ABC[M:M+4]}, vw);
236              } while (--nnz != 0);
237            }
238            $if SUBMR < 4:
239              float32x2_t vout${ABC[0:SUBMR]} = vmin_f32(vacc${ABC[0:SUBMR]}, vget_low_f32(vmax));
240              vout${ABC[0:SUBMR]} = vmax_f32(vout${ABC[0:SUBMR]}, vget_low_f32(vmin));
241            $else:
242              $for M in range(0, SUBMR, 4):
243                float32x4_t vout${ABC[M:M+4]} = vminq_f32(vacc${ABC[M:M+4]}, vmax);
244
245              $for M in range(0, SUBMR, 4):
246                vout${ABC[M:M+4]} = vmaxq_f32(vout${ABC[M:M+4]}, vmin);
247
248            $if SUBMR == 1:
249              vst1_lane_f32(output, vout${ABC[0:1]}, 1);
250            $elif SUBMR == 2:
251              vst1_f32(output, vout${ABC[0:2]});
252            $else:
253              $for M in range(0, SUBMR, 4):
254                vst1q_f32(output + ${M}, vout${ABC[M:M+4]});
255            output = (float*restrict) ((uintptr_t) output + output_stride);
256            n -= 1;
257          } while (n != 0);
258        }
259        output = (float*restrict) ((uintptr_t) output - output_decrement);
260        input += ${SUBMR};
261      }
262    }
263}
264