1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert DATATYPE in ["QS8", "QU8"] 7$assert CHANNEL_TILE >= 1 8$assert CHANNEL_TILE <= 16 9$assert ROW_TILE >= 3 10$assert ROW_SUBTILE >= 3 11$assert ROW_SUBTILE <= ROW_TILE 12$assert REQUANTIZATION == "FP32" 13#include <assert.h> 14$if VARIANT == "LRINTF": 15 #include <math.h> 16 17#include <xnnpack/gavgpool.h> 18#include <xnnpack/math.h> 19 20 21$PARAMS_STRUCT = "fp32_scalar_" + VARIANT.lower() 22$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 23$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" 24$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" 25void xnn_${DATATYPE.lower()}_gavgpool_minmax_fp32_ukernel_${ROW_TILE}p${ROW_SUBTILE}x__scalar_${VARIANT.lower()}_c${CHANNEL_TILE}( 26 size_t rows, 27 size_t channels, 28 const ${XINT8_T}* input, 29 size_t input_stride, 30 const ${XINT8_T}* zero, 31 int32_t* buffer, 32 ${XINT8_T}* output, 33 const union xnn_${DATATYPE.lower()}_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 34{ 35 assert(rows > ${ROW_TILE}); 36 assert(channels != 0); 37 38 const ${XINT8_T}* i0 = input; 39 $for M in range(1, ROW_TILE): 40 const ${XINT8_T}* i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M-1} + input_stride); 41 const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, ${CHANNEL_TILE}) * sizeof(${XINT8_T}); 42 43 const int32_t vinit_bias = params->${PARAMS_STRUCT}.init_bias; 44 int32_t* b = buffer; 45 $if CHANNEL_TILE == 1: 46 size_t c = channels; 47 do { 48 int32_t vacc = vinit_bias; 49 $for M in range(2): 50 const int32_t vi${M} = (int32_t) *i${M}++; 51 52 $for M in range(2, ROW_TILE): 53 vacc += vi${M-2}; 54 const int32_t vi${M} = (int32_t) *i${M}++; 55 56 $for M in range(ROW_TILE - 2, ROW_TILE): 57 vacc += vi${M}; 58 59 *b++ = vacc; 60 } while (--c != 0); 61 $else: 62 for (ptrdiff_t c = (ptrdiff_t) channels; c > 0; c -= ${CHANNEL_TILE}) { 63 $for C in range(CHANNEL_TILE): 64 const int32_t vi0x${C} = (int32_t) i0[${C}]; 65 i0 += ${CHANNEL_TILE}; 66 67 $for C in range(CHANNEL_TILE): 68 int32_t vacc${C} = vi0x${C} + vinit_bias; 69 const int32_t vi1x${C} = (int32_t) i1[${C}]; 70 i1 += ${CHANNEL_TILE}; 71 72 $for M in range(2, ROW_TILE): 73 $for C in range(CHANNEL_TILE): 74 vacc${C} += vi${M-1}x${C}; 75 const int32_t vi${M}x${C} = (int32_t) i${M}[${C}]; 76 i${M} += ${CHANNEL_TILE}; 77 78 $for C in range(CHANNEL_TILE): 79 vacc${C} += vi${ROW_TILE-1}x${C}; 80 81 $for C in range(CHANNEL_TILE): 82 b[${C}] = vacc${C}; 83 b += ${CHANNEL_TILE}; 84 } 85 86 for (rows -= ${ROW_TILE}; rows > ${ROW_SUBTILE}; rows -= ${ROW_SUBTILE}) { 87 $for M in range(ROW_SUBTILE): 88 i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment); 89 90 int32_t* b = buffer; 91 $if CHANNEL_TILE == 1: 92 size_t c = channels; 93 do { 94 int32_t vacc = *b; 95 $for M in range(2): 96 const int32_t vi${M} = (int32_t) *i${M}++; 97 98 $for M in range(2, ROW_SUBTILE): 99 vacc += vi${M-2}; 100 const int32_t vi${M} = (int32_t) *i${M}++; 101 102 $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE): 103 vacc += vi${M}; 104 105 *b++ = vacc; 106 } while (--c != 0); 107 $else: 108 for (ptrdiff_t c = (ptrdiff_t) channels; c > 0; c -= ${CHANNEL_TILE}) { 109 $for C in range(CHANNEL_TILE): 110 int32_t vacc${C} = b[${C}]; 111 const int32_t vi0x${C} = (int32_t) i0[${C}]; 112 i0 += ${CHANNEL_TILE}; 113 114 $for M in range(1, ROW_SUBTILE): 115 $for C in range(CHANNEL_TILE): 116 vacc${C} += vi${M-1}x${C}; 117 const int32_t vi${M}x${C} = (int32_t) i${M}[${C}]; 118 i${M} += ${CHANNEL_TILE}; 119 120 $for C in range(CHANNEL_TILE): 121 vacc${C} += vi${ROW_SUBTILE-1}x${C}; 122 123 $for C in range(CHANNEL_TILE): 124 b[${C}] = vacc${C}; 125 b += ${CHANNEL_TILE}; 126 } 127 } 128 129 i0 = (const ${XINT8_T}*) ((uintptr_t) i${ROW_TILE - ROW_SUBTILE} + input_increment); 130 $for M in range(1, ROW_SUBTILE): 131 i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment); 132 $if M % 2 == 1: 133 if XNN_UNPREDICTABLE(rows < ${M+1}) { 134 i${M} = zero; 135 } 136 $else: 137 if XNN_UNPREDICTABLE(rows <= ${M}) { 138 i${M} = zero; 139 } 140 141 const float vscale = params->${PARAMS_STRUCT}.scale; 142 $if VARIANT == "FMAGIC": 143 const float voutput_min_less_zero_point = params->fp32_scalar_fmagic.output_min_less_zero_point; 144 const float voutput_max_less_zero_point = params->fp32_scalar_fmagic.output_max_less_zero_point; 145 const float vmagic_bias = params->fp32_scalar_fmagic.magic_bias; 146 const int32_t vmagic_bias_less_output_zero_point = params->fp32_scalar_fmagic.magic_bias_less_output_zero_point; 147 $elif VARIANT == "IMAGIC": 148 const float vmagic_bias = params->fp32_scalar_imagic.magic_bias; 149 const int32_t vmagic_min = params->fp32_scalar_imagic.magic_min; 150 const int32_t vmagic_max = params->fp32_scalar_imagic.magic_max; 151 const int32_t vmagic_bias_less_zero_point = params->fp32_scalar_imagic.magic_bias_less_zero_point; 152 $elif VARIANT == "LRINTF": 153 const float voutput_min_less_zero_point = params->fp32_scalar_lrintf.output_min_less_zero_point; 154 const float voutput_max_less_zero_point = params->fp32_scalar_lrintf.output_max_less_zero_point; 155 const int32_t voutput_zero_point = params->fp32_scalar_lrintf.output_zero_point; 156 $if CHANNEL_TILE == 1: 157 do { 158 int32_t vacc = *buffer++; 159 $for M in range(2): 160 const int32_t vi${M} = (int32_t) *i${M}++; 161 162 $for M in range(2, ROW_SUBTILE): 163 vacc += vi${M-2}; 164 const int32_t vi${M} = (int32_t) *i${M}++; 165 166 $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE): 167 vacc += vi${M}; 168 169 float vfpacc = (float) vacc * vscale; 170 $if VARIANT == "FMAGIC": 171 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 172 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 173 vfpacc += vmagic_bias; 174 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 175 $elif VARIANT == "IMAGIC": 176 vfpacc += vmagic_bias; 177 int32_t vout = (int32_t) float_as_uint32(vfpacc); 178 vout = math_max_s32(vout, vmagic_min); 179 vout = math_min_s32(vout, vmagic_max); 180 vout -= vmagic_bias_less_zero_point; 181 $elif VARIANT == "LRINTF": 182 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 183 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 184 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 185 int32_t vout = vrndacc + voutput_zero_point; 186 187 *output++ = (${XINT8_T}) vout; 188 } while (--channels != 0); 189 $else: 190 for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) { 191 $for C in range(CHANNEL_TILE): 192 int32_t vacc${C} = buffer[${C}]; 193 const int32_t vi0x${C} = (int32_t) i0[${C}]; 194 buffer += ${CHANNEL_TILE}; 195 i0 += ${CHANNEL_TILE}; 196 197 $for M in range(1, ROW_SUBTILE): 198 $for C in range(CHANNEL_TILE): 199 vacc${C} += vi${M-1}x${C}; 200 const int32_t vi${M}x${C} = (int32_t) i${M}[${C}]; 201 i${M} += ${CHANNEL_TILE}; 202 203 $for C in range(CHANNEL_TILE): 204 vacc${C} += vi${ROW_SUBTILE-1}x${C}; 205 206 $for C in range(CHANNEL_TILE): 207 float vfpacc${C} = (float) vacc${C} * vscale; 208 209 $if VARIANT == "FMAGIC": 210 $for C in range(CHANNEL_TILE): 211 vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point); 212 213 $for C in range(CHANNEL_TILE): 214 vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point); 215 216 $for C in range(CHANNEL_TILE): 217 vfpacc${C} += vmagic_bias; 218 219 $for C in range(CHANNEL_TILE): 220 int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}) - vmagic_bias_less_output_zero_point; 221 $elif VARIANT == "IMAGIC": 222 $for C in range(CHANNEL_TILE): 223 vfpacc${C} += vmagic_bias; 224 225 $for C in range(CHANNEL_TILE): 226 int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}); 227 228 $for C in range(CHANNEL_TILE): 229 vout${C} = math_max_s32(vout${C}, vmagic_min); 230 231 $for C in range(CHANNEL_TILE): 232 vout${C} = math_min_s32(vout${C}, vmagic_max); 233 234 $for C in range(CHANNEL_TILE): 235 vout${C} -= vmagic_bias_less_zero_point; 236 $elif VARIANT == "LRINTF": 237 $for C in range(CHANNEL_TILE): 238 vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point); 239 240 $for C in range(CHANNEL_TILE): 241 vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point); 242 243 $for C in range(CHANNEL_TILE): 244 const int32_t vrndacc${C} = (int32_t) lrintf(vfpacc${C}); 245 246 $for C in range(CHANNEL_TILE): 247 int32_t vout${C} = vrndacc${C} + voutput_zero_point; 248 249 $for C in range(CHANNEL_TILE): 250 output[${C}] = (${XINT8_T}) vout${C}; 251 output += ${CHANNEL_TILE}; 252 } 253 if XNN_UNLIKELY(channels != 0) { 254 $if CHANNEL_TILE == 2: 255 int32_t vacc = *buffer; 256 $for M in range(2): 257 const int32_t vi${M} = (int32_t) *i${M}; 258 259 $for M in range(2, ROW_SUBTILE): 260 vacc += vi${M-2}; 261 const int32_t vi${M} = (int32_t) *i${M}; 262 263 $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE): 264 vacc += vi${M}; 265 266 float vfpacc = (float) vacc * vscale; 267 $if VARIANT == "FMAGIC": 268 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 269 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 270 vfpacc += vmagic_bias; 271 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 272 $elif VARIANT == "IMAGIC": 273 vfpacc += vmagic_bias; 274 int32_t vout = (int32_t) float_as_uint32(vfpacc); 275 vout = math_max_s32(vout, vmagic_min); 276 vout = math_min_s32(vout, vmagic_max); 277 vout -= vmagic_bias_less_zero_point; 278 $elif VARIANT == "LRINTF": 279 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 280 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 281 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 282 int32_t vout = vrndacc + voutput_zero_point; 283 284 *output = (${XINT8_T}) vout; 285 $else: 286 do { 287 int32_t vacc = *buffer++; 288 $for M in range(2): 289 const int32_t vi${M} = (int32_t) *i${M}++; 290 291 $for M in range(2, ROW_SUBTILE): 292 vacc += vi${M-2}; 293 const int32_t vi${M} = (int32_t) *i${M}++; 294 295 $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE): 296 vacc += vi${M}; 297 298 float vfpacc = (float) vacc * vscale; 299 $if VARIANT == "FMAGIC": 300 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 301 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 302 vfpacc += vmagic_bias; 303 int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point; 304 $elif VARIANT == "IMAGIC": 305 vfpacc += vmagic_bias; 306 int32_t vout = (int32_t) float_as_uint32(vfpacc); 307 vout = math_max_s32(vout, vmagic_min); 308 vout = math_min_s32(vout, vmagic_max); 309 vout -= vmagic_bias_less_zero_point; 310 $elif VARIANT == "LRINTF": 311 vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point); 312 vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point); 313 const int32_t vrndacc = (int32_t) lrintf(vfpacc); 314 int32_t vout = vrndacc + voutput_zero_point; 315 316 *output++ = (${XINT8_T}) vout; 317 } while (--channels != 0); 318 } 319} 320