xref: /aosp_15_r20/external/XNNPACK/src/f32-dwconv/up-wasmsimd.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2020 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker//
3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker$assert CHANNEL_TILE % 4 == 0
7*4bdc9457SAndroid Build Coastguard Worker$assert KERNEL_TILE >= 2
8*4bdc9457SAndroid Build Coastguard Worker$assert ACCUMULATORS >= 1
9*4bdc9457SAndroid Build Coastguard Worker$assert ACTIVATION != "MINMAX" or ARCH in ["ARM", "X86", "RELAXED"]
10*4bdc9457SAndroid Build Coastguard Worker$assert not FMA or ARCH == "RELAXED"
11*4bdc9457SAndroid Build Coastguard Worker$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
12*4bdc9457SAndroid Build Coastguard Worker#include <assert.h>
13*4bdc9457SAndroid Build Coastguard Worker
14*4bdc9457SAndroid Build Coastguard Worker#include <wasm_simd128.h>
15*4bdc9457SAndroid Build Coastguard Worker
16*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/dwconv.h>
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker
19*4bdc9457SAndroid Build Coastguard Worker$assert ACTIVATION in ["LINEAR", "RELU", "MINMAX"]
20*4bdc9457SAndroid Build Coastguard Worker$if ACTIVATION == "MINMAX":
21*4bdc9457SAndroid Build Coastguard Worker$  WASM_F32X4_MIN={"ARM": "wasm_f32x4_min", "X86": "wasm_f32x4_pmin", "RELAXED": "__builtin_wasm_relaxed_min_f32x4"}[ARCH]
22*4bdc9457SAndroid Build Coastguard Worker$  WASM_F32X4_MAX={"ARM": "wasm_f32x4_max", "X86": "wasm_f32x4_pmax", "RELAXED": "__builtin_wasm_relaxed_max_f32x4"}[ARCH]
23*4bdc9457SAndroid Build Coastguard Worker$ACTIVATION_SUFFIX = {"LINEAR": ""}.get(ACTIVATION, "_" + ACTIVATION.lower())
24*4bdc9457SAndroid Build Coastguard Worker$ISA = "wasmsimd" if not FMA and (ACTIVATION in ["LINEAR", "RELU"] or ARCH != "RELAXED") else "wasmrelaxedsimd"
25*4bdc9457SAndroid Build Coastguard Worker$ARCH_SUFFIX = "" if not FMA and (ACTIVATION in ["LINEAR", "RELU"] or ARCH == "RELAXED") else "_" + ("fma" if FMA else ARCH.lower())
26*4bdc9457SAndroid Build Coastguard Worker$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
27*4bdc9457SAndroid Build Coastguard Workervoid xnn_f32_dwconv${ACTIVATION_SUFFIX}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${ISA}${ARCH_SUFFIX}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
28*4bdc9457SAndroid Build Coastguard Worker    size_t channels,
29*4bdc9457SAndroid Build Coastguard Worker    size_t output_width,
30*4bdc9457SAndroid Build Coastguard Worker    const float** input,
31*4bdc9457SAndroid Build Coastguard Worker    const float* weights,
32*4bdc9457SAndroid Build Coastguard Worker    float* output,
33*4bdc9457SAndroid Build Coastguard Worker    size_t input_stride,
34*4bdc9457SAndroid Build Coastguard Worker    size_t output_increment,
35*4bdc9457SAndroid Build Coastguard Worker    size_t input_offset,
36*4bdc9457SAndroid Build Coastguard Worker    const float* zero,
37*4bdc9457SAndroid Build Coastguard Worker    const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
38*4bdc9457SAndroid Build Coastguard Worker{
39*4bdc9457SAndroid Build Coastguard Worker  assert(channels != 0);
40*4bdc9457SAndroid Build Coastguard Worker  assert(output_width != 0);
41*4bdc9457SAndroid Build Coastguard Worker
42*4bdc9457SAndroid Build Coastguard Worker  $if ACTIVATION == "MINMAX":
43*4bdc9457SAndroid Build Coastguard Worker    const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min);
44*4bdc9457SAndroid Build Coastguard Worker    const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max);
45*4bdc9457SAndroid Build Coastguard Worker  $elif ACTIVATION == "RELU":
46*4bdc9457SAndroid Build Coastguard Worker    const v128_t vzero = wasm_i32x4_const_splat(0);
47*4bdc9457SAndroid Build Coastguard Worker  do {
48*4bdc9457SAndroid Build Coastguard Worker    $for K in range(KERNEL_TILE):
49*4bdc9457SAndroid Build Coastguard Worker      const float* i${K} = input[${K}];
50*4bdc9457SAndroid Build Coastguard Worker      assert(i${K} != NULL);
51*4bdc9457SAndroid Build Coastguard Worker      if XNN_UNPREDICTABLE(i${K} != zero) {
52*4bdc9457SAndroid Build Coastguard Worker        i${K} = (const float*) ((uintptr_t) i${K} + input_offset);
53*4bdc9457SAndroid Build Coastguard Worker      }
54*4bdc9457SAndroid Build Coastguard Worker    input = (const float**) ((uintptr_t) input + input_stride);
55*4bdc9457SAndroid Build Coastguard Worker
56*4bdc9457SAndroid Build Coastguard Worker    size_t c = channels;
57*4bdc9457SAndroid Build Coastguard Worker    const float* w = weights;
58*4bdc9457SAndroid Build Coastguard Worker    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
59*4bdc9457SAndroid Build Coastguard Worker      v128_t vacc${ABC[0:4]}p0 = wasm_v128_load(w);
60*4bdc9457SAndroid Build Coastguard Worker      $for C in range(4, CHANNEL_TILE, 4):
61*4bdc9457SAndroid Build Coastguard Worker        v128_t vacc${ABC[C:C+4]}p0 = wasm_v128_load(w + ${C});
62*4bdc9457SAndroid Build Coastguard Worker
63*4bdc9457SAndroid Build Coastguard Worker      $for K in range(KERNEL_TILE):
64*4bdc9457SAndroid Build Coastguard Worker
65*4bdc9457SAndroid Build Coastguard Worker        const v128_t vi${K}x${ABC[0:4]} = wasm_v128_load(i${K});
66*4bdc9457SAndroid Build Coastguard Worker        $for C in range(4, CHANNEL_TILE, 4):
67*4bdc9457SAndroid Build Coastguard Worker          const v128_t vi${K}x${ABC[C:C+4]} = wasm_v128_load(i${K} + ${C});
68*4bdc9457SAndroid Build Coastguard Worker        i${K} += ${CHANNEL_TILE};
69*4bdc9457SAndroid Build Coastguard Worker
70*4bdc9457SAndroid Build Coastguard Worker        $for C in range(0, CHANNEL_TILE, 4):
71*4bdc9457SAndroid Build Coastguard Worker          const v128_t vk${K}x${ABC[C:C+4]} = wasm_v128_load(w + ${(K + 1) * CHANNEL_TILE + C});
72*4bdc9457SAndroid Build Coastguard Worker        $for C in range(0, CHANNEL_TILE, 4):
73*4bdc9457SAndroid Build Coastguard Worker          $if 1 <= K < ACCUMULATORS:
74*4bdc9457SAndroid Build Coastguard Worker            v128_t vacc${ABC[C:C+4]}p${K} = wasm_f32x4_mul(vi${K}x${ABC[C:C+4]}, vk${K}x${ABC[C:C+4]});
75*4bdc9457SAndroid Build Coastguard Worker          $else:
76*4bdc9457SAndroid Build Coastguard Worker            $if FMA:
77*4bdc9457SAndroid Build Coastguard Worker              vacc${ABC[C:C+4]}p${K % ACCUMULATORS} = __builtin_wasm_fma_f32x4(vacc${ABC[C:C+4]}p${K % ACCUMULATORS}, vi${K}x${ABC[C:C+4]}, vk${K}x${ABC[C:C+4]});
78*4bdc9457SAndroid Build Coastguard Worker            $else:
79*4bdc9457SAndroid Build Coastguard Worker              vacc${ABC[C:C+4]}p${K % ACCUMULATORS} = wasm_f32x4_add(vacc${ABC[C:C+4]}p${K % ACCUMULATORS}, wasm_f32x4_mul(vi${K}x${ABC[C:C+4]}, vk${K}x${ABC[C:C+4]}));
80*4bdc9457SAndroid Build Coastguard Worker
81*4bdc9457SAndroid Build Coastguard Worker      w += ${(KERNEL_TILE + 1) * CHANNEL_TILE};
82*4bdc9457SAndroid Build Coastguard Worker
83*4bdc9457SAndroid Build Coastguard Worker      $if ACCUMULATORS > 1:
84*4bdc9457SAndroid Build Coastguard Worker        // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
85*4bdc9457SAndroid Build Coastguard Worker        $ACC_SLICE = 1
86*4bdc9457SAndroid Build Coastguard Worker        $while ACC_SLICE < ACCUMULATORS:
87*4bdc9457SAndroid Build Coastguard Worker          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
88*4bdc9457SAndroid Build Coastguard Worker            $if A + ACC_SLICE < ACCUMULATORS:
89*4bdc9457SAndroid Build Coastguard Worker              $for C in range(0, CHANNEL_TILE, 4):
90*4bdc9457SAndroid Build Coastguard Worker                vacc${ABC[C:C+4]}p${A} = wasm_f32x4_add(vacc${ABC[C:C+4]}p${A}, vacc${ABC[C:C+4]}p${A + ACC_SLICE});
91*4bdc9457SAndroid Build Coastguard Worker          $ACC_SLICE *= 2
92*4bdc9457SAndroid Build Coastguard Worker
93*4bdc9457SAndroid Build Coastguard Worker      $if ACTIVATION == "MINMAX":
94*4bdc9457SAndroid Build Coastguard Worker        $for C in range(0, CHANNEL_TILE, 4):
95*4bdc9457SAndroid Build Coastguard Worker          v128_t vacc${ABC[C:C+4]} = ${WASM_F32X4_MAX}(vmin, vacc${ABC[C:C+4]}p0);
96*4bdc9457SAndroid Build Coastguard Worker
97*4bdc9457SAndroid Build Coastguard Worker        $for C in range(0, CHANNEL_TILE, 4):
98*4bdc9457SAndroid Build Coastguard Worker          vacc${ABC[C:C+4]} = ${WASM_F32X4_MIN}(vmax, vacc${ABC[C:C+4]});
99*4bdc9457SAndroid Build Coastguard Worker      $elif ACTIVATION == "RELU":
100*4bdc9457SAndroid Build Coastguard Worker        $for C in range(0, CHANNEL_TILE, 4):
101*4bdc9457SAndroid Build Coastguard Worker          const v128_t vacc${ABC[C:C+4]} = wasm_i32x4_max(vacc${ABC[C:C+4]}p0, vzero);
102*4bdc9457SAndroid Build Coastguard Worker      $elif ACTIVATION == "LINEAR":
103*4bdc9457SAndroid Build Coastguard Worker        $for C in range(0, CHANNEL_TILE, 4):
104*4bdc9457SAndroid Build Coastguard Worker          const v128_t vacc${ABC[C:C+4]} = vacc${ABC[C:C+4]}p0;
105*4bdc9457SAndroid Build Coastguard Worker
106*4bdc9457SAndroid Build Coastguard Worker      wasm_v128_store(output, vacc${ABC[0:4]});
107*4bdc9457SAndroid Build Coastguard Worker      $for C in range(4, CHANNEL_TILE, 4):
108*4bdc9457SAndroid Build Coastguard Worker        wasm_v128_store(output + ${C}, vacc${ABC[C:C+4]});
109*4bdc9457SAndroid Build Coastguard Worker      output += ${CHANNEL_TILE};
110*4bdc9457SAndroid Build Coastguard Worker    }
111*4bdc9457SAndroid Build Coastguard Worker    $if CHANNEL_TILE > 4:
112*4bdc9457SAndroid Build Coastguard Worker      for (; c >= 4; c -= 4) {
113*4bdc9457SAndroid Build Coastguard Worker        v128_t vacc0123p0 = wasm_v128_load(w);
114*4bdc9457SAndroid Build Coastguard Worker        $for K in range(KERNEL_TILE):
115*4bdc9457SAndroid Build Coastguard Worker
116*4bdc9457SAndroid Build Coastguard Worker          const v128_t vi${K}x0123 = wasm_v128_load(i${K});
117*4bdc9457SAndroid Build Coastguard Worker          i${K} += 4;
118*4bdc9457SAndroid Build Coastguard Worker
119*4bdc9457SAndroid Build Coastguard Worker          const v128_t vk${K}x0123 = wasm_v128_load(w + ${(K + 1) * CHANNEL_TILE});
120*4bdc9457SAndroid Build Coastguard Worker          $if 1 <= K < ACCUMULATORS:
121*4bdc9457SAndroid Build Coastguard Worker            v128_t vacc0123p${K} = wasm_f32x4_mul(vi${K}x0123, vk${K}x0123);
122*4bdc9457SAndroid Build Coastguard Worker          $else:
123*4bdc9457SAndroid Build Coastguard Worker            $if FMA:
124*4bdc9457SAndroid Build Coastguard Worker              vacc0123p${K % ACCUMULATORS} = __builtin_wasm_fma_f32x4(vacc0123p${K % ACCUMULATORS}, vi${K}x0123, vk${K}x0123);
125*4bdc9457SAndroid Build Coastguard Worker            $else:
126*4bdc9457SAndroid Build Coastguard Worker              vacc0123p${K % ACCUMULATORS} = wasm_f32x4_add(vacc0123p${K % ACCUMULATORS}, wasm_f32x4_mul(vi${K}x0123, vk${K}x0123));
127*4bdc9457SAndroid Build Coastguard Worker
128*4bdc9457SAndroid Build Coastguard Worker        w += 4;
129*4bdc9457SAndroid Build Coastguard Worker
130*4bdc9457SAndroid Build Coastguard Worker        $if ACCUMULATORS > 1:
131*4bdc9457SAndroid Build Coastguard Worker          // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
132*4bdc9457SAndroid Build Coastguard Worker          $ACC_SLICE = 1
133*4bdc9457SAndroid Build Coastguard Worker          $while ACC_SLICE < ACCUMULATORS:
134*4bdc9457SAndroid Build Coastguard Worker            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
135*4bdc9457SAndroid Build Coastguard Worker              $if A + ACC_SLICE < ACCUMULATORS:
136*4bdc9457SAndroid Build Coastguard Worker                vacc0123p${A} = wasm_f32x4_add(vacc0123p${A}, vacc0123p${A + ACC_SLICE});
137*4bdc9457SAndroid Build Coastguard Worker            $ACC_SLICE *= 2
138*4bdc9457SAndroid Build Coastguard Worker
139*4bdc9457SAndroid Build Coastguard Worker        $if ACTIVATION == "MINMAX":
140*4bdc9457SAndroid Build Coastguard Worker          v128_t vacc0123 = ${WASM_F32X4_MAX}(vmin, vacc0123p0);
141*4bdc9457SAndroid Build Coastguard Worker          vacc0123 = ${WASM_F32X4_MIN}(vmax, vacc0123);
142*4bdc9457SAndroid Build Coastguard Worker        $elif ACTIVATION == "RELU":
143*4bdc9457SAndroid Build Coastguard Worker          const v128_t vacc0123 = wasm_i32x4_max(vacc0123p0, vzero);
144*4bdc9457SAndroid Build Coastguard Worker        $elif ACTIVATION == "LINEAR":
145*4bdc9457SAndroid Build Coastguard Worker          const v128_t vacc0123 = vacc0123p0;
146*4bdc9457SAndroid Build Coastguard Worker
147*4bdc9457SAndroid Build Coastguard Worker        wasm_v128_store(output, vacc0123);
148*4bdc9457SAndroid Build Coastguard Worker        output += 4;
149*4bdc9457SAndroid Build Coastguard Worker      }
150*4bdc9457SAndroid Build Coastguard Worker    if XNN_UNLIKELY(c != 0) {
151*4bdc9457SAndroid Build Coastguard Worker      v128_t vacc0123p0 = wasm_v128_load(w);
152*4bdc9457SAndroid Build Coastguard Worker      $for K in range(KERNEL_TILE):
153*4bdc9457SAndroid Build Coastguard Worker
154*4bdc9457SAndroid Build Coastguard Worker        const v128_t vi${K}x0123 = wasm_v128_load(i${K});
155*4bdc9457SAndroid Build Coastguard Worker        const v128_t vk${K}x0123 = wasm_v128_load(w + ${(K+1) * CHANNEL_TILE});
156*4bdc9457SAndroid Build Coastguard Worker        $if 1 <= K < ACCUMULATORS:
157*4bdc9457SAndroid Build Coastguard Worker          v128_t vacc0123p${K} = wasm_f32x4_mul(vi${K}x0123, vk${K}x0123);
158*4bdc9457SAndroid Build Coastguard Worker        $else:
159*4bdc9457SAndroid Build Coastguard Worker          $if FMA:
160*4bdc9457SAndroid Build Coastguard Worker            vacc0123p${K % ACCUMULATORS} = __builtin_wasm_fma_f32x4(vacc0123p${K % ACCUMULATORS}, vi${K}x0123, vk${K}x0123);
161*4bdc9457SAndroid Build Coastguard Worker          $else:
162*4bdc9457SAndroid Build Coastguard Worker            vacc0123p${K % ACCUMULATORS} = wasm_f32x4_add(vacc0123p${K % ACCUMULATORS}, wasm_f32x4_mul(vi${K}x0123, vk${K}x0123));
163*4bdc9457SAndroid Build Coastguard Worker
164*4bdc9457SAndroid Build Coastguard Worker      $if ACCUMULATORS > 1:
165*4bdc9457SAndroid Build Coastguard Worker        // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
166*4bdc9457SAndroid Build Coastguard Worker        $ACC_SLICE = 1
167*4bdc9457SAndroid Build Coastguard Worker        $while ACC_SLICE < ACCUMULATORS:
168*4bdc9457SAndroid Build Coastguard Worker          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
169*4bdc9457SAndroid Build Coastguard Worker            $if A + ACC_SLICE < ACCUMULATORS:
170*4bdc9457SAndroid Build Coastguard Worker              vacc0123p${A} = wasm_f32x4_add(vacc0123p${A}, vacc0123p${A + ACC_SLICE});
171*4bdc9457SAndroid Build Coastguard Worker          $ACC_SLICE *= 2
172*4bdc9457SAndroid Build Coastguard Worker
173*4bdc9457SAndroid Build Coastguard Worker      $if ACTIVATION == "MINMAX":
174*4bdc9457SAndroid Build Coastguard Worker        v128_t vacc0123 = ${WASM_F32X4_MAX}(vmin, vacc0123p0);
175*4bdc9457SAndroid Build Coastguard Worker        vacc0123 = ${WASM_F32X4_MIN}(vmax, vacc0123);
176*4bdc9457SAndroid Build Coastguard Worker      $elif ACTIVATION == "RELU":
177*4bdc9457SAndroid Build Coastguard Worker        v128_t vacc0123 = wasm_i32x4_max(vacc0123p0, vzero);
178*4bdc9457SAndroid Build Coastguard Worker      $elif ACTIVATION == "LINEAR":
179*4bdc9457SAndroid Build Coastguard Worker        v128_t vacc0123 = vacc0123p0;
180*4bdc9457SAndroid Build Coastguard Worker
181*4bdc9457SAndroid Build Coastguard Worker      if (c & 2) {
182*4bdc9457SAndroid Build Coastguard Worker        *((double*) output) = wasm_f64x2_extract_lane(vacc0123, 0);
183*4bdc9457SAndroid Build Coastguard Worker        vacc0123 = wasm_v32x4_shuffle(vacc0123, vacc0123, 2, 3, 2, 3);
184*4bdc9457SAndroid Build Coastguard Worker        output += 2;
185*4bdc9457SAndroid Build Coastguard Worker      }
186*4bdc9457SAndroid Build Coastguard Worker      if (c & 1) {
187*4bdc9457SAndroid Build Coastguard Worker        *output = wasm_f32x4_extract_lane(vacc0123, 0);
188*4bdc9457SAndroid Build Coastguard Worker        output += 1;
189*4bdc9457SAndroid Build Coastguard Worker      }
190*4bdc9457SAndroid Build Coastguard Worker    }
191*4bdc9457SAndroid Build Coastguard Worker
192*4bdc9457SAndroid Build Coastguard Worker    output = (float*) ((uintptr_t) output + output_increment);
193*4bdc9457SAndroid Build Coastguard Worker  } while (--output_width != 0);
194*4bdc9457SAndroid Build Coastguard Worker}
195