1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert REQUANTIZATION == "FP32" 7$assert VARIANT in ["FMAGIC", "IMAGIC", "LRINTF"] 8$assert DATATYPE in ["QC8", "QS8", "QU8"] 9#include <assert.h> 10$if VARIANT == "LRINTF": 11 #include <math.h> 12 13#include <xnnpack/math.h> 14#include <xnnpack/gemm.h> 15$if NR % 4 != 0: 16 #include <xnnpack/unaligned.h> 17 18 19$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" + ("_" + VARIANT.lower() if VARIANT else "") 20$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 21$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 22$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" 23$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" 24void xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}__${"wasm" if WASM else "scalar"}_${VARIANT.lower()}( 25 size_t mr, 26 size_t nc, 27 size_t kc, 28 size_t ks, 29 const ${XINT8_T}**restrict a, 30 const void*restrict w, 31 ${XINT8_T}*restrict c, 32 size_t cm_stride, 33 size_t cn_stride, 34 size_t a_offset, 35 const ${XINT8_T}* zero, 36 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) 37{ 38 assert(mr != 0); 39 assert(mr <= ${MR}); 40 assert(nc != 0); 41 assert(kc != 0); 42 assert(ks != 0); 43 assert(ks % (${MR} * sizeof(void*)) == 0); 44 assert(a != NULL); 45 assert(w != NULL); 46 assert(c != NULL); 47 48 ${XINT8_T}* c0 = c; 49 $for M in range(1, MR): 50 ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride); 51 $if M % 2 == 0: 52 if XNN_UNPREDICTABLE(mr <= ${M}) { 53 c${M} = c${M-1}; 54 } 55 $elif M + 1 == MR: 56 if XNN_UNPREDICTABLE(mr != ${M+1}) { 57 c${M} = c${M-1}; 58 } 59 $else: 60 if XNN_UNPREDICTABLE(mr < ${M+1}) { 61 c${M} = c${M-1}; 62 } 63 64 $if DATATYPE == "QU8": 65 const int32_t vb_zero_point = params->${PARAMS_STRUCT}.kernel_zero_point; 66 do { 67 $if NR % 4 != 0: 68 $for N in range(NR): 69 int32_t vacc0x${N} = unaligned_indexed_load_s32(w, ${N}); 70 $else: 71 $for N in range(NR): 72 int32_t vacc0x${N} = ((const int32_t*) w)[${N}]; 73 $for M in range(1, MR): 74 $for N in range(NR): 75 int32_t vacc${M}x${N} = vacc0x${N}; 76 w = (const void*) ((const int32_t*) w + ${NR}); 77 78 size_t p = ks; 79 do { 80 $for M in range(MR): 81 const ${XINT8_T}* restrict a${M} = a[${M}]; 82 assert(a${M} != NULL); 83 if XNN_UNPREDICTABLE(a${M} != zero) { 84 a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + a_offset); 85 } 86 a += ${MR}; 87 88 size_t k = kc; 89 do { 90 $for M in range(MR): 91 $if DATATYPE == "QU8": 92 const int32_t va${M} = (int32_t) (uint32_t) *a${M}++; 93 $else: 94 const int32_t va${M} = (int32_t) *a${M}++; 95 96 $for N in range(NR): 97 $if DATATYPE == "QU8": 98 const int32_t vb${N} = (int32_t) (uint32_t) ((const uint8_t*) w)[${N}] - vb_zero_point; 99 $else: 100 const int32_t vb${N} = (int32_t) ((const int8_t*) w)[${N}]; 101 w = (const void*) ((const ${XINT8_T}*) w + ${NR}); 102 103 $for M in range(MR): 104 $for N in range(NR): 105 vacc${M}x${N} += va${M} * vb${N}; 106 107 k -= sizeof(${XINT8_T}); 108 } while (k != 0); 109 p -= ${MR} * sizeof(void*); 110 } while (p != 0); 111 112 $for M in range(MR): 113 $for N in range(NR): 114 float vfpacc${M}x${N} = (float) vacc${M}x${N}; 115 116 $if DATATYPE == "QC8": 117 $if NR % 4 != 0: 118 $for N in range(NR): 119 const float vscale${N} = unaligned_indexed_load_f32(w, ${N}); 120 $for M in range(MR): 121 vfpacc${M}x${N} *= vscale${N}; 122 $else: 123 $for N in range(NR): 124 const float vscale${N} = ((const float*) w)[${N}]; 125 $for M in range(MR): 126 vfpacc${M}x${N} *= vscale${N}; 127 w = (const void*) ((const float*) w + ${NR}); 128 $else: 129 const float vscale = params->${PARAMS_STRUCT}.scale; 130 $for M in range(MR): 131 $for N in range(NR): 132 vfpacc${M}x${N} *= vscale; 133 134 $if VARIANT == "FMAGIC": 135 const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point; 136 $for M in range(MR): 137 $for N in range(NR): 138 vfpacc${M}x${N} = ${MAX_F32}(vfpacc${M}x${N}, voutput_min_less_zero_point); 139 140 const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point; 141 $for M in range(MR): 142 $for N in range(NR): 143 vfpacc${M}x${N} = ${MIN_F32}(vfpacc${M}x${N}, voutput_max_less_zero_point); 144 145 const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 146 $for M in range(MR): 147 $for N in range(NR): 148 vfpacc${M}x${N} += vmagic_bias; 149 150 const int32_t vmagic_bias_less_output_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point; 151 $for M in range(MR): 152 $for N in range(NR): 153 int32_t vout${M}x${N} = (int32_t) float_as_uint32(vfpacc${M}x${N}) - vmagic_bias_less_output_zero_point; 154 $elif VARIANT == "IMAGIC": 155 const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 156 $for M in range(MR): 157 $for N in range(NR): 158 vfpacc${M}x${N} += vmagic_bias; 159 160 $for M in range(MR): 161 $for N in range(NR): 162 int32_t vout${M}x${N} = (int32_t) float_as_uint32(vfpacc${M}x${N}); 163 164 const int32_t vmagic_min = params->${PARAMS_STRUCT}.magic_min; 165 $for M in range(MR): 166 $for N in range(NR): 167 vout${M}x${N} = math_max_s32(vout${M}x${N}, vmagic_min); 168 169 const int32_t vmagic_max = params->${PARAMS_STRUCT}.magic_max; 170 $for M in range(MR): 171 $for N in range(NR): 172 vout${M}x${N} = math_min_s32(vout${M}x${N}, vmagic_max); 173 174 const int32_t vmagic_bias_less_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_zero_point; 175 $for M in range(MR): 176 $for N in range(NR): 177 vout${M}x${N} -= vmagic_bias_less_zero_point; 178 $elif VARIANT == "LRINTF": 179 const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point; 180 $for M in range(MR): 181 $for N in range(NR): 182 vfpacc${M}x${N} = ${MAX_F32}(vfpacc${M}x${N}, voutput_min_less_zero_point); 183 184 const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point; 185 $for M in range(MR): 186 $for N in range(NR): 187 vfpacc${M}x${N} = ${MIN_F32}(vfpacc${M}x${N}, voutput_max_less_zero_point); 188 189 $for M in range(MR): 190 $for N in range(NR): 191 const int32_t vrndacc${M}x${N} = (int32_t) lrintf(vfpacc${M}x${N}); 192 193 const int32_t voutput_zero_point = params->${PARAMS_STRUCT}.output_zero_point; 194 $for M in range(MR): 195 $for N in range(NR): 196 int32_t vout${M}x${N} = vrndacc${M}x${N} + voutput_zero_point; 197 198 if XNN_LIKELY(nc >= ${NR}) { 199 $for M in reversed(range(MR)): 200 $for N in range(NR): 201 c${M}[${N}] = (${XINT8_T}) vout${M}x${N}; 202 203 $for M in reversed(range(MR)): 204 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 205 206 a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks); 207 nc -= ${NR}; 208 } else { 209 $for LOG2N in reversed(range(NR.bit_length() - 1)): 210 if (nc & ${1 << LOG2N}) { 211 $for M in reversed(range(MR)): 212 $for N in range(1 << LOG2N): 213 c${M}[${N}] = (${XINT8_T}) vout${M}x${N}; 214 $if LOG2N != 0: 215 $for N in range(1 << (LOG2N - 1)): 216 vout${M}x${N} = vout${M}x${N + (1 << LOG2N)}; 217 c${M} += ${1 << LOG2N}; 218 } 219 220 nc = 0; 221 } 222 } while (nc != 0); 223} 224