1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert DATATYPE in ["QS8", "QU8"] 7$assert CHANNEL_TILE >= 1 8$assert ROW_TILE >= 3 9$assert REQUANTIZATION == "FP32" 10#include <assert.h> 11$if VARIANT == "LRINTF": 12 #include <math.h> 13 14#include <xnnpack/gavgpool.h> 15#include <xnnpack/math.h> 16 17 18$PARAMS_STRUCT = "fp32_scalar_" + VARIANT.lower() 19$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 20$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" 21$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" 22void xnn_${DATATYPE.lower()}_gavgpool_minmax_fp32_ukernel_${ROW_TILE}x__scalar_${VARIANT.lower()}_c${CHANNEL_TILE}( 23 size_t rows, 24 size_t channels, 25 const ${XINT8_T}* input, 26 size_t input_stride, 27 const ${XINT8_T}* zero, 28 ${XINT8_T}* output, 29 const union xnn_${DATATYPE.lower()}_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 30{ 31 assert(rows != 0); 32 assert(rows <= ${ROW_TILE}); 33 assert(channels != 0); 34 35 const ${XINT8_T}* i0 = input; 36 $for M in range(1, ROW_TILE): 37 const ${XINT8_T}* i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M-1} + input_stride); 38 $if M % 2 == 1: 39 if XNN_UNPREDICTABLE(rows < ${M+1}) { 40 i${M} = zero; 41 } 42 $else: 43 if XNN_UNPREDICTABLE(rows <= ${M}) { 44 i${M} = zero; 45 } 46 47 const int32_t vinit_bias = params->${PARAMS_STRUCT}.init_bias; 48 const float vscale = params->${PARAMS_STRUCT}.scale; 49 $if VARIANT == "FMAGIC": 50 const float voutput_min_less_zero_point = params->fp32_scalar_fmagic.output_min_less_zero_point; 51 const float voutput_max_less_zero_point = params->fp32_scalar_fmagic.output_max_less_zero_point; 52 const float vmagic_bias = params->fp32_scalar_fmagic.magic_bias; 53 const int32_t vmagic_bias_less_output_zero_point = params->fp32_scalar_fmagic.magic_bias_less_output_zero_point; 54 $elif VARIANT == "IMAGIC": 55 const float vmagic_bias = params->fp32_scalar_imagic.magic_bias; 56 const int32_t vmagic_min = params->fp32_scalar_imagic.magic_min; 57 const int32_t vmagic_max = params->fp32_scalar_imagic.magic_max; 58 const int32_t vmagic_bias_less_zero_point = params->fp32_scalar_imagic.magic_bias_less_zero_point; 59 $elif VARIANT == "LRINTF": 60 const float voutput_min_less_zero_point = params->fp32_scalar_lrintf.output_min_less_zero_point; 61 const float voutput_max_less_zero_point = params->fp32_scalar_lrintf.output_max_less_zero_point; 62 const int32_t voutput_zero_point = params->fp32_scalar_lrintf.output_zero_point; 63 $if CHANNEL_TILE > 1: 64 for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) { 65 $for C in range(CHANNEL_TILE): 66 const int32_t vi0x${C} = (int32_t) i0[${C}]; 67 i0 += ${CHANNEL_TILE}; 68 69 $for C in range(CHANNEL_TILE): 70 int32_t vacc${C} = vi0x${C} + vinit_bias; 71 const int32_t vi1x${C} = (int32_t) i1[${C}]; 72 i1 += ${CHANNEL_TILE}; 73 74 $for M in range(2, ROW_TILE): 75 $for C in range(CHANNEL_TILE): 76 vacc${C} += vi${M-1}x${C}; 77 const int32_t vi${M}x${C} = (int32_t) i${M}[${C}]; 78 i${M} += ${CHANNEL_TILE}; 79 80 $for C in range(CHANNEL_TILE): 81 vacc${C} += vi${ROW_TILE-1}x${C}; 82 83 $for C in range(CHANNEL_TILE): 84 float vfpacc${C} = (float) vacc${C} * vscale; 85 86 $if VARIANT == "FMAGIC": 87 $for C in range(CHANNEL_TILE): 88 vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point); 89 90 $for C in range(CHANNEL_TILE): 91 vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point); 92 93 $for C in range(CHANNEL_TILE): 94 vfpacc${C} += vmagic_bias; 95 96 $for C in range(CHANNEL_TILE): 97 int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}) - vmagic_bias_less_output_zero_point; 98 $elif VARIANT == "IMAGIC": 99 $for C in range(CHANNEL_TILE): 100 vfpacc${C} += vmagic_bias; 101 102 $for C in range(CHANNEL_TILE): 103 int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}); 104 105 $for C in range(CHANNEL_TILE): 106 vout${C} = math_max_s32(vout${C}, vmagic_min); 107 108 $for C in range(CHANNEL_TILE): 109 vout${C} = math_min_s32(vout${C}, vmagic_max); 110 111 $for C in range(CHANNEL_TILE): 112 vout${C} -= vmagic_bias_less_zero_point; 113 $elif VARIANT == "LRINTF": 114 $for C in range(CHANNEL_TILE): 115 vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point); 116 117 $for C in range(CHANNEL_TILE): 118 vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point); 119 120 $for C in range(CHANNEL_TILE): 121 const int32_t vrndacc${C} = (int32_t) lrintf(vfpacc${C}); 122 123 $for C in range(CHANNEL_TILE): 124 int32_t vout${C} = vrndacc${C} + voutput_zero_point; 125 126 $for C in range(CHANNEL_TILE): 127 output[${C}] = (${XINT8_T}) vout${C}; 128 output += ${CHANNEL_TILE}; 129 } 130 $if CHANNEL_TILE == 1: 131 do { 132 int32_t vacc = vinit_bias; 133 $for M in range(2): 134 const int32_t vi${M} = (int32_t) *i${M}++; 135 136 $for M in range(2, ROW_TILE): 137 vacc += vi${M-2}; 138 const int32_t vi${M} = (int32_t) *i${M}++; 139 140 $for M in range(ROW_TILE - 2, ROW_TILE): 141 vacc += vi${M}; 142 143 float vfpacc = (float) vacc * vscale; 144 $if VARIANT == "FMAGIC": 145 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 146 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 147 vfpacc += vmagic_bias; 148 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 149 $elif VARIANT == "IMAGIC": 150 vfpacc += vmagic_bias; 151 int32_t vout = (int32_t) float_as_uint32(vfpacc); 152 vout = math_max_s32(vout, vmagic_min); 153 vout = math_min_s32(vout, vmagic_max); 154 vout -= vmagic_bias_less_zero_point; 155 $elif VARIANT == "LRINTF": 156 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 157 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 158 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 159 int32_t vout = vrndacc + voutput_zero_point; 160 161 *output++ = (${XINT8_T}) vout; 162 } while (--channels != 0); 163 $else: 164 if XNN_UNLIKELY(channels != 0) { 165 $if CHANNEL_TILE == 2: 166 int32_t vacc = vinit_bias; 167 $for M in range(2): 168 const int32_t vi${M} = (int32_t) *i${M}; 169 170 $for M in range(2, ROW_TILE): 171 vacc += vi${M-2}; 172 const int32_t vi${M} = (int32_t) *i${M}; 173 174 $for M in range(ROW_TILE - 2, ROW_TILE): 175 vacc += vi${M}; 176 177 float vfpacc = (float) vacc * vscale; 178 $if VARIANT == "FMAGIC": 179 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 180 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 181 vfpacc += vmagic_bias; 182 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 183 $elif VARIANT == "IMAGIC": 184 vfpacc += vmagic_bias; 185 int32_t vout = (int32_t) float_as_uint32(vfpacc); 186 vout = math_max_s32(vout, vmagic_min); 187 vout = math_min_s32(vout, vmagic_max); 188 vout -= vmagic_bias_less_zero_point; 189 $elif VARIANT == "LRINTF": 190 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 191 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 192 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 193 int32_t vout = vrndacc + voutput_zero_point; 194 195 *output = (${XINT8_T}) vout; 196 $else: 197 do { 198 int32_t vacc = vinit_bias; 199 $for M in range(2): 200 const int32_t vi${M} = (int32_t) *i${M}++; 201 202 $for M in range(2, ROW_TILE): 203 vacc += vi${M-2}; 204 const int32_t vi${M} = (int32_t) *i${M}++; 205 206 $for M in range(ROW_TILE - 2, ROW_TILE): 207 vacc += vi${M}; 208 209 float vfpacc = (float) vacc * vscale; 210 $if VARIANT == "FMAGIC": 211 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 212 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 213 vfpacc += vmagic_bias; 214 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 215 $elif VARIANT == "IMAGIC": 216 vfpacc += vmagic_bias; 217 int32_t vout = (int32_t) float_as_uint32(vfpacc); 218 vout = math_max_s32(vout, vmagic_min); 219 vout = math_min_s32(vout, vmagic_max); 220 vout -= vmagic_bias_less_zero_point; 221 $elif VARIANT == "LRINTF": 222 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 223 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 224 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 225 int32_t vout = vrndacc + voutput_zero_point; 226 227 *output++ = (${XINT8_T}) vout; 228 } while (--channels != 0); 229 } 230} 231