1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2020 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker// 3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker$assert CHANNEL_TILE % 4 == 0 7*4bdc9457SAndroid Build Coastguard Worker$assert CHANNEL_TILE >= 4 8*4bdc9457SAndroid Build Coastguard Worker$assert ROW_TILE >= 1 9*4bdc9457SAndroid Build Coastguard Worker$assert ARCH in ["ARM", "X86", "RELAXED"] 10*4bdc9457SAndroid Build Coastguard Worker$assert not FMA or ARCH == "RELAXED" 11*4bdc9457SAndroid Build Coastguard Worker$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 12*4bdc9457SAndroid Build Coastguard Worker#include <assert.h> 13*4bdc9457SAndroid Build Coastguard Worker 14*4bdc9457SAndroid Build Coastguard Worker#include <wasm_simd128.h> 15*4bdc9457SAndroid Build Coastguard Worker 16*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/math.h> 17*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/vmulcaddc.h> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker$WASM_F32X4_MIN={"ARM": "wasm_f32x4_min", "X86": "wasm_f32x4_pmin", "RELAXED": "__builtin_wasm_relaxed_min_f32x4"}[ARCH] 21*4bdc9457SAndroid Build Coastguard Worker$WASM_F32X4_MAX={"ARM": "wasm_f32x4_max", "X86": "wasm_f32x4_pmax", "RELAXED": "__builtin_wasm_relaxed_max_f32x4"}[ARCH] 22*4bdc9457SAndroid Build Coastguard Worker$ISA = "wasmsimd" if ARCH != "RELAXED" else "wasmrelaxedsimd" 23*4bdc9457SAndroid Build Coastguard Worker$ARCH_SUFFIX = "" if ARCH == "RELAXED" and not FMA else "_" + ("fma" if FMA else ARCH.lower()) 24*4bdc9457SAndroid Build Coastguard Workervoid xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__${ISA}${ARCH_SUFFIX}_${ROW_TILE}x( 25*4bdc9457SAndroid Build Coastguard Worker size_t rows, 26*4bdc9457SAndroid Build Coastguard Worker size_t channels, 27*4bdc9457SAndroid Build Coastguard Worker const float*restrict input, 28*4bdc9457SAndroid Build Coastguard Worker size_t input_stride, 29*4bdc9457SAndroid Build Coastguard Worker const float*restrict weights, 30*4bdc9457SAndroid Build Coastguard Worker float*restrict output, 31*4bdc9457SAndroid Build Coastguard Worker size_t output_stride, 32*4bdc9457SAndroid Build Coastguard Worker const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 33*4bdc9457SAndroid Build Coastguard Worker{ 34*4bdc9457SAndroid Build Coastguard Worker assert(rows != 0); 35*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 36*4bdc9457SAndroid Build Coastguard Worker assert(channels % sizeof(float) == 0); 37*4bdc9457SAndroid Build Coastguard Worker 38*4bdc9457SAndroid Build Coastguard Worker const float* i0 = input; 39*4bdc9457SAndroid Build Coastguard Worker float* o0 = output; 40*4bdc9457SAndroid Build Coastguard Worker $for M in range(1, ROW_TILE): 41*4bdc9457SAndroid Build Coastguard Worker const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride); 42*4bdc9457SAndroid Build Coastguard Worker float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride); 43*4bdc9457SAndroid Build Coastguard Worker 44*4bdc9457SAndroid Build Coastguard Worker const size_t input_increment = input_stride * ${ROW_TILE} - channels; 45*4bdc9457SAndroid Build Coastguard Worker const size_t output_increment = output_stride * ${ROW_TILE} - channels; 46*4bdc9457SAndroid Build Coastguard Worker 47*4bdc9457SAndroid Build Coastguard Worker const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min); 48*4bdc9457SAndroid Build Coastguard Worker const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max); 49*4bdc9457SAndroid Build Coastguard Worker do { 50*4bdc9457SAndroid Build Coastguard Worker $for M in range(1, ROW_TILE): 51*4bdc9457SAndroid Build Coastguard Worker $if M % 2 == 0: 52*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(rows <= ${M}) { 53*4bdc9457SAndroid Build Coastguard Worker i${M} = i${M-1}; 54*4bdc9457SAndroid Build Coastguard Worker o${M} = o${M-1}; 55*4bdc9457SAndroid Build Coastguard Worker } 56*4bdc9457SAndroid Build Coastguard Worker $else: 57*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(rows < ${M+1}) { 58*4bdc9457SAndroid Build Coastguard Worker i${M} = i${M-1}; 59*4bdc9457SAndroid Build Coastguard Worker o${M} = o${M-1}; 60*4bdc9457SAndroid Build Coastguard Worker } 61*4bdc9457SAndroid Build Coastguard Worker 62*4bdc9457SAndroid Build Coastguard Worker const float* w = weights; 63*4bdc9457SAndroid Build Coastguard Worker size_t c = channels; 64*4bdc9457SAndroid Build Coastguard Worker for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) { 65*4bdc9457SAndroid Build Coastguard Worker const v128_t vscale${ABC[0:4]} = wasm_v128_load(w); 66*4bdc9457SAndroid Build Coastguard Worker $for C in range(4, CHANNEL_TILE, 4): 67*4bdc9457SAndroid Build Coastguard Worker const v128_t vscale${ABC[C:C+4]} = wasm_v128_load(w + ${C}); 68*4bdc9457SAndroid Build Coastguard Worker 69*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 70*4bdc9457SAndroid Build Coastguard Worker v128_t vacc${M}x${ABC[0:4]} = wasm_v128_load(i${M}); 71*4bdc9457SAndroid Build Coastguard Worker $for C in range(4, CHANNEL_TILE, 4): 72*4bdc9457SAndroid Build Coastguard Worker v128_t vacc${M}x${ABC[C:C+4]} = wasm_v128_load(i${M} + ${C}); 73*4bdc9457SAndroid Build Coastguard Worker i${M} += ${CHANNEL_TILE}; 74*4bdc9457SAndroid Build Coastguard Worker 75*4bdc9457SAndroid Build Coastguard Worker $for C in range(0, CHANNEL_TILE, 4): 76*4bdc9457SAndroid Build Coastguard Worker const v128_t vbias${ABC[C:C+4]} = wasm_v128_load(w + ${C + CHANNEL_TILE}); 77*4bdc9457SAndroid Build Coastguard Worker 78*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 79*4bdc9457SAndroid Build Coastguard Worker $for C in range(0, CHANNEL_TILE, 4): 80*4bdc9457SAndroid Build Coastguard Worker $if FMA: 81*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[C:C+4]} = __builtin_wasm_fma_f32x4(vbias${ABC[C:C+4]}, vscale${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]}); 82*4bdc9457SAndroid Build Coastguard Worker $else: 83*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[C:C+4]} = wasm_f32x4_add(vbias${ABC[C:C+4]}, wasm_f32x4_mul(vscale${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]})); 84*4bdc9457SAndroid Build Coastguard Worker 85*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 86*4bdc9457SAndroid Build Coastguard Worker $for C in range(0, CHANNEL_TILE, 4): 87*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[C:C+4]} = ${WASM_F32X4_MAX}(vmin, vacc${M}x${ABC[C:C+4]}); 88*4bdc9457SAndroid Build Coastguard Worker 89*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 90*4bdc9457SAndroid Build Coastguard Worker $for C in range(0, CHANNEL_TILE, 4): 91*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[C:C+4]} = ${WASM_F32X4_MIN}(vmax, vacc${M}x${ABC[C:C+4]}); 92*4bdc9457SAndroid Build Coastguard Worker 93*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 94*4bdc9457SAndroid Build Coastguard Worker wasm_v128_store(o${M}, vacc${M}x${ABC[0:4]}); 95*4bdc9457SAndroid Build Coastguard Worker $for C in range(4, CHANNEL_TILE, 4): 96*4bdc9457SAndroid Build Coastguard Worker wasm_v128_store(o${M} + ${C}, vacc${M}x${ABC[C:C+4]}); 97*4bdc9457SAndroid Build Coastguard Worker o${M} += ${CHANNEL_TILE}; 98*4bdc9457SAndroid Build Coastguard Worker 99*4bdc9457SAndroid Build Coastguard Worker w += ${CHANNEL_TILE * 2}; 100*4bdc9457SAndroid Build Coastguard Worker } 101*4bdc9457SAndroid Build Coastguard Worker $if CHANNEL_TILE > 4: 102*4bdc9457SAndroid Build Coastguard Worker for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) { 103*4bdc9457SAndroid Build Coastguard Worker const v128_t vscale = wasm_v128_load(w); 104*4bdc9457SAndroid Build Coastguard Worker 105*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 106*4bdc9457SAndroid Build Coastguard Worker v128_t vacc${M} = wasm_v128_load(i${M}); 107*4bdc9457SAndroid Build Coastguard Worker i${M} += 4; 108*4bdc9457SAndroid Build Coastguard Worker 109*4bdc9457SAndroid Build Coastguard Worker const v128_t vbias = wasm_v128_load(w + ${CHANNEL_TILE}); 110*4bdc9457SAndroid Build Coastguard Worker 111*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 112*4bdc9457SAndroid Build Coastguard Worker $if FMA: 113*4bdc9457SAndroid Build Coastguard Worker vacc${M} = __builtin_wasm_fma_f32x4(vbias, vscale, vacc${M}); 114*4bdc9457SAndroid Build Coastguard Worker $else: 115*4bdc9457SAndroid Build Coastguard Worker vacc${M} = wasm_f32x4_add(vbias, wasm_f32x4_mul(vscale, vacc${M})); 116*4bdc9457SAndroid Build Coastguard Worker 117*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 118*4bdc9457SAndroid Build Coastguard Worker vacc${M} = ${WASM_F32X4_MAX}(vmin, vacc${M}); 119*4bdc9457SAndroid Build Coastguard Worker 120*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 121*4bdc9457SAndroid Build Coastguard Worker vacc${M} = ${WASM_F32X4_MIN}(vmax, vacc${M}); 122*4bdc9457SAndroid Build Coastguard Worker 123*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 124*4bdc9457SAndroid Build Coastguard Worker wasm_v128_store(o${M}, vacc${M}); 125*4bdc9457SAndroid Build Coastguard Worker o${M} += 4; 126*4bdc9457SAndroid Build Coastguard Worker 127*4bdc9457SAndroid Build Coastguard Worker w += 4; 128*4bdc9457SAndroid Build Coastguard Worker } 129*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(c != 0) { 130*4bdc9457SAndroid Build Coastguard Worker const v128_t vscale = wasm_v128_load(w); 131*4bdc9457SAndroid Build Coastguard Worker 132*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 133*4bdc9457SAndroid Build Coastguard Worker v128_t vacc${M} = wasm_v128_load(i${M}); 134*4bdc9457SAndroid Build Coastguard Worker i${M} = (const float*) ((uintptr_t) i${M} + c); 135*4bdc9457SAndroid Build Coastguard Worker 136*4bdc9457SAndroid Build Coastguard Worker const v128_t vbias = wasm_v128_load(w + ${CHANNEL_TILE}); 137*4bdc9457SAndroid Build Coastguard Worker 138*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 139*4bdc9457SAndroid Build Coastguard Worker $if FMA: 140*4bdc9457SAndroid Build Coastguard Worker vacc${M} = __builtin_wasm_fma_f32x4(vbias, vscale, vacc${M}); 141*4bdc9457SAndroid Build Coastguard Worker $else: 142*4bdc9457SAndroid Build Coastguard Worker vacc${M} = wasm_f32x4_add(vbias, wasm_f32x4_mul(vscale, vacc${M})); 143*4bdc9457SAndroid Build Coastguard Worker 144*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 145*4bdc9457SAndroid Build Coastguard Worker vacc${M} = ${WASM_F32X4_MAX}(vmin, vacc${M}); 146*4bdc9457SAndroid Build Coastguard Worker 147*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 148*4bdc9457SAndroid Build Coastguard Worker vacc${M} = ${WASM_F32X4_MIN}(vmax, vacc${M}); 149*4bdc9457SAndroid Build Coastguard Worker 150*4bdc9457SAndroid Build Coastguard Worker if (c & (2 * sizeof(float))) { 151*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 152*4bdc9457SAndroid Build Coastguard Worker *((double*) o${M}) = wasm_f64x2_extract_lane(vacc${M}, 0); 153*4bdc9457SAndroid Build Coastguard Worker 154*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 155*4bdc9457SAndroid Build Coastguard Worker vacc${M} = wasm_v32x4_shuffle(vacc${M}, vacc${M}, 2, 3, 2, 3); 156*4bdc9457SAndroid Build Coastguard Worker 157*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 158*4bdc9457SAndroid Build Coastguard Worker o${M} += 2; 159*4bdc9457SAndroid Build Coastguard Worker } 160*4bdc9457SAndroid Build Coastguard Worker if (c & (1 * sizeof(float))) { 161*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 162*4bdc9457SAndroid Build Coastguard Worker *o${M}++ = wasm_f32x4_extract_lane(vacc${M}, 0); 163*4bdc9457SAndroid Build Coastguard Worker } 164*4bdc9457SAndroid Build Coastguard Worker } 165*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 166*4bdc9457SAndroid Build Coastguard Worker i${M} = (const float*) ((uintptr_t) i${M} + input_increment); 167*4bdc9457SAndroid Build Coastguard Worker o${M} = (float*) ((uintptr_t) o${M} + output_increment); 168*4bdc9457SAndroid Build Coastguard Worker rows = doz(rows, ${ROW_TILE}); 169*4bdc9457SAndroid Build Coastguard Worker } while (rows != 0); 170*4bdc9457SAndroid Build Coastguard Worker} 171