xref: /aosp_15_r20/external/XNNPACK/src/f32-dwconv2d-chw/3x3p1-wasmsimd-loadsplat.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ROW_TILE >= 1
7$assert ACCUMULATORS >= 1
8#include <assert.h>
9
10#include <wasm_simd128.h>
11
12
13#include <xnnpack/dwconv.h>
14#include <xnnpack/math.h>
15
16
17$ARCH_SUFFIX = "_x86" if X86 else "_arm"
18
19void xnn_f32_dwconv2d_chw_ukernel_3x3p1__wasmsimd${ARCH_SUFFIX}_loadsplat_${ROW_TILE}x4${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
20    size_t input_height,
21    size_t input_width,
22    const float* input,
23    const float* weights,
24    const float* zero,
25    float* output,
26    uint32_t padding_top,
27    const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
28{
29  assert(input_height != 0);
30  assert(input_width != 0);
31  assert(input_width % sizeof(float) == 0);
32  assert(padding_top == 1);
33
34  const v128_t vmask = wasm_v128_load(params->scalar.mask);
35  const v128_t vmax = wasm_v128_load32_splat(&params->scalar.max);
36  const v128_t vmin = wasm_v128_load32_splat(&params->scalar.min);
37
38  const v128_t vw0123 = wasm_v128_load(weights);
39  const v128_t vw4567 = wasm_v128_load(weights + 4);
40  const v128_t vw89 = wasm_v128_load64_splat(weights + 8);
41  const v128_t vbias = wasm_v32x4_shuffle(vw0123, vw0123, 0, 0, 0, 0);
42  const v128_t vk00 = wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1);
43  const v128_t vk01 = wasm_v32x4_shuffle(vw0123, vw0123, 2, 2, 2, 2);
44  const v128_t vk02 = wasm_v32x4_shuffle(vw0123, vw0123, 3, 3, 3, 3);
45  const v128_t vk10 = wasm_v32x4_shuffle(vw4567, vw4567, 0, 0, 0, 0);
46  const v128_t vk11 = wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1);
47  const v128_t vk12 = wasm_v32x4_shuffle(vw4567, vw4567, 2, 2, 2, 2);
48  const v128_t vk20 = wasm_v32x4_shuffle(vw4567, vw4567, 3, 3, 3, 3);
49  const v128_t vk21 = wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0);
50  const v128_t vk22 = wasm_v32x4_shuffle(vw89, vw89, 1, 1, 1, 1);
51
52  const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float));
53
54  const float* i0 = zero;
55  const float* i1 = input;
56  $for M in range(2, 2 + ROW_TILE):
57    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
58
59  float* o0 = output;
60  $for M in range(1, ROW_TILE):
61    float* o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
62
63  size_t output_height = input_height;
64  do {
65    $for M in range(2, 2 + ROW_TILE):
66      if XNN_UNPREDICTABLE(output_height < ${M}) {
67        i${M} = zero;
68        $if M <= ROW_TILE:
69          o${M-1} = o${M-2};
70      }
71
72    $for M in range(2 + ROW_TILE):
73      v128_t vi${M}x0123 = wasm_f32x4_const_splat(0.0f);
74
75    $for M in range(2 + ROW_TILE):
76      v128_t vi${M}x4567 = wasm_v128_load(i${M});
77      i${M} += 4;
78
79    size_t w = input_width;
80    for (; w > 4 * sizeof(float); w -= 4 * sizeof(float)) {
81      $for M in range(2 + ROW_TILE):
82        const v128_t vi${M}x89AB = wasm_v128_load(i${M});
83        i${M} += 4;
84
85      $for K in range(3):
86        $for M in range(ROW_TILE):
87          $if K == 0:
88            v128_t vo${M}p0 = wasm_f32x4_add(vbias, wasm_f32x4_mul(vi${M+K}x4567, vk${K}1));
89          $elif K < ACCUMULATORS:
90            v128_t vo${M}p${K} = wasm_f32x4_mul(vi${M+K}x4567, vk${K}1);
91          $else:
92            vo${M}p${K % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${K % ACCUMULATORS}, wasm_f32x4_mul(vi${M+K}x4567, vk${K}1));
93
94      $for M in range(2 + ROW_TILE):
95        const v128_t vi${M}x3456 = wasm_v32x4_shuffle(vi${M}x0123, vi${M}x4567, 3, 4, 5, 6);
96
97      $for K in range(3):
98        $for M in range(ROW_TILE):
99          $if K+3 < ACCUMULATORS:
100            v128_t vo${M}p${K+3} = wasm_f32x4_mul(vi${M+K}x3456, vk${K}0);
101          $else:
102            vo${M}p${(K+3) % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${(K+3) % ACCUMULATORS}, wasm_f32x4_mul(vi${M+K}x3456, vk${K}0));
103
104      $for M in range(2 + ROW_TILE):
105        vi${M}x0123 = vi${M}x4567;
106
107      $for M in range(2 + ROW_TILE):
108        const v128_t vi${M}x5678 = wasm_v32x4_shuffle(vi${M}x4567, vi${M}x89AB, 1, 2, 3, 4);
109
110      $for K in range(3):
111        $for M in range(ROW_TILE):
112          vo${M}p${(K+6) % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${(K+6) % ACCUMULATORS}, wasm_f32x4_mul(vi${M+K}x5678, vk${K}2));
113
114      $for M in range(2 + ROW_TILE):
115        vi${M}x4567 = vi${M}x89AB;
116
117      $if ACCUMULATORS > 1:
118        $ACC_SLICE = 1
119        $while ACC_SLICE < ACCUMULATORS:
120          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
121            $if A + ACC_SLICE < ACCUMULATORS:
122              $for M in range(ROW_TILE):
123                vo${M}p${A} = wasm_f32x4_add(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
124          $ACC_SLICE *= 2
125
126      $if X86:
127        $for M in range(ROW_TILE):
128          v128_t vo${M} = wasm_f32x4_pmax(vmin, vo${M}p0);
129        $for M in range(ROW_TILE):
130          vo${M} = wasm_f32x4_pmin(vmax, vo${M});
131      $else:
132        $for M in range(ROW_TILE):
133          v128_t vo${M} = wasm_f32x4_max(vo${M}p0, vmin);
134        $for M in range(ROW_TILE):
135          vo${M} = wasm_f32x4_min(vo${M}, vmax);
136
137      $for M in reversed(range(ROW_TILE)):
138        wasm_v128_store(o${M}, vo${M});
139        o${M} += 4;
140    }
141    // Always process the last block of 1..4 pixels.
142    assert(w >= 1 * sizeof(float));
143    assert(w <= 4 * sizeof(float));
144    {
145      $for M in range(2 + ROW_TILE):
146        vi${M}x4567 = wasm_v128_and(vmask, vi${M}x4567);
147
148      $for K in range(3):
149        $for M in range(ROW_TILE):
150          $if K == 0:
151            v128_t vo${M}p0 = wasm_f32x4_add(vbias, wasm_f32x4_mul(vi${M+K}x4567, vk${K}1));
152          $elif K < ACCUMULATORS:
153            v128_t vo${M}p${K} = wasm_f32x4_mul(vi${M+K}x4567, vk${K}1);
154          $else:
155            vo${M}p${K % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${K % ACCUMULATORS}, wasm_f32x4_mul(vi${M+K}x4567, vk${K}1));
156
157      $for M in range(2 + ROW_TILE):
158        const v128_t vi${M}x3456 = wasm_v32x4_shuffle(vi${M}x0123, vi${M}x4567, 3, 4, 5, 6);
159
160      $for K in range(3):
161        $for M in range(ROW_TILE):
162          $if K+3 < ACCUMULATORS:
163            v128_t vo${M}p${K+3} = wasm_f32x4_mul(vi${M+K}x3456, vk${K}0);
164          $else:
165            vo${M}p${(K+3) % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${(K+3) % ACCUMULATORS}, wasm_f32x4_mul(vi${M+K}x3456, vk${K}0));
166
167      const v128_t vzero = wasm_f32x4_const_splat(0.0f);
168      $for M in range(2 + ROW_TILE):
169        const v128_t vi${M}x5678 = wasm_v32x4_shuffle(vi${M}x4567, vzero, 1, 2, 3, 4);
170
171      $for K in range(3):
172        $for M in range(ROW_TILE):
173          vo${M}p${(K+6) % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${(K+6) % ACCUMULATORS}, wasm_f32x4_mul(vi${M+K}x5678, vk${K}2));
174
175      $if ACCUMULATORS > 1:
176        $ACC_SLICE = 1
177        $while ACC_SLICE < ACCUMULATORS:
178          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
179            $if A + ACC_SLICE < ACCUMULATORS:
180              $for M in range(ROW_TILE):
181                vo${M}p${A} = wasm_f32x4_add(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
182          $ACC_SLICE *= 2
183
184      $if X86:
185        $for M in range(ROW_TILE):
186          v128_t vo${M} = wasm_f32x4_pmax(vmin, vo${M}p0);
187        $for M in range(ROW_TILE):
188          vo${M} = wasm_f32x4_pmin(vmax, vo${M});
189      $else:
190        $for M in range(ROW_TILE):
191          v128_t vo${M} = wasm_f32x4_max(vo${M}p0, vmin);
192        $for M in range(ROW_TILE):
193          vo${M} = wasm_f32x4_min(vo${M}, vmax);
194
195      if XNN_LIKELY(w == 4 * sizeof(float)) {
196        $for M in reversed(range(ROW_TILE)):
197          wasm_v128_store(o${M}, vo${M});
198          o${M} += 4;
199      } else {
200        if (w & (2 * sizeof(float))) {
201          $for M in reversed(range(ROW_TILE)):
202            *((double*) o${M}) = wasm_f64x2_extract_lane(vo${M}, 0);
203            o${M} += 2;
204
205          $for M in range(ROW_TILE):
206            vo${M} = wasm_v32x4_shuffle(vo${M}, vo${M}, 2, 3, 0, 1);
207        }
208        if (w & (1 * sizeof(float))) {
209          $for M in reversed(range(ROW_TILE)):
210            *o${M} = wasm_f32x4_extract_lane(vo${M}, 0);
211            o${M} += 1;
212        }
213      }
214    }
215
216    i0 = (const float*) ((uintptr_t) i${ROW_TILE} - input_decrement);
217    i1 = (const float*) ((uintptr_t) i${ROW_TILE+1} - input_decrement);
218    $for M in range(2, 2 + ROW_TILE):
219      i${M} = (const float*) ((uintptr_t) i${M-1} + input_width);
220
221    $if ROW_TILE > 1:
222      o0 = o${ROW_TILE - 1};
223      $for M in range(1, ROW_TILE):
224        o${M} = (float*) ((uintptr_t) o${M-1} + input_width);
225
226    $if ROW_TILE > 1:
227      output_height = doz(output_height, ${ROW_TILE});
228  } while (${"--" if ROW_TILE == 1 else ""}output_height != 0);
229}
230