1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2022 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker// 3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker$assert REQUANTIZATION == "FP32" 7*4bdc9457SAndroid Build Coastguard Worker$assert DATATYPE in ["QC8", "QS8", "QU8"] 8*4bdc9457SAndroid Build Coastguard Worker$assert 1 <= MR <= 2 9*4bdc9457SAndroid Build Coastguard Worker$assert 1 <= NR <= 2 10*4bdc9457SAndroid Build Coastguard Worker#include <assert.h> 11*4bdc9457SAndroid Build Coastguard Worker 12*4bdc9457SAndroid Build Coastguard Worker#include <arm_acle.h> 13*4bdc9457SAndroid Build Coastguard Worker 14*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/intrinsics-polyfill.h> 15*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/math.h> 16*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/gemm.h> 17*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/unaligned.h> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker$PARAMS_STRUCT = REQUANTIZATION.lower() + "_armsimd32" 21*4bdc9457SAndroid Build Coastguard Worker$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 22*4bdc9457SAndroid Build Coastguard Worker$__XXTB16 = "__uxtb16" if DATATYPE == "QU8" else "__sxtb16" 23*4bdc9457SAndroid Build Coastguard Worker$__XSAT = "__usat" if DATATYPE == "QU8" else "__ssat" 24*4bdc9457SAndroid Build Coastguard Worker$__XSUB8 = "__usub8" if DATATYPE == "QU8" else "__ssub8" 25*4bdc9457SAndroid Build Coastguard Worker$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 26*4bdc9457SAndroid Build Coastguard Workervoid xnn_${DATATYPE.lower()}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}c4__armsimd32( 27*4bdc9457SAndroid Build Coastguard Worker size_t mr, 28*4bdc9457SAndroid Build Coastguard Worker size_t nc, 29*4bdc9457SAndroid Build Coastguard Worker size_t kc, 30*4bdc9457SAndroid Build Coastguard Worker const ${XINT8_T}* restrict a, 31*4bdc9457SAndroid Build Coastguard Worker size_t a_stride, 32*4bdc9457SAndroid Build Coastguard Worker const void* restrict w, 33*4bdc9457SAndroid Build Coastguard Worker ${XINT8_T}* restrict c, 34*4bdc9457SAndroid Build Coastguard Worker size_t cm_stride, 35*4bdc9457SAndroid Build Coastguard Worker size_t cn_stride, 36*4bdc9457SAndroid Build Coastguard Worker const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) 37*4bdc9457SAndroid Build Coastguard Worker{ 38*4bdc9457SAndroid Build Coastguard Worker assert(mr != 0); 39*4bdc9457SAndroid Build Coastguard Worker assert(mr <= ${MR}); 40*4bdc9457SAndroid Build Coastguard Worker assert(nc != 0); 41*4bdc9457SAndroid Build Coastguard Worker assert(kc != 0); 42*4bdc9457SAndroid Build Coastguard Worker 43*4bdc9457SAndroid Build Coastguard Worker kc = round_up_po2(kc, 4 * sizeof(int8_t)); 44*4bdc9457SAndroid Build Coastguard Worker const ${XINT8_T}* a0 = a; 45*4bdc9457SAndroid Build Coastguard Worker ${XINT8_T}* c0 = c; 46*4bdc9457SAndroid Build Coastguard Worker $for M in range(1, MR): 47*4bdc9457SAndroid Build Coastguard Worker const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride); 48*4bdc9457SAndroid Build Coastguard Worker ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride); 49*4bdc9457SAndroid Build Coastguard Worker $if M % 2 == 0: 50*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(mr <= ${M}) { 51*4bdc9457SAndroid Build Coastguard Worker a${M} = a${M-1}; 52*4bdc9457SAndroid Build Coastguard Worker c${M} = c${M-1}; 53*4bdc9457SAndroid Build Coastguard Worker } 54*4bdc9457SAndroid Build Coastguard Worker $elif M + 1 == MR: 55*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(mr != ${M+1}) { 56*4bdc9457SAndroid Build Coastguard Worker a${M} = a${M-1}; 57*4bdc9457SAndroid Build Coastguard Worker c${M} = c${M-1}; 58*4bdc9457SAndroid Build Coastguard Worker } 59*4bdc9457SAndroid Build Coastguard Worker $else: 60*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(mr < ${M+1}) { 61*4bdc9457SAndroid Build Coastguard Worker a${M} = a${M-1}; 62*4bdc9457SAndroid Build Coastguard Worker c${M} = c${M-1}; 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker 65*4bdc9457SAndroid Build Coastguard Worker $if DATATYPE == "QU8": 66*4bdc9457SAndroid Build Coastguard Worker const int16x2_t vb_minus_zero_point = (int16x2_t) params->${PARAMS_STRUCT}.minus_kernel_zero_point; 67*4bdc9457SAndroid Build Coastguard Worker $if REQUANTIZATION == "FP32": 68*4bdc9457SAndroid Build Coastguard Worker $if DATATYPE != "QC8": 69*4bdc9457SAndroid Build Coastguard Worker const float vscale = params->${PARAMS_STRUCT}.scale; 70*4bdc9457SAndroid Build Coastguard Worker const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; 71*4bdc9457SAndroid Build Coastguard Worker do { 72*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 73*4bdc9457SAndroid Build Coastguard Worker int32_t vacc0x${N} = ((const int32_t*) w)[${N}]; 74*4bdc9457SAndroid Build Coastguard Worker $for M in range(1, MR): 75*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 76*4bdc9457SAndroid Build Coastguard Worker int32_t vacc${M}x${N} = vacc0x${N}; 77*4bdc9457SAndroid Build Coastguard Worker w = (const void*) ((const int32_t*) w + ${NR}); 78*4bdc9457SAndroid Build Coastguard Worker 79*4bdc9457SAndroid Build Coastguard Worker size_t k = kc; 80*4bdc9457SAndroid Build Coastguard Worker do { 81*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 82*4bdc9457SAndroid Build Coastguard Worker const int8x4_t va${M} = (int8x4_t) unaligned_load_s32(a${M}); a${M} += 4; 83*4bdc9457SAndroid Build Coastguard Worker 84*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 85*4bdc9457SAndroid Build Coastguard Worker const int16x2_t va${M}c02 = ${__XXTB16}(va${M}); 86*4bdc9457SAndroid Build Coastguard Worker const int16x2_t va${M}c13 = ${__XXTB16}(__ror(va${M}, 8)); 87*4bdc9457SAndroid Build Coastguard Worker 88*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 89*4bdc9457SAndroid Build Coastguard Worker const int8x4_t vb${N} = *((const int8x4_t*) w); w = (const int8_t*) w + 4; 90*4bdc9457SAndroid Build Coastguard Worker $if DATATYPE == "QU8": 91*4bdc9457SAndroid Build Coastguard Worker const int16x2_t vb${N}c02 = __uxtab16(vb_minus_zero_point, vb${N}); 92*4bdc9457SAndroid Build Coastguard Worker $else: 93*4bdc9457SAndroid Build Coastguard Worker const int16x2_t vb${N}c02 = __sxtb16(vb${N}); 94*4bdc9457SAndroid Build Coastguard Worker 95*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 96*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${N} = __smlad(va${M}c02, vb${N}c02, vacc${M}x${N}); 97*4bdc9457SAndroid Build Coastguard Worker 98*4bdc9457SAndroid Build Coastguard Worker $if DATATYPE == "QU8": 99*4bdc9457SAndroid Build Coastguard Worker const int16x2_t vb${N}c13 = __uxtab16(vb_minus_zero_point, __ror(vb${N}, 8)); 100*4bdc9457SAndroid Build Coastguard Worker $else: 101*4bdc9457SAndroid Build Coastguard Worker const int16x2_t vb${N}c13 = __sxtb16(__ror(vb${N}, 8)); 102*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 103*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${N} = __smlad(va${M}c13, vb${N}c13, vacc${M}x${N}); 104*4bdc9457SAndroid Build Coastguard Worker 105*4bdc9457SAndroid Build Coastguard Worker k -= 4 * sizeof(${XINT8_T}); 106*4bdc9457SAndroid Build Coastguard Worker } while (k != 0); 107*4bdc9457SAndroid Build Coastguard Worker 108*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 109*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 110*4bdc9457SAndroid Build Coastguard Worker float vfpacc${M}x${N} = (float) vacc${M}x${N}; 111*4bdc9457SAndroid Build Coastguard Worker 112*4bdc9457SAndroid Build Coastguard Worker $if DATATYPE == "QC8": 113*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 114*4bdc9457SAndroid Build Coastguard Worker const float vscale${N} = ((const float*) w)[${N}]; 115*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 116*4bdc9457SAndroid Build Coastguard Worker vfpacc${M}x${N} *= vscale${N}; 117*4bdc9457SAndroid Build Coastguard Worker w = (const void*) ((const float*) w + ${NR}); 118*4bdc9457SAndroid Build Coastguard Worker $else: 119*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 120*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 121*4bdc9457SAndroid Build Coastguard Worker vfpacc${M}x${N} *= vscale; 122*4bdc9457SAndroid Build Coastguard Worker 123*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 124*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 125*4bdc9457SAndroid Build Coastguard Worker vfpacc${M}x${N} += vmagic_bias; 126*4bdc9457SAndroid Build Coastguard Worker 127*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 128*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 129*4bdc9457SAndroid Build Coastguard Worker int32_t vout${M}x${N} = (int32_t) float_as_uint32(vfpacc${M}x${N}); 130*4bdc9457SAndroid Build Coastguard Worker 131*4bdc9457SAndroid Build Coastguard Worker const int32_t vmagic_bias_less_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_zero_point; 132*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 133*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 134*4bdc9457SAndroid Build Coastguard Worker vout${M}x${N} = __qsub(vout${M}x${N}, vmagic_bias_less_zero_point); 135*4bdc9457SAndroid Build Coastguard Worker 136*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 137*4bdc9457SAndroid Build Coastguard Worker $for N in range(NR): 138*4bdc9457SAndroid Build Coastguard Worker vout${M}x${N} = ${__XSAT}(vout${M}x${N}, 8); 139*4bdc9457SAndroid Build Coastguard Worker 140*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 141*4bdc9457SAndroid Build Coastguard Worker $if NR == 1: 142*4bdc9457SAndroid Build Coastguard Worker const uint32_t vout${M} = (uint32_t) vout${M}x0; 143*4bdc9457SAndroid Build Coastguard Worker $else: 144*4bdc9457SAndroid Build Coastguard Worker const uint32_t vout${M} = (uint32_t) (uint8_t) vout${M}x0 | ((uint32_t) vout${M}x1 << 8); 145*4bdc9457SAndroid Build Coastguard Worker 146*4bdc9457SAndroid Build Coastguard Worker $if MR == 1: 147*4bdc9457SAndroid Build Coastguard Worker uint32_t vout = vout0; 148*4bdc9457SAndroid Build Coastguard Worker $else: 149*4bdc9457SAndroid Build Coastguard Worker uint32_t vout = (uint32_t) (uint16_t) vout0 | (vout1 << 16); 150*4bdc9457SAndroid Build Coastguard Worker 151*4bdc9457SAndroid Build Coastguard Worker const int8x4_t voutput_min = (int8x4_t) params->${PARAMS_STRUCT}.output_min; 152*4bdc9457SAndroid Build Coastguard Worker ${__XSUB8}((int8x4_t) vout, voutput_min); 153*4bdc9457SAndroid Build Coastguard Worker vout = (uint32_t) __sel((uint8x4_t) vout, (uint8x4_t) voutput_min); 154*4bdc9457SAndroid Build Coastguard Worker 155*4bdc9457SAndroid Build Coastguard Worker const int8x4_t voutput_max = (int8x4_t) params->${PARAMS_STRUCT}.output_max; 156*4bdc9457SAndroid Build Coastguard Worker ${__XSUB8}((int8x4_t) vout, voutput_max); 157*4bdc9457SAndroid Build Coastguard Worker vout = (uint32_t) __sel((uint8x4_t) voutput_max, (uint8x4_t) vout); 158*4bdc9457SAndroid Build Coastguard Worker 159*4bdc9457SAndroid Build Coastguard Worker $if NR == 2: 160*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(nc >= ${NR}) { 161*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 162*4bdc9457SAndroid Build Coastguard Worker unaligned_store_u16(c${M}, (uint16_t) vout); 163*4bdc9457SAndroid Build Coastguard Worker $if M + 1 != MR: 164*4bdc9457SAndroid Build Coastguard Worker vout >>= 16; 165*4bdc9457SAndroid Build Coastguard Worker 166*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 167*4bdc9457SAndroid Build Coastguard Worker a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); 168*4bdc9457SAndroid Build Coastguard Worker 169*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 170*4bdc9457SAndroid Build Coastguard Worker c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 171*4bdc9457SAndroid Build Coastguard Worker 172*4bdc9457SAndroid Build Coastguard Worker nc -= ${NR}; 173*4bdc9457SAndroid Build Coastguard Worker } else { 174*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 175*4bdc9457SAndroid Build Coastguard Worker *c${M} = (${XINT8_T}) vout; 176*4bdc9457SAndroid Build Coastguard Worker $if M + 1 != MR: 177*4bdc9457SAndroid Build Coastguard Worker vout >>= 16; 178*4bdc9457SAndroid Build Coastguard Worker 179*4bdc9457SAndroid Build Coastguard Worker nc = 0; 180*4bdc9457SAndroid Build Coastguard Worker } 181*4bdc9457SAndroid Build Coastguard Worker $else: 182*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 183*4bdc9457SAndroid Build Coastguard Worker *c${M} = (${XINT8_T}) vout; 184*4bdc9457SAndroid Build Coastguard Worker $if M + 1 != MR: 185*4bdc9457SAndroid Build Coastguard Worker vout >>= 16; 186*4bdc9457SAndroid Build Coastguard Worker 187*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 188*4bdc9457SAndroid Build Coastguard Worker a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); 189*4bdc9457SAndroid Build Coastguard Worker 190*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 191*4bdc9457SAndroid Build Coastguard Worker c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 192*4bdc9457SAndroid Build Coastguard Worker 193*4bdc9457SAndroid Build Coastguard Worker nc -= 1; 194*4bdc9457SAndroid Build Coastguard Worker } while (nc != 0); 195*4bdc9457SAndroid Build Coastguard Worker} 196