1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker// 3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker$assert CHANNEL_TILE >= 1 7*4bdc9457SAndroid Build Coastguard Worker$assert ROW_TILE >= 1 8*4bdc9457SAndroid Build Coastguard Worker$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9*4bdc9457SAndroid Build Coastguard Worker#include <assert.h> 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/math.h> 12*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/vmulcaddc.h> 13*4bdc9457SAndroid Build Coastguard Worker 14*4bdc9457SAndroid Build Coastguard Worker 15*4bdc9457SAndroid Build Coastguard Worker$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" 16*4bdc9457SAndroid Build Coastguard Worker$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" 17*4bdc9457SAndroid Build Coastguard Workervoid xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__${"wasm" if WASM else "scalar"}_${ROW_TILE}x( 18*4bdc9457SAndroid Build Coastguard Worker size_t rows, 19*4bdc9457SAndroid Build Coastguard Worker size_t channels, 20*4bdc9457SAndroid Build Coastguard Worker const float*restrict input, 21*4bdc9457SAndroid Build Coastguard Worker size_t input_stride, 22*4bdc9457SAndroid Build Coastguard Worker const float*restrict weights, 23*4bdc9457SAndroid Build Coastguard Worker float*restrict output, 24*4bdc9457SAndroid Build Coastguard Worker size_t output_stride, 25*4bdc9457SAndroid Build Coastguard Worker const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 26*4bdc9457SAndroid Build Coastguard Worker{ 27*4bdc9457SAndroid Build Coastguard Worker assert(rows != 0); 28*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 29*4bdc9457SAndroid Build Coastguard Worker assert(channels % sizeof(float) == 0); 30*4bdc9457SAndroid Build Coastguard Worker 31*4bdc9457SAndroid Build Coastguard Worker const size_t input_increment = input_stride * ${ROW_TILE} - channels; 32*4bdc9457SAndroid Build Coastguard Worker const size_t output_increment = output_stride * ${ROW_TILE} - channels; 33*4bdc9457SAndroid Build Coastguard Worker 34*4bdc9457SAndroid Build Coastguard Worker const float* i0 = input; 35*4bdc9457SAndroid Build Coastguard Worker float* o0 = output; 36*4bdc9457SAndroid Build Coastguard Worker $for M in range(1, ROW_TILE): 37*4bdc9457SAndroid Build Coastguard Worker const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride); 38*4bdc9457SAndroid Build Coastguard Worker float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride); 39*4bdc9457SAndroid Build Coastguard Worker 40*4bdc9457SAndroid Build Coastguard Worker const float vmin = params->scalar.min; 41*4bdc9457SAndroid Build Coastguard Worker const float vmax = params->scalar.max; 42*4bdc9457SAndroid Build Coastguard Worker do { 43*4bdc9457SAndroid Build Coastguard Worker $for M in range(1, ROW_TILE): 44*4bdc9457SAndroid Build Coastguard Worker $if M % 2 == 0: 45*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(rows <= ${M}) { 46*4bdc9457SAndroid Build Coastguard Worker i${M} = i${M-1}; 47*4bdc9457SAndroid Build Coastguard Worker o${M} = o${M-1}; 48*4bdc9457SAndroid Build Coastguard Worker } 49*4bdc9457SAndroid Build Coastguard Worker $else: 50*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(rows < ${M+1}) { 51*4bdc9457SAndroid Build Coastguard Worker i${M} = i${M-1}; 52*4bdc9457SAndroid Build Coastguard Worker o${M} = o${M-1}; 53*4bdc9457SAndroid Build Coastguard Worker } 54*4bdc9457SAndroid Build Coastguard Worker 55*4bdc9457SAndroid Build Coastguard Worker const float* w = weights; 56*4bdc9457SAndroid Build Coastguard Worker size_t c = channels; 57*4bdc9457SAndroid Build Coastguard Worker $if CHANNEL_TILE > 1: 58*4bdc9457SAndroid Build Coastguard Worker for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) { 59*4bdc9457SAndroid Build Coastguard Worker $for C in range(CHANNEL_TILE): 60*4bdc9457SAndroid Build Coastguard Worker const float vscale${ABC[C]} = w[${C}]; 61*4bdc9457SAndroid Build Coastguard Worker 62*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 63*4bdc9457SAndroid Build Coastguard Worker $for C in range(CHANNEL_TILE): 64*4bdc9457SAndroid Build Coastguard Worker float vacc${M}x${ABC[C]} = i${M}[${C}]; 65*4bdc9457SAndroid Build Coastguard Worker i${M} += ${CHANNEL_TILE}; 66*4bdc9457SAndroid Build Coastguard Worker 67*4bdc9457SAndroid Build Coastguard Worker $for C in range(CHANNEL_TILE): 68*4bdc9457SAndroid Build Coastguard Worker const float vbias${ABC[C]} = w[${C + CHANNEL_TILE}]; 69*4bdc9457SAndroid Build Coastguard Worker 70*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 71*4bdc9457SAndroid Build Coastguard Worker $for C in range(CHANNEL_TILE): 72*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[C]} = vacc${M}x${ABC[C]} * vscale${ABC[C]} + vbias${ABC[C]}; 73*4bdc9457SAndroid Build Coastguard Worker 74*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 75*4bdc9457SAndroid Build Coastguard Worker $for C in range(CHANNEL_TILE): 76*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[C]} = ${MAX_F32}(vacc${M}x${ABC[C]}, vmin); 77*4bdc9457SAndroid Build Coastguard Worker 78*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 79*4bdc9457SAndroid Build Coastguard Worker $for C in range(CHANNEL_TILE): 80*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[C]} = ${MIN_F32}(vacc${M}x${ABC[C]}, vmax); 81*4bdc9457SAndroid Build Coastguard Worker 82*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 83*4bdc9457SAndroid Build Coastguard Worker $for C in range(CHANNEL_TILE): 84*4bdc9457SAndroid Build Coastguard Worker o${M}[${C}] = vacc${M}x${ABC[C]}; 85*4bdc9457SAndroid Build Coastguard Worker o${M} += ${CHANNEL_TILE}; 86*4bdc9457SAndroid Build Coastguard Worker 87*4bdc9457SAndroid Build Coastguard Worker w += ${CHANNEL_TILE * 2}; 88*4bdc9457SAndroid Build Coastguard Worker } 89*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(c != 0) { 90*4bdc9457SAndroid Build Coastguard Worker do { 91*4bdc9457SAndroid Build Coastguard Worker const float vscale = *w++; 92*4bdc9457SAndroid Build Coastguard Worker 93*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 94*4bdc9457SAndroid Build Coastguard Worker float vacc${M} = *i${M}++; 95*4bdc9457SAndroid Build Coastguard Worker 96*4bdc9457SAndroid Build Coastguard Worker const float vbias = w[${CHANNEL_TILE - 1}]; 97*4bdc9457SAndroid Build Coastguard Worker 98*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 99*4bdc9457SAndroid Build Coastguard Worker vacc${M} = vacc${M} * vscale + vbias; 100*4bdc9457SAndroid Build Coastguard Worker 101*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 102*4bdc9457SAndroid Build Coastguard Worker vacc${M} = ${MAX_F32}(vacc${M}, vmin); 103*4bdc9457SAndroid Build Coastguard Worker 104*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 105*4bdc9457SAndroid Build Coastguard Worker vacc${M} = ${MIN_F32}(vacc${M}, vmax); 106*4bdc9457SAndroid Build Coastguard Worker 107*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 108*4bdc9457SAndroid Build Coastguard Worker *o${M}++ = vacc${M}; 109*4bdc9457SAndroid Build Coastguard Worker 110*4bdc9457SAndroid Build Coastguard Worker c -= sizeof(float); 111*4bdc9457SAndroid Build Coastguard Worker } while (c != 0); 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker $else: 114*4bdc9457SAndroid Build Coastguard Worker do { 115*4bdc9457SAndroid Build Coastguard Worker const float vscale = w[0]; 116*4bdc9457SAndroid Build Coastguard Worker 117*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 118*4bdc9457SAndroid Build Coastguard Worker float vacc${M} = *i${M}++; 119*4bdc9457SAndroid Build Coastguard Worker 120*4bdc9457SAndroid Build Coastguard Worker const float vbias = w[1]; 121*4bdc9457SAndroid Build Coastguard Worker 122*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 123*4bdc9457SAndroid Build Coastguard Worker vacc${M} = vacc${M} * vscale + vbias; 124*4bdc9457SAndroid Build Coastguard Worker 125*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 126*4bdc9457SAndroid Build Coastguard Worker vacc${M} = ${MAX_F32}(vacc${M}, vmin); 127*4bdc9457SAndroid Build Coastguard Worker 128*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 129*4bdc9457SAndroid Build Coastguard Worker vacc${M} = ${MIN_F32}(vacc${M}, vmax); 130*4bdc9457SAndroid Build Coastguard Worker 131*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 132*4bdc9457SAndroid Build Coastguard Worker *o${M}++ = vacc${M}; 133*4bdc9457SAndroid Build Coastguard Worker 134*4bdc9457SAndroid Build Coastguard Worker w += 2; 135*4bdc9457SAndroid Build Coastguard Worker c -= sizeof(float); 136*4bdc9457SAndroid Build Coastguard Worker } while (c != 0); 137*4bdc9457SAndroid Build Coastguard Worker $for M in range(ROW_TILE): 138*4bdc9457SAndroid Build Coastguard Worker i${M} = (const float*) ((uintptr_t) i${M} + input_increment); 139*4bdc9457SAndroid Build Coastguard Worker o${M} = (float*) ((uintptr_t) o${M} + output_increment); 140*4bdc9457SAndroid Build Coastguard Worker rows = doz(rows, ${ROW_TILE}); 141*4bdc9457SAndroid Build Coastguard Worker } while (rows != 0); 142*4bdc9457SAndroid Build Coastguard Worker} 143