1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 4 == 0 7$assert BATCH_TILE >= 4 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9#include <assert.h> 10 11#include <wasm_simd128.h> 12 13#include <xnnpack/vunary.h> 14#include <xnnpack/common.h> 15 16 17void xnn_f32_velu_ukernel__wasmsimd_${"x86" if X86 else "arm"}_rr2_p6_x${BATCH_TILE}( 18 size_t n, 19 const float* x, 20 float* y, 21 const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 22{ 23 assert(n != 0); 24 assert(n % sizeof(float) == 0); 25 assert(x != NULL); 26 assert(y != NULL); 27 28 const v128_t vprescale = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.prescale); 29 const v128_t valpha = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.alpha); 30 const v128_t vbeta = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.beta); 31 const v128_t vsat_cutoff = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.sat_cutoff); 32 const v128_t vmagic_bias = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.magic_bias); 33 const v128_t vlog2e = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.log2e); 34 const v128_t vminus_ln2_hi = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.minus_ln2_hi); 35 const v128_t vminus_ln2_lo = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.minus_ln2_lo); 36 const v128_t vc6 = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.c6); 37 const v128_t vc5 = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.c5); 38 const v128_t vc4 = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.c4); 39 const v128_t vc3 = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.c3); 40 const v128_t vc2 = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.c2); 41 const v128_t vone = wasm_v128_load64_splat(params->wasmsimd_rr2_p6.one); 42 43 $if BATCH_TILE > 4: 44 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 45 v128_t vx${ABC[0:4]} = wasm_v128_load(x); 46 $for N in range(4, BATCH_TILE, 4): 47 v128_t vx${ABC[N:N+4]} = wasm_v128_load(x + ${N}); 48 x += ${BATCH_TILE}; 49 50 $for N in range(0, BATCH_TILE, 4): 51 $if X86: 52 const v128_t vz${ABC[N:N+4]} = wasm_f32x4_mul(vx${ABC[N:N+4]}, vprescale); 53 $else: 54 const v128_t vz${ABC[N:N+4]} = wasm_f32x4_max(wasm_f32x4_mul(vx${ABC[N:N+4]}, vprescale), vsat_cutoff); 55 56 $for N in range(0, BATCH_TILE, 4): 57 v128_t vn${ABC[N:N+4]} = wasm_f32x4_add(wasm_f32x4_mul(vz${ABC[N:N+4]}, vlog2e), vmagic_bias); 58 59 $for N in range(0, BATCH_TILE, 4): 60 v128_t vs${ABC[N:N+4]} = wasm_i32x4_shl(vn${ABC[N:N+4]}, 23); 61 62 $for N in range(0, BATCH_TILE, 4): 63 vn${ABC[N:N+4]} = wasm_f32x4_sub(vn${ABC[N:N+4]}, vmagic_bias); 64 65 $for N in range(0, BATCH_TILE, 4): 66 v128_t vt${ABC[N:N+4]} = wasm_f32x4_add(wasm_f32x4_mul(vn${ABC[N:N+4]}, vminus_ln2_hi), vz${ABC[N:N+4]}); 67 $if X86: 68 const v128_t vsatm${ABC[N:N+4]} = wasm_f32x4_le(vz${ABC[N:N+4]}, vsat_cutoff); 69 70 $for N in range(0, BATCH_TILE, 4): 71 vt${ABC[N:N+4]} = wasm_f32x4_add(wasm_f32x4_mul(vn${ABC[N:N+4]}, vminus_ln2_lo), vt${ABC[N:N+4]}); 72 $if X86: 73 vs${ABC[N:N+4]} = wasm_v128_andnot(vs${ABC[N:N+4]}, vsatm${ABC[N:N+4]}); 74 75 $for N in range(0, BATCH_TILE, 4): 76 $if X86: 77 vt${ABC[N:N+4]} = wasm_v128_andnot(vt${ABC[N:N+4]}, vsatm${ABC[N:N+4]}); 78 v128_t vp${ABC[N:N+4]} = wasm_f32x4_add(wasm_f32x4_mul(vc6, vt${ABC[N:N+4]}), vc5); 79 80 $for N in range(0, BATCH_TILE, 4): 81 vp${ABC[N:N+4]} = wasm_f32x4_add(wasm_f32x4_mul(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc4); 82 83 $for N in range(0, BATCH_TILE, 4): 84 vp${ABC[N:N+4]} = wasm_f32x4_add(wasm_f32x4_mul(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc3); 85 86 $for N in range(0, BATCH_TILE, 4): 87 vp${ABC[N:N+4]} = wasm_f32x4_add(wasm_f32x4_mul(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc2); 88 89 $for N in range(0, BATCH_TILE, 4): 90 vp${ABC[N:N+4]} = wasm_f32x4_mul(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}); 91 92 $for N in range(0, BATCH_TILE, 4): 93 vt${ABC[N:N+4]} = wasm_f32x4_mul(vt${ABC[N:N+4]}, vs${ABC[N:N+4]}); 94 vs${ABC[N:N+4]} = wasm_f32x4_sub(vs${ABC[N:N+4]}, vone); 95 96 $for N in range(0, BATCH_TILE, 4): 97 vp${ABC[N:N+4]} = wasm_f32x4_add(wasm_f32x4_mul(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vt${ABC[N:N+4]}); 98 99 $for N in range(0, BATCH_TILE, 4): 100 const v128_t ve${ABC[N:N+4]} = wasm_f32x4_mul(wasm_f32x4_add(vp${ABC[N:N+4]}, vs${ABC[N:N+4]}), valpha); 101 102 $for N in range(0, BATCH_TILE, 4): 103 const v128_t vsignm${ABC[N:N+4]} = wasm_i32x4_shr(vx${ABC[N:N+4]}, 31); 104 vx${ABC[N:N+4]} = wasm_f32x4_mul(vx${ABC[N:N+4]}, vbeta); 105 106 $for N in range(0, BATCH_TILE, 4): 107 const v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(ve${ABC[N:N+4]}, vx${ABC[N:N+4]}, vsignm${ABC[N:N+4]}); 108 109 wasm_v128_store(y, vy${ABC[0:4]}); 110 $for N in range(4, BATCH_TILE, 4): 111 wasm_v128_store(y + ${N}, vy${ABC[N:N+4]}); 112 y += ${BATCH_TILE}; 113 } 114 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) { 115 v128_t vx = wasm_v128_load(x); 116 x += 4; 117 118 $if X86: 119 const v128_t vz = wasm_f32x4_mul(vx, vprescale); 120 $else: 121 const v128_t vz = wasm_f32x4_max(wasm_f32x4_mul(vx, vprescale), vsat_cutoff); 122 123 v128_t vn = wasm_f32x4_add(wasm_f32x4_mul(vz, vlog2e), vmagic_bias); 124 v128_t vs = wasm_i32x4_shl(vn, 23); 125 vn = wasm_f32x4_sub(vn, vmagic_bias); 126 127 v128_t vt = wasm_f32x4_add(wasm_f32x4_mul(vn, vminus_ln2_hi), vz); 128 $if X86: 129 const v128_t vsatm = wasm_f32x4_le(vz, vsat_cutoff); 130 vt = wasm_f32x4_add(wasm_f32x4_mul(vn, vminus_ln2_lo), vt); 131 $if X86: 132 vs = wasm_v128_andnot(vs, vsatm); 133 vt = wasm_v128_andnot(vt, vsatm); 134 135 v128_t vp = wasm_f32x4_add(wasm_f32x4_mul(vc6, vt), vc5); 136 vp = wasm_f32x4_add(wasm_f32x4_mul(vp, vt), vc4); 137 vp = wasm_f32x4_add(wasm_f32x4_mul(vp, vt), vc3); 138 vp = wasm_f32x4_add(wasm_f32x4_mul(vp, vt), vc2); 139 vp = wasm_f32x4_mul(vp, vt); 140 141 vt = wasm_f32x4_mul(vt, vs); 142 vs = wasm_f32x4_sub(vs, vone); 143 vp = wasm_f32x4_add(wasm_f32x4_mul(vp, vt), vt); 144 const v128_t ve = wasm_f32x4_mul(wasm_f32x4_add(vp, vs), valpha); 145 146 const v128_t vsignm = wasm_i32x4_shr(vx, 31); 147 vx = wasm_f32x4_mul(vx, vbeta); 148 const v128_t vy = wasm_v128_bitselect(ve, vx, vsignm); 149 150 wasm_v128_store(y, vy); 151 y += 4; 152 } 153 if XNN_UNLIKELY(n != 0) { 154 v128_t vx = wasm_v128_load(x); 155 156 $if X86: 157 const v128_t vz = wasm_f32x4_mul(vx, vprescale); 158 $else: 159 const v128_t vz = wasm_f32x4_max(wasm_f32x4_mul(vx, vprescale), vsat_cutoff); 160 161 v128_t vn = wasm_f32x4_add(wasm_f32x4_mul(vz, vlog2e), vmagic_bias); 162 v128_t vs = wasm_i32x4_shl(vn, 23); 163 vn = wasm_f32x4_sub(vn, vmagic_bias); 164 165 v128_t vt = wasm_f32x4_add(wasm_f32x4_mul(vn, vminus_ln2_hi), vz); 166 $if X86: 167 const v128_t vsatm = wasm_f32x4_le(vz, vsat_cutoff); 168 vt = wasm_f32x4_add(wasm_f32x4_mul(vn, vminus_ln2_lo), vt); 169 $if X86: 170 vs = wasm_v128_andnot(vs, vsatm); 171 vt = wasm_v128_andnot(vt, vsatm); 172 173 v128_t vp = wasm_f32x4_add(wasm_f32x4_mul(vc6, vt), vc5); 174 vp = wasm_f32x4_add(wasm_f32x4_mul(vp, vt), vc4); 175 vp = wasm_f32x4_add(wasm_f32x4_mul(vp, vt), vc3); 176 vp = wasm_f32x4_add(wasm_f32x4_mul(vp, vt), vc2); 177 vp = wasm_f32x4_mul(vp, vt); 178 179 vt = wasm_f32x4_mul(vt, vs); 180 vs = wasm_f32x4_sub(vs, vone); 181 vp = wasm_f32x4_add(wasm_f32x4_mul(vp, vt), vt); 182 const v128_t ve = wasm_f32x4_mul(wasm_f32x4_add(vp, vs), valpha); 183 184 const v128_t vsignm = wasm_i32x4_shr(vx, 31); 185 vx = wasm_f32x4_mul(vx, vbeta); 186 v128_t vy = wasm_v128_bitselect(ve, vx, vsignm); 187 188 if (n & (2 * sizeof(float))) { 189 *((double*) y) = wasm_f64x2_extract_lane(vy, 0); 190 vy = wasm_v32x4_shuffle(vy, vy, 2, 3, 2, 3); 191 y += 2; 192 } 193 if (n & (1 * sizeof(float))) { 194 *y = wasm_f32x4_extract_lane(vy, 0); 195 } 196 } 197} 198