xref: /aosp_15_r20/external/XNNPACK/src/f32-spmm/wasmsimd-pipelined.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8$assert MINMAX in ["MINMAX", "PMINMAX"]
9#include <assert.h>
10
11#include <wasm_simd128.h>
12
13#include <xnnpack/spmm.h>
14
15
16$WASM_F32X4_MIN={"MINMAX": "wasm_f32x4_min", "PMINMAX": "wasm_f32x4_pmin"}[MINMAX]
17$WASM_F32X4_MAX={"MINMAX": "wasm_f32x4_max", "PMINMAX": "wasm_f32x4_pmax"}[MINMAX]
18$ARCH_SUFFIX = "_x86" if MINMAX == "PMINMAX" else "_arm"
19void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__wasmsimd${ARCH_SUFFIX}_pipelined${"_x" + str(UNROLL) if UNROLL > 1 else ""}(
20    size_t mc,
21    size_t nc,
22    const float*restrict input,
23    const float*restrict weights,
24    const int32_t*restrict widx_dmap,
25    const uint32_t*restrict nidx_nnzmap,
26    float*restrict output,
27    size_t output_stride,
28    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
29{
30  assert(mc != 0);
31  assert(mc % sizeof(float) == 0);
32  assert(nc != 0);
33
34  const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min);
35  const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max);
36  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
37  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
38    const float*restrict w = weights;
39    const int32_t* dmap = widx_dmap;
40    const uint32_t* nnzmap = nidx_nnzmap;
41    v128_t vw = wasm_v128_load32_splat(w); w += 1;
42    intptr_t diff = *dmap++;
43    $for M in range(0, MR, 4):
44      v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
45    size_t n = nc;
46    do {
47      uint32_t nnz = *nnzmap++;
48       $for M in range(0, MR, 4):
49        v128_t vacc${ABC[M:M+4]} = vw;
50      vw = wasm_v128_load32_splat(w); w += 1;
51
52      $if UNROLL > 1:
53        for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
54          $for K in range(0, UNROLL):
55            $for M in range(0, MR, 4):
56              vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]},   vw));
57            input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
58            diff = *dmap++;
59            vw = wasm_v128_load32_splat(w); w += 1;
60            $for M in range(0, MR, 4):
61              vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
62        }
63
64      if XNN_LIKELY(nnz != 0) {
65        do {
66          $for M in range(0, MR, 4):
67            vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw));
68          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
69
70          diff = *dmap++;
71          vw = wasm_v128_load32_splat(w); w += 1;
72          $for M in range(0, MR, 4):
73            vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
74        } while (--nnz != 0);
75      }
76      $for M in range(0, MR, 4):
77        v128_t vout${ABC[M:M+4]} = ${WASM_F32X4_MIN}(vmax, vacc${ABC[M:M+4]});
78      $for M in range(0, MR, 4):
79        vout${ABC[M:M+4]} = ${WASM_F32X4_MAX}(vmin, vout${ABC[M:M+4]});
80      wasm_v128_store(output, vout0123);
81      $for M in range(4, MR, 4):
82        wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
83      output = (float*restrict) ((uintptr_t) output + output_stride);
84    } while (--n != 0);
85    output = (float*restrict) ((uintptr_t) output - output_decrement);
86    input += ${MR};
87    mc -= ${MR} * sizeof(float);
88  }
89  if XNN_UNLIKELY(mc != 0) {
90    $for LOG2M in reversed(range((MR - 1).bit_length())):
91      $SUBMR = 1 << LOG2M
92      $if SUBMR * 2 >= MR:
93        output_decrement += ${MR - SUBMR} * sizeof(float);
94      $else:
95        output_decrement += ${SUBMR} * sizeof(float);
96      if (mc & (${SUBMR} * sizeof(float))) {
97        const float*restrict w = weights;
98        const int32_t* dmap = widx_dmap;
99        const uint32_t* nnzmap = nidx_nnzmap;
100        size_t n = nc;
101        do {
102          uint32_t nnz = *nnzmap++;
103          $if SUBMR == 1:
104            v128_t vacc0 = wasm_v128_load32_splat(w); w += 1;
105          $elif SUBMR == 2:
106            v128_t vacc01 = wasm_v128_load32_splat(w); w += 1;
107          $else:
108            v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1;
109          $for M in range(4, SUBMR, 4):
110            v128_t vacc${ABC[M:M+4]} = vacc0123;
111          if XNN_LIKELY(nnz != 0) {
112            do {
113              const intptr_t diff = *dmap++;
114              $if SUBMR >= 4:
115                const v128_t vi0123 = wasm_v128_load(input);
116              $elif SUBMR == 2:
117                const v128_t vi01 = wasm_v128_load64_splat(input);
118              $elif SUBMR == 1:
119                const v128_t vi0 = wasm_v128_load32_splat(input);
120              $for M in range(4, SUBMR, 4):
121                const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
122              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
123              const v128_t vw = wasm_v128_load32_splat(w); w += 1;
124              $if SUBMR == 1:
125                vacc${ABC[0]} = wasm_f32x4_add(vacc${ABC[0]}, wasm_f32x4_mul(vi${ABC[0]}, vw));
126              $else:
127                $for M in range(0, SUBMR, 4):
128                  vacc${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_add(vacc${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_mul(vi${ABC[M:min(M+4,SUBMR)]}, vw));
129            } while (--nnz != 0);
130          }
131          $if SUBMR == 1:
132            v128_t vout${ABC[0]} = ${WASM_F32X4_MIN}(vmax, vacc${ABC[0]});
133            vout${ABC[0]} = ${WASM_F32X4_MAX}(vmin, vout${ABC[0]});
134          $else:
135            $for M in range(0, SUBMR, 4):
136              v128_t vout${ABC[M:min(M+4,SUBMR)]} = ${WASM_F32X4_MIN}(vmax, vacc${ABC[M:min(M+4,SUBMR)]});
137            $for M in range(0, SUBMR, 4):
138              vout${ABC[M:min(M+4,SUBMR)]} = ${WASM_F32X4_MAX}(vmin, vout${ABC[M:min(M+4,SUBMR)]});
139          $if SUBMR >= 4:
140            wasm_v128_store(output, vout0123);
141          $elif SUBMR == 2:
142            *((double*) output) = wasm_f64x2_extract_lane(vout01, 0);
143          $elif SUBMR == 1:
144            *output = wasm_f32x4_extract_lane(vout0, 0);
145
146          $for M in range(4, SUBMR, 4):
147            wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
148          output = (float*restrict) ((uintptr_t) output + output_stride);
149        } while (--n != 0);
150        output = (float*restrict) ((uintptr_t) output - output_decrement);
151        input += ${SUBMR};
152      }
153  }
154}
155