1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert ROW_TILE >= 1 7$assert ACCUMULATORS >= 1 8#include <assert.h> 9 10#include <wasm_simd128.h> 11 12#include <xnnpack/dwconv.h> 13#include <xnnpack/math.h> 14 15 16$ARCH_SUFFIX = "_x86" if X86 else "_arm" 17 18void xnn_f32_dwconv2d_chw_ukernel_3x3p1__wasmsimd${ARCH_SUFFIX}_splat_${ROW_TILE}x4${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}( 19 size_t input_height, 20 size_t input_width, 21 const float* input, 22 const float* weights, 23 const float* zero, 24 float* output, 25 uint32_t padding_top, 26 const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 27{ 28 assert(input_height != 0); 29 assert(input_width != 0); 30 assert(input_width % sizeof(float) == 0); 31 assert(padding_top == 1); 32 33 const v128_t vmask = wasm_v128_load(params->scalar.mask); 34 const v128_t vmax = wasm_v128_load32_splat(¶ms->scalar.max); 35 const v128_t vmin = wasm_v128_load32_splat(¶ms->scalar.min); 36 37 const v128_t vw0123 = wasm_v128_load(weights); 38 const v128_t vw4567 = wasm_v128_load(weights + 4); 39 const v128_t vw89 = wasm_v128_load64_splat(weights + 8); 40 41 const size_t input_decrement = round_up_po2(input_width, 4 * sizeof(float)); 42 43 const float* i0 = zero; 44 const float* i1 = input; 45 $for M in range(2, 2 + ROW_TILE): 46 const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width); 47 48 float* o0 = output; 49 $for M in range(1, ROW_TILE): 50 float* o${M} = (float*) ((uintptr_t) o${M-1} + input_width); 51 52 size_t output_height = input_height; 53 do { 54 $for M in range(2, 2 + ROW_TILE): 55 if XNN_UNPREDICTABLE(output_height < ${M}) { 56 i${M} = zero; 57 $if M <= ROW_TILE: 58 o${M-1} = o${M-2}; 59 } 60 61 $for M in range(2 + ROW_TILE): 62 v128_t vi${M}x0123 = wasm_f32x4_const_splat(0.0f); 63 64 $for M in range(2 + ROW_TILE): 65 v128_t vi${M}x4567 = wasm_v128_load(i${M}); i${M} += 4; 66 67 size_t w = input_width; 68 for (; w > 4 * sizeof(float); w -= 4 * sizeof(float)) { 69 $for M in range(ROW_TILE): 70 v128_t vo${M}p0 = wasm_v32x4_shuffle(vw0123, vw0123, 0, 0, 0, 0); 71 72 $for M in range(2 + ROW_TILE): 73 const v128_t vi${M}x89AB = wasm_v128_load(i${M}); i${M} += 4; 74 75 $for M in range(ROW_TILE): 76 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M}x4567, wasm_v32x4_shuffle(vw0123, vw0123, 2, 2, 2, 2))); 77 78 $for M in range(ROW_TILE): 79 $if ACCUMULATORS >= 2: 80 v128_t vo${M}p1 = wasm_f32x4_mul(vi${M+1}x4567, wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1)); 81 $else: 82 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M+1}x4567, wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1))); 83 84 $for M in range(ROW_TILE): 85 $if ACCUMULATORS >= 3: 86 v128_t vo${M}p2 = wasm_f32x4_mul(vi${M+2}x4567, wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0)); 87 $else: 88 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M+2}x4567, wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0))); 89 90 $for M in range(2 + ROW_TILE): 91 const v128_t vi${M}x3456 = wasm_v32x4_shuffle(vi${M}x0123, vi${M}x4567, 3, 4, 5, 6); 92 93 $for M in range(ROW_TILE): 94 $if ACCUMULATORS >= 4: 95 v128_t vo${M}p3 = wasm_f32x4_mul(vi${M}x3456, wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1)); 96 $else: 97 vo${M}p${3 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${3 % ACCUMULATORS}, wasm_f32x4_mul(vi${M}x3456, wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1))); 98 99 $for M in range(ROW_TILE): 100 vo${M}p${4 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${4 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+1}x3456, wasm_v32x4_shuffle(vw4567, vw4567, 0, 0, 0, 0))); 101 102 $for M in range(ROW_TILE): 103 vo${M}p${5 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${5 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+2}x3456, wasm_v32x4_shuffle(vw4567, vw4567, 3, 3, 3, 3))); 104 105 $for M in range(2 + ROW_TILE): 106 vi${M}x0123 = vi${M}x4567; 107 108 $for M in range(2 + ROW_TILE): 109 const v128_t vi${M}x5678 = wasm_v32x4_shuffle(vi${M}x4567, vi${M}x89AB, 1, 2, 3, 4); 110 111 $for M in range(ROW_TILE): 112 vo${M}p${6 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${6 % ACCUMULATORS}, wasm_f32x4_mul(vi${M}x5678, wasm_v32x4_shuffle(vw0123, vw0123, 3, 3, 3, 3))); 113 114 $for M in range(ROW_TILE): 115 vo${M}p${7 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${7 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+1}x5678, wasm_v32x4_shuffle(vw4567, vw4567, 2, 2, 2, 2))); 116 117 $for M in range(ROW_TILE): 118 vo${M}p${8 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${8 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+2}x5678, wasm_v32x4_shuffle(vw89, vw89, 1, 1, 1, 1))); 119 120 $for M in range(2 + ROW_TILE): 121 vi${M}x4567 = vi${M}x89AB; 122 123 $if ACCUMULATORS > 1: 124 $ACC_SLICE = 1 125 $while ACC_SLICE < ACCUMULATORS: 126 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 127 $if A + ACC_SLICE < ACCUMULATORS: 128 $for M in range(ROW_TILE): 129 vo${M}p${A} = wasm_f32x4_add(vo${M}p${A}, vo${M}p${A + ACC_SLICE}); 130 $ACC_SLICE *= 2 131 132 $if X86: 133 $for M in range(ROW_TILE): 134 v128_t vo${M} = wasm_f32x4_pmax(vmin, vo${M}p0); 135 $for M in range(ROW_TILE): 136 vo${M} = wasm_f32x4_pmin(vmax, vo${M}); 137 $else: 138 $for M in range(ROW_TILE): 139 v128_t vo${M} = wasm_f32x4_max(vo${M}p0, vmin); 140 $for M in range(ROW_TILE): 141 vo${M} = wasm_f32x4_min(vo${M}, vmax); 142 143 $for M in reversed(range(ROW_TILE)): 144 wasm_v128_store(o${M}, vo${M}); o${M} += 4; 145 } 146 // Always process the last block of 1..4 pixels. 147 assert(w >= 1 * sizeof(float)); 148 assert(w <= 4 * sizeof(float)); 149 { 150 $for M in range(ROW_TILE): 151 v128_t vo${M}p0 = wasm_v32x4_shuffle(vw0123, vw0123, 0, 0, 0, 0); 152 153 $for M in range(2 + ROW_TILE): 154 vi${M}x4567 = wasm_v128_and(vmask, vi${M}x4567); 155 156 $for M in range(ROW_TILE): 157 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M}x4567, wasm_v32x4_shuffle(vw0123, vw0123, 2, 2, 2, 2))); 158 159 $for M in range(ROW_TILE): 160 $if ACCUMULATORS >= 2: 161 v128_t vo${M}p1 = wasm_f32x4_mul(vi${M+1}x4567, wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1)); 162 $else: 163 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M+1}x4567, wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1))); 164 165 $for M in range(ROW_TILE): 166 $if ACCUMULATORS >= 3: 167 v128_t vo${M}p2 = wasm_f32x4_mul(vi${M+2}x4567, wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0)); 168 $else: 169 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${M+2}x4567, wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0))); 170 171 $for M in range(2 + ROW_TILE): 172 const v128_t vi${M}x3456 = wasm_v32x4_shuffle(vi${M}x0123, vi${M}x4567, 3, 4, 5, 6); 173 174 $for M in range(ROW_TILE): 175 $if ACCUMULATORS >= 4: 176 v128_t vo${M}p3 = wasm_f32x4_mul(vi${M}x3456, wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1)); 177 $else: 178 vo${M}p${3 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${3 % ACCUMULATORS}, wasm_f32x4_mul(vi${M}x3456, wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1))); 179 180 $for M in range(ROW_TILE): 181 vo${M}p${4 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${4 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+1}x3456, wasm_v32x4_shuffle(vw4567, vw4567, 0, 0, 0, 0))); 182 183 $for M in range(ROW_TILE): 184 vo${M}p${5 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${5 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+2}x3456, wasm_v32x4_shuffle(vw4567, vw4567, 3, 3, 3, 3))); 185 186 const v128_t vzero = wasm_f32x4_const_splat(0.0f); 187 $for M in range(2 + ROW_TILE): 188 const v128_t vi${M}x5678 = wasm_v32x4_shuffle(vi${M}x4567, vzero, 1, 2, 3, 4); 189 190 $for M in range(ROW_TILE): 191 vo${M}p${6 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${6 % ACCUMULATORS}, wasm_f32x4_mul(vi${M}x5678, wasm_v32x4_shuffle(vw0123, vw0123, 3, 3, 3, 3))); 192 193 $for M in range(ROW_TILE): 194 vo${M}p${7 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${7 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+1}x5678, wasm_v32x4_shuffle(vw4567, vw4567, 2, 2, 2, 2))); 195 196 $for M in range(ROW_TILE): 197 vo${M}p${8 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${8 % ACCUMULATORS}, wasm_f32x4_mul(vi${M+2}x5678, wasm_v32x4_shuffle(vw89, vw89, 1, 1, 1, 1))); 198 199 $if ACCUMULATORS > 1: 200 $ACC_SLICE = 1 201 $while ACC_SLICE < ACCUMULATORS: 202 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 203 $if A + ACC_SLICE < ACCUMULATORS: 204 $for M in range(ROW_TILE): 205 vo${M}p${A} = wasm_f32x4_add(vo${M}p${A}, vo${M}p${A + ACC_SLICE}); 206 $ACC_SLICE *= 2 207 208 $if X86: 209 $for M in range(ROW_TILE): 210 v128_t vo${M} = wasm_f32x4_pmax(vmin, vo${M}p0); 211 $for M in range(ROW_TILE): 212 vo${M} = wasm_f32x4_pmin(vmax, vo${M}); 213 $else: 214 $for M in range(ROW_TILE): 215 v128_t vo${M} = wasm_f32x4_max(vo${M}p0, vmin); 216 $for M in range(ROW_TILE): 217 vo${M} = wasm_f32x4_min(vo${M}, vmax); 218 219 if XNN_LIKELY(w == 4 * sizeof(float)) { 220 $for M in reversed(range(ROW_TILE)): 221 wasm_v128_store(o${M}, vo${M}); o${M} += 4; 222 } else { 223 if (w & (2 * sizeof(float))) { 224 $for M in reversed(range(ROW_TILE)): 225 *((double*) o${M}) = wasm_f64x2_extract_lane(vo${M}, 0); o${M} += 2; 226 227 $for M in range(ROW_TILE): 228 vo${M} = wasm_v32x4_shuffle(vo${M}, vo${M}, 2, 3, 0, 1); 229 } 230 if (w & (1 * sizeof(float))) { 231 $for M in reversed(range(ROW_TILE)): 232 *o${M} = wasm_f32x4_extract_lane(vo${M}, 0); o${M} += 1; 233 } 234 } 235 } 236 237 i0 = (const float*) ((uintptr_t) i${ROW_TILE} - input_decrement); 238 i1 = (const float*) ((uintptr_t) i${ROW_TILE+1} - input_decrement); 239 $for M in range(2, 2 + ROW_TILE): 240 i${M} = (const float*) ((uintptr_t) i${M-1} + input_width); 241 242 $if ROW_TILE > 1: 243 o0 = o${ROW_TILE - 1}; 244 $for M in range(1, ROW_TILE): 245 o${M} = (float*) ((uintptr_t) o${M-1} + input_width); 246 247 $if ROW_TILE > 1: 248 output_height = doz(output_height, ${ROW_TILE}); 249 } while (${"--" if ROW_TILE == 1 else ""}output_height != 0); 250} 251