xref: /aosp_15_r20/external/XNNPACK/src/f32-dwconv2d-chw/3x3p1-wasmsimd-splat.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ROW_TILE >= 1
7$assert ACCUMULATORS >= 1
8#include <assert.h>
9
10#include <wasm_simd128.h>
11
12#include <xnnpack/dwconv.h>
13#include <xnnpack/math.h>
14
15
16$ARCH_SUFFIX = "_x86" if X86 else "_arm"
17
18void xnn_f32_dwconv2d_chw_ukernel_3x3p1__wasmsimd${ARCH_SUFFIX}_splat_${ROW_TILE}x4${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
19    size_t input_height,
20    size_t input_width,
21    const float* input,
22    const float* weights,
23    const float* zero,
24    float* output,
25    uint32_t padding_top,
26    const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
27{
28  assert(input_height != 0);
29  assert(input_width != 0);
30  assert(input_width % sizeof(float) == 0);
31  assert(padding_top == 1);
32
33  const v128_t vmask = wasm_v128_load(params->scalar.mask);
34  const v128_t vmax = wasm_v128_load32_splat(&params->scalar.max);
35  const v128_t vmin = wasm_v128_load32_splat(&params->scalar.min);
36
37  const v128_t vw0123 = wasm_v128_load(weights);
38  const v128_t vw4567 = wasm_v128_load(weights + 4);
39  const v128_t vw89 = wasm_v128_load64_splat(weights + 8);
40
41  const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float));
42
43  const float* i0 = zero;
44  const float* i1 = input;
45  $for M in range(2, 2 + ROW_TILE):
46    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
47
48  float* o0 = output;
49  $for M in range(1, ROW_TILE):
50    float* o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
51
52  size_t output_height = input_height;
53  do {
54    $for M in range(2, 2 + ROW_TILE):
55      if XNN_UNPREDICTABLE(output_height < ${M}) {
56        i${M} = zero;
57        $if M <= ROW_TILE:
58          o${M-1} = o${M-2};
59      }
60
61    $for M in range(2 + ROW_TILE):
62      v128_t vi${M}x0123 = wasm_f32x4_const_splat(0.0f);
63
64    $for M in range(2 + ROW_TILE):
65      v128_t vi${M}x4567 = wasm_v128_load(i${M}); i${M} += 4;
66
67    size_t w = input_width;
68    for (; w > 4 * sizeof(float); w -= 4 * sizeof(float)) {
69      $for M in range(ROW_TILE):
70        v128_t vo${M}p0 = wasm_v32x4_shuffle(vw0123, vw0123, 0, 0, 0, 0);
71
72      $for M in range(2 + ROW_TILE):
73        const v128_t vi${M}x89AB = wasm_v128_load(i${M}); i${M} += 4;
74
75      $for M in range(ROW_TILE):
76        vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M}x4567, wasm_v32x4_shuffle(vw0123, vw0123, 2, 2, 2, 2)));
77
78      $for M in range(ROW_TILE):
79        $if ACCUMULATORS >= 2:
80          v128_t vo${M}p1 = wasm_f32x4_mul(vi${M+1}x4567, wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1));
81        $else:
82          vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M+1}x4567, wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1)));
83
84      $for M in range(ROW_TILE):
85        $if ACCUMULATORS >= 3:
86          v128_t vo${M}p2 = wasm_f32x4_mul(vi${M+2}x4567, wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0));
87        $else:
88          vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M+2}x4567, wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0)));
89
90      $for M in range(2 + ROW_TILE):
91        const v128_t vi${M}x3456 = wasm_v32x4_shuffle(vi${M}x0123, vi${M}x4567, 3, 4, 5, 6);
92
93      $for M in range(ROW_TILE):
94        $if ACCUMULATORS >= 4:
95          v128_t vo${M}p3 = wasm_f32x4_mul(vi${M}x3456, wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1));
96        $else:
97          vo${M}p${3 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${3 % ACCUMULATORS}, wasm_f32x4_mul(vi${M}x3456, wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1)));
98
99      $for M in range(ROW_TILE):
100        vo${M}p${4 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${4 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+1}x3456, wasm_v32x4_shuffle(vw4567, vw4567, 0, 0, 0, 0)));
101
102      $for M in range(ROW_TILE):
103        vo${M}p${5 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${5 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+2}x3456, wasm_v32x4_shuffle(vw4567, vw4567, 3, 3, 3, 3)));
104
105      $for M in range(2 + ROW_TILE):
106        vi${M}x0123 = vi${M}x4567;
107
108      $for M in range(2 + ROW_TILE):
109        const v128_t vi${M}x5678 = wasm_v32x4_shuffle(vi${M}x4567, vi${M}x89AB, 1, 2, 3, 4);
110
111      $for M in range(ROW_TILE):
112        vo${M}p${6 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${6 % ACCUMULATORS}, wasm_f32x4_mul(vi${M}x5678, wasm_v32x4_shuffle(vw0123, vw0123, 3, 3, 3, 3)));
113
114      $for M in range(ROW_TILE):
115        vo${M}p${7 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${7 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+1}x5678, wasm_v32x4_shuffle(vw4567, vw4567, 2, 2, 2, 2)));
116
117      $for M in range(ROW_TILE):
118        vo${M}p${8 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${8 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+2}x5678, wasm_v32x4_shuffle(vw89, vw89, 1, 1, 1, 1)));
119
120      $for M in range(2 + ROW_TILE):
121        vi${M}x4567 = vi${M}x89AB;
122
123      $if ACCUMULATORS > 1:
124        $ACC_SLICE = 1
125        $while ACC_SLICE < ACCUMULATORS:
126          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
127            $if A + ACC_SLICE < ACCUMULATORS:
128              $for M in range(ROW_TILE):
129                vo${M}p${A} = wasm_f32x4_add(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
130          $ACC_SLICE *= 2
131
132      $if X86:
133        $for M in range(ROW_TILE):
134          v128_t vo${M} = wasm_f32x4_pmax(vmin, vo${M}p0);
135        $for M in range(ROW_TILE):
136          vo${M} = wasm_f32x4_pmin(vmax, vo${M});
137      $else:
138        $for M in range(ROW_TILE):
139          v128_t vo${M} = wasm_f32x4_max(vo${M}p0, vmin);
140        $for M in range(ROW_TILE):
141          vo${M} = wasm_f32x4_min(vo${M}, vmax);
142
143      $for M in reversed(range(ROW_TILE)):
144        wasm_v128_store(o${M}, vo${M}); o${M} += 4;
145    }
146    // Always process the last block of 1..4 pixels.
147    assert(w >= 1 * sizeof(float));
148    assert(w <= 4 * sizeof(float));
149    {
150      $for M in range(ROW_TILE):
151        v128_t vo${M}p0 = wasm_v32x4_shuffle(vw0123, vw0123, 0, 0, 0, 0);
152
153      $for M in range(2 + ROW_TILE):
154        vi${M}x4567 = wasm_v128_and(vmask, vi${M}x4567);
155
156      $for M in range(ROW_TILE):
157        vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M}x4567, wasm_v32x4_shuffle(vw0123, vw0123, 2, 2, 2, 2)));
158
159      $for M in range(ROW_TILE):
160        $if ACCUMULATORS >= 2:
161          v128_t vo${M}p1 = wasm_f32x4_mul(vi${M+1}x4567, wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1));
162        $else:
163          vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M+1}x4567, wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1)));
164
165      $for M in range(ROW_TILE):
166        $if ACCUMULATORS >= 3:
167          v128_t vo${M}p2 = wasm_f32x4_mul(vi${M+2}x4567, wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0));
168        $else:
169          vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M+2}x4567, wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0)));
170
171      $for M in range(2 + ROW_TILE):
172        const v128_t vi${M}x3456 = wasm_v32x4_shuffle(vi${M}x0123, vi${M}x4567, 3, 4, 5, 6);
173
174      $for M in range(ROW_TILE):
175        $if ACCUMULATORS >= 4:
176          v128_t vo${M}p3 = wasm_f32x4_mul(vi${M}x3456, wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1));
177        $else:
178          vo${M}p${3 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${3 % ACCUMULATORS}, wasm_f32x4_mul(vi${M}x3456, wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1)));
179
180      $for M in range(ROW_TILE):
181        vo${M}p${4 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${4 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+1}x3456, wasm_v32x4_shuffle(vw4567, vw4567, 0, 0, 0, 0)));
182
183      $for M in range(ROW_TILE):
184        vo${M}p${5 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${5 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+2}x3456, wasm_v32x4_shuffle(vw4567, vw4567, 3, 3, 3, 3)));
185
186      const v128_t vzero = wasm_f32x4_const_splat(0.0f);
187      $for M in range(2 + ROW_TILE):
188        const v128_t vi${M}x5678 = wasm_v32x4_shuffle(vi${M}x4567, vzero, 1, 2, 3, 4);
189
190      $for M in range(ROW_TILE):
191        vo${M}p${6 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${6 % ACCUMULATORS}, wasm_f32x4_mul(vi${M}x5678, wasm_v32x4_shuffle(vw0123, vw0123, 3, 3, 3, 3)));
192
193      $for M in range(ROW_TILE):
194        vo${M}p${7 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${7 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+1}x5678, wasm_v32x4_shuffle(vw4567, vw4567, 2, 2, 2, 2)));
195
196      $for M in range(ROW_TILE):
197        vo${M}p${8 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${8 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+2}x5678, wasm_v32x4_shuffle(vw89, vw89, 1, 1, 1, 1)));
198
199      $if ACCUMULATORS > 1:
200        $ACC_SLICE = 1
201        $while ACC_SLICE < ACCUMULATORS:
202          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
203            $if A + ACC_SLICE < ACCUMULATORS:
204              $for M in range(ROW_TILE):
205                vo${M}p${A} = wasm_f32x4_add(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
206          $ACC_SLICE *= 2
207
208      $if X86:
209        $for M in range(ROW_TILE):
210          v128_t vo${M} = wasm_f32x4_pmax(vmin, vo${M}p0);
211        $for M in range(ROW_TILE):
212          vo${M} = wasm_f32x4_pmin(vmax, vo${M});
213      $else:
214        $for M in range(ROW_TILE):
215          v128_t vo${M} = wasm_f32x4_max(vo${M}p0, vmin);
216        $for M in range(ROW_TILE):
217          vo${M} = wasm_f32x4_min(vo${M}, vmax);
218
219      if XNN_LIKELY(w == 4 * sizeof(float)) {
220        $for M in reversed(range(ROW_TILE)):
221          wasm_v128_store(o${M}, vo${M}); o${M} += 4;
222      } else {
223        if (w & (2 * sizeof(float))) {
224          $for M in reversed(range(ROW_TILE)):
225            *((double*) o${M}) = wasm_f64x2_extract_lane(vo${M}, 0); o${M} += 2;
226
227          $for M in range(ROW_TILE):
228            vo${M} = wasm_v32x4_shuffle(vo${M}, vo${M}, 2, 3, 0, 1);
229        }
230        if (w & (1 * sizeof(float))) {
231          $for M in reversed(range(ROW_TILE)):
232            *o${M} = wasm_f32x4_extract_lane(vo${M}, 0); o${M} += 1;
233        }
234      }
235    }
236
237    i0 = (const float*) ((uintptr_t) i${ROW_TILE} - input_decrement);
238    i1 = (const float*) ((uintptr_t) i${ROW_TILE+1} - input_decrement);
239    $for M in range(2, 2 + ROW_TILE):
240      i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
241
242    $if ROW_TILE > 1:
243      o0 = o${ROW_TILE - 1};
244      $for M in range(1, ROW_TILE):
245        o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
246
247    $if ROW_TILE > 1:
248      output_height = doz(output_height, ${ROW_TILE});
249  } while (${"--" if ROW_TILE == 1 else ""}output_height != 0);
250}
251