xref: /aosp_15_r20/external/XNNPACK/src/f32-vmulcaddc/wasmsimd.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert CHANNEL_TILE >= 4
8$assert ROW_TILE >= 1
9$assert ARCH in ["ARM", "X86", "RELAXED"]
10$assert not FMA or ARCH == "RELAXED"
11$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
12#include <assert.h>
13
14#include <wasm_simd128.h>
15
16#include <xnnpack/math.h>
17#include <xnnpack/vmulcaddc.h>
18
19
20$WASM_F32X4_MIN={"ARM": "wasm_f32x4_min", "X86": "wasm_f32x4_pmin", "RELAXED": "__builtin_wasm_relaxed_min_f32x4"}[ARCH]
21$WASM_F32X4_MAX={"ARM": "wasm_f32x4_max", "X86": "wasm_f32x4_pmax", "RELAXED": "__builtin_wasm_relaxed_max_f32x4"}[ARCH]
22$ISA = "wasmsimd" if ARCH != "RELAXED" else "wasmrelaxedsimd"
23$ARCH_SUFFIX = "" if ARCH == "RELAXED" and not FMA else "_" + ("fma" if FMA else ARCH.lower())
24void xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__${ISA}${ARCH_SUFFIX}_${ROW_TILE}x(
25    size_t rows,
26    size_t channels,
27    const float*restrict input,
28    size_t input_stride,
29    const float*restrict weights,
30    float*restrict output,
31    size_t output_stride,
32    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
33{
34  assert(rows != 0);
35  assert(channels != 0);
36  assert(channels % sizeof(float) == 0);
37
38  const float* i0 = input;
39  float* o0 = output;
40  $for M in range(1, ROW_TILE):
41    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
42    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
43
44  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
45  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
46
47  const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min);
48  const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max);
49  do {
50    $for M in range(1, ROW_TILE):
51      $if M % 2 == 0:
52        if XNN_UNPREDICTABLE(rows <= ${M}) {
53          i${M} = i${M-1};
54          o${M} = o${M-1};
55        }
56      $else:
57        if XNN_UNPREDICTABLE(rows < ${M+1}) {
58          i${M} = i${M-1};
59          o${M} = o${M-1};
60        }
61
62    const float* w = weights;
63    size_t c = channels;
64    for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
65      const v128_t vscale${ABC[0:4]} = wasm_v128_load(w);
66      $for C in range(4, CHANNEL_TILE, 4):
67        const v128_t vscale${ABC[C:C+4]} = wasm_v128_load(w + ${C});
68
69      $for M in range(ROW_TILE):
70        v128_t vacc${M}x${ABC[0:4]} = wasm_v128_load(i${M});
71        $for C in range(4, CHANNEL_TILE, 4):
72          v128_t vacc${M}x${ABC[C:C+4]} = wasm_v128_load(i${M} + ${C});
73        i${M} += ${CHANNEL_TILE};
74
75      $for C in range(0, CHANNEL_TILE, 4):
76        const v128_t vbias${ABC[C:C+4]} = wasm_v128_load(w + ${C + CHANNEL_TILE});
77
78      $for M in range(ROW_TILE):
79        $for C in range(0, CHANNEL_TILE, 4):
80          $if FMA:
81            vacc${M}x${ABC[C:C+4]} = __builtin_wasm_fma_f32x4(vbias${ABC[C:C+4]}, vscale${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]});
82          $else:
83            vacc${M}x${ABC[C:C+4]} = wasm_f32x4_add(vbias${ABC[C:C+4]}, wasm_f32x4_mul(vscale${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]}));
84
85      $for M in range(ROW_TILE):
86        $for C in range(0, CHANNEL_TILE, 4):
87          vacc${M}x${ABC[C:C+4]} = ${WASM_F32X4_MAX}(vmin, vacc${M}x${ABC[C:C+4]});
88
89      $for M in range(ROW_TILE):
90        $for C in range(0, CHANNEL_TILE, 4):
91          vacc${M}x${ABC[C:C+4]} = ${WASM_F32X4_MIN}(vmax, vacc${M}x${ABC[C:C+4]});
92
93      $for M in range(ROW_TILE):
94        wasm_v128_store(o${M}, vacc${M}x${ABC[0:4]});
95        $for C in range(4, CHANNEL_TILE, 4):
96          wasm_v128_store(o${M} + ${C}, vacc${M}x${ABC[C:C+4]});
97        o${M} += ${CHANNEL_TILE};
98
99      w += ${CHANNEL_TILE * 2};
100    }
101    $if CHANNEL_TILE > 4:
102      for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
103        const v128_t vscale = wasm_v128_load(w);
104
105        $for M in range(ROW_TILE):
106          v128_t vacc${M} = wasm_v128_load(i${M});
107          i${M} += 4;
108
109        const v128_t vbias = wasm_v128_load(w + ${CHANNEL_TILE});
110
111        $for M in range(ROW_TILE):
112          $if FMA:
113            vacc${M} = __builtin_wasm_fma_f32x4(vbias, vscale, vacc${M});
114          $else:
115            vacc${M} = wasm_f32x4_add(vbias, wasm_f32x4_mul(vscale, vacc${M}));
116
117        $for M in range(ROW_TILE):
118          vacc${M} = ${WASM_F32X4_MAX}(vmin, vacc${M});
119
120        $for M in range(ROW_TILE):
121          vacc${M} = ${WASM_F32X4_MIN}(vmax, vacc${M});
122
123        $for M in range(ROW_TILE):
124          wasm_v128_store(o${M}, vacc${M});
125          o${M} += 4;
126
127        w += 4;
128      }
129    if XNN_UNLIKELY(c != 0) {
130      const v128_t vscale = wasm_v128_load(w);
131
132      $for M in range(ROW_TILE):
133        v128_t vacc${M} = wasm_v128_load(i${M});
134        i${M} = (const float*) ((uintptr_t) i${M} + c);
135
136      const v128_t vbias = wasm_v128_load(w + ${CHANNEL_TILE});
137
138      $for M in range(ROW_TILE):
139        $if FMA:
140          vacc${M} = __builtin_wasm_fma_f32x4(vbias, vscale, vacc${M});
141        $else:
142          vacc${M} = wasm_f32x4_add(vbias, wasm_f32x4_mul(vscale, vacc${M}));
143
144      $for M in range(ROW_TILE):
145        vacc${M} = ${WASM_F32X4_MAX}(vmin, vacc${M});
146
147      $for M in range(ROW_TILE):
148        vacc${M} = ${WASM_F32X4_MIN}(vmax, vacc${M});
149
150      if (c & (2 * sizeof(float))) {
151        $for M in range(ROW_TILE):
152          *((double*) o${M}) = wasm_f64x2_extract_lane(vacc${M}, 0);
153
154        $for M in range(ROW_TILE):
155          vacc${M} = wasm_v32x4_shuffle(vacc${M}, vacc${M}, 2, 3, 2, 3);
156
157        $for M in range(ROW_TILE):
158          o${M} += 2;
159      }
160      if (c & (1 * sizeof(float))) {
161        $for M in range(ROW_TILE):
162          *o${M}++ = wasm_f32x4_extract_lane(vacc${M}, 0);
163      }
164    }
165    $for M in range(ROW_TILE):
166      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
167      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
168    rows = doz(rows, ${ROW_TILE});
169  } while (rows != 0);
170}
171