1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert REQUANTIZATION == "FP32" 7$assert DATATYPE in ["QC8", "QS8", "QU8"] 8$assert 1 <= MR <= 2 9$assert 1 <= NR <= 2 10#include <assert.h> 11 12#include <arm_acle.h> 13 14#include <xnnpack/intrinsics-polyfill.h> 15#include <xnnpack/math.h> 16#include <xnnpack/gemm.h> 17#include <xnnpack/unaligned.h> 18 19 20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_armsimd32" 21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 22$__XXTB16 = "__uxtb16" if DATATYPE == "QU8" else "__sxtb16" 23$__XSAT = "__usat" if DATATYPE == "QU8" else "__ssat" 24$__XSUB8 = "__usub8" if DATATYPE == "QU8" else "__ssub8" 25$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 26void xnn_${DATATYPE.lower()}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}c4__armsimd32( 27 size_t mr, 28 size_t nc, 29 size_t kc, 30 const ${XINT8_T}* restrict a, 31 size_t a_stride, 32 const void* restrict w, 33 ${XINT8_T}* restrict c, 34 size_t cm_stride, 35 size_t cn_stride, 36 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) 37{ 38 assert(mr != 0); 39 assert(mr <= ${MR}); 40 assert(nc != 0); 41 assert(kc != 0); 42 43 kc = round_up_po2(kc, 4 * sizeof(int8_t)); 44 const ${XINT8_T}* a0 = a; 45 ${XINT8_T}* c0 = c; 46 $for M in range(1, MR): 47 const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride); 48 ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride); 49 $if M % 2 == 0: 50 if XNN_UNPREDICTABLE(mr <= ${M}) { 51 a${M} = a${M-1}; 52 c${M} = c${M-1}; 53 } 54 $elif M + 1 == MR: 55 if XNN_UNPREDICTABLE(mr != ${M+1}) { 56 a${M} = a${M-1}; 57 c${M} = c${M-1}; 58 } 59 $else: 60 if XNN_UNPREDICTABLE(mr < ${M+1}) { 61 a${M} = a${M-1}; 62 c${M} = c${M-1}; 63 } 64 65 $if DATATYPE == "QU8": 66 const int16x2_t vb_minus_zero_point = (int16x2_t) params->${PARAMS_STRUCT}.minus_kernel_zero_point; 67 $if REQUANTIZATION == "FP32": 68 $if DATATYPE != "QC8": 69 const float vscale = params->${PARAMS_STRUCT}.scale; 70 const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 71 do { 72 $for N in range(NR): 73 int32_t vacc0x${N} = ((const int32_t*) w)[${N}]; 74 $for M in range(1, MR): 75 $for N in range(NR): 76 int32_t vacc${M}x${N} = vacc0x${N}; 77 w = (const void*) ((const int32_t*) w + ${NR}); 78 79 size_t k = kc; 80 do { 81 $for M in range(MR): 82 const int8x4_t va${M} = (int8x4_t) unaligned_load_s32(a${M}); a${M} += 4; 83 84 $for M in range(MR): 85 const int16x2_t va${M}c02 = ${__XXTB16}(va${M}); 86 const int16x2_t va${M}c13 = ${__XXTB16}(__ror(va${M}, 8)); 87 88 $for N in range(NR): 89 const int8x4_t vb${N} = *((const int8x4_t*) w); w = (const int8_t*) w + 4; 90 $if DATATYPE == "QU8": 91 const int16x2_t vb${N}c02 = __uxtab16(vb_minus_zero_point, vb${N}); 92 $else: 93 const int16x2_t vb${N}c02 = __sxtb16(vb${N}); 94 95 $for M in range(MR): 96 vacc${M}x${N} = __smlad(va${M}c02, vb${N}c02, vacc${M}x${N}); 97 98 $if DATATYPE == "QU8": 99 const int16x2_t vb${N}c13 = __uxtab16(vb_minus_zero_point, __ror(vb${N}, 8)); 100 $else: 101 const int16x2_t vb${N}c13 = __sxtb16(__ror(vb${N}, 8)); 102 $for M in range(MR): 103 vacc${M}x${N} = __smlad(va${M}c13, vb${N}c13, vacc${M}x${N}); 104 105 k -= 4 * sizeof(${XINT8_T}); 106 } while (k != 0); 107 108 $for M in range(MR): 109 $for N in range(NR): 110 float vfpacc${M}x${N} = (float) vacc${M}x${N}; 111 112 $if DATATYPE == "QC8": 113 $for N in range(NR): 114 const float vscale${N} = ((const float*) w)[${N}]; 115 $for M in range(MR): 116 vfpacc${M}x${N} *= vscale${N}; 117 w = (const void*) ((const float*) w + ${NR}); 118 $else: 119 $for M in range(MR): 120 $for N in range(NR): 121 vfpacc${M}x${N} *= vscale; 122 123 $for M in range(MR): 124 $for N in range(NR): 125 vfpacc${M}x${N} += vmagic_bias; 126 127 $for M in range(MR): 128 $for N in range(NR): 129 int32_t vout${M}x${N} = (int32_t) float_as_uint32(vfpacc${M}x${N}); 130 131 const int32_t vmagic_bias_less_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_zero_point; 132 $for M in range(MR): 133 $for N in range(NR): 134 vout${M}x${N} = __qsub(vout${M}x${N}, vmagic_bias_less_zero_point); 135 136 $for M in range(MR): 137 $for N in range(NR): 138 vout${M}x${N} = ${__XSAT}(vout${M}x${N}, 8); 139 140 $for M in range(MR): 141 $if NR == 1: 142 const uint32_t vout${M} = (uint32_t) vout${M}x0; 143 $else: 144 const uint32_t vout${M} = (uint32_t) (uint8_t) vout${M}x0 | ((uint32_t) vout${M}x1 << 8); 145 146 $if MR == 1: 147 uint32_t vout = vout0; 148 $else: 149 uint32_t vout = (uint32_t) (uint16_t) vout0 | (vout1 << 16); 150 151 const int8x4_t voutput_min = (int8x4_t) params->${PARAMS_STRUCT}.output_min; 152 ${__XSUB8}((int8x4_t) vout, voutput_min); 153 vout = (uint32_t) __sel((uint8x4_t) vout, (uint8x4_t) voutput_min); 154 155 const int8x4_t voutput_max = (int8x4_t) params->${PARAMS_STRUCT}.output_max; 156 ${__XSUB8}((int8x4_t) vout, voutput_max); 157 vout = (uint32_t) __sel((uint8x4_t) voutput_max, (uint8x4_t) vout); 158 159 $if NR == 2: 160 if XNN_LIKELY(nc >= ${NR}) { 161 $for M in range(MR): 162 unaligned_store_u16(c${M}, (uint16_t) vout); 163 $if M + 1 != MR: 164 vout >>= 16; 165 166 $for M in range(MR): 167 a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); 168 169 $for M in range(MR): 170 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 171 172 nc -= ${NR}; 173 } else { 174 $for M in range(MR): 175 *c${M} = (${XINT8_T}) vout; 176 $if M + 1 != MR: 177 vout >>= 16; 178 179 nc = 0; 180 } 181 $else: 182 $for M in range(MR): 183 *c${M} = (${XINT8_T}) vout; 184 $if M + 1 != MR: 185 vout >>= 16; 186 187 $for M in range(MR): 188 a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); 189 190 $for M in range(MR): 191 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 192 193 nc -= 1; 194 } while (nc != 0); 195} 196