1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert REQUANTIZATION == "FP32" 7$assert DATATYPE in ["QC8", "QS8", "QU8"] 8$assert VARIANT in ["LD64", "LD128", "EXTENDED"] 9$assert MR <= 4 10#include <assert.h> 11 12#include <wasm_simd128.h> 13 14#include <xnnpack/gemm.h> 15#include <xnnpack/math.h> 16 17 18$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT] 19$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else "" 20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_wasmsimd" 21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 22$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 23$WASM_X16X8_LOAD8X8 = "wasm_u16x8_load8x8" if DATATYPE == "QU8" else "wasm_i16x8_load8x8" 24$WASM_X8X16_NARROW_I16X8 = "wasm_u8x16_narrow_i16x8" if DATATYPE == "QU8" else "wasm_i8x16_narrow_i16x8" 25$WASM_X8X16_MIN = "wasm_u8x16_min" if DATATYPE == "QU8" else "wasm_i8x16_min" 26void xnn_${DATATYPE.lower()}_igemm${GEMM_SUFFIX}_minmax_fp32_ukernel_${MR}x4c8__wasmsimd_dot16x2${LOAD_SUFFIX}( 27 size_t mr, 28 size_t nc, 29 size_t kc, 30 size_t ks, 31 const ${XINT8_T}** restrict a, 32 const void* restrict w, 33 ${XINT8_T}* restrict c, 34 size_t cm_stride, 35 size_t cn_stride, 36 size_t a_offset, 37 const ${XINT8_T}* zero, 38 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 39{ 40 assert(mr != 0); 41 assert(mr <= ${MR}); 42 assert(nc != 0); 43 assert(kc != 0); 44 assert(ks != 0); 45 assert(ks % (${MR} * sizeof(void*)) == 0); 46 assert(a_offset % sizeof(${XINT8_T}) == 0); 47 assert(a != NULL); 48 assert(w != NULL); 49 assert(c != NULL); 50 51 kc = round_up_po2(kc, 8); 52 ${XINT8_T}* c0 = c; 53 $for M in range(1, MR): 54 ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride); 55 $if M % 2 == 0: 56 if XNN_UNPREDICTABLE(mr <= ${M}) { 57 c${M} = c${M-1}; 58 } 59 $elif M + 1 == MR: 60 if XNN_UNPREDICTABLE(mr != ${M+1}) { 61 c${M} = c${M-1}; 62 } 63 $else: 64 if XNN_UNPREDICTABLE(mr < ${M+1}) { 65 c${M} = c${M-1}; 66 } 67 68 $if DATATYPE == "QU8": 69 const v128_t vb_zero_point = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.kernel_zero_point); 70 do { 71 v128_t vacc0x0 = wasm_v128_load32_zero(w); 72 $for N in range(1, 4): 73 v128_t vacc0x${N} = wasm_v128_load32_zero((const int32_t*) w + ${N}); 74 $for M in range(1, MR): 75 $for N in range(4): 76 v128_t vacc${M}x${N} = vacc0x${N}; 77 w = (const void*) ((const int32_t*) w + 4); 78 79 size_t p = ks; 80 do { 81 $for M in range(MR): 82 const ${XINT8_T}* restrict a${M} = a[${M}]; 83 if XNN_UNPREDICTABLE(a${M} != zero) { 84 a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + a_offset); 85 } 86 a += ${MR}; 87 88 size_t k = 0; 89 while (k < kc) { 90 $for M in range(MR): 91 const v128_t vxa${M} = ${WASM_X16X8_LOAD8X8}(a${M}); 92 a${M} += 8; 93 94 $if VARIANT == "LD128": 95 $for N in range(0, 4, 2): 96 $if N == 0: 97 const v128_t vb${N}${N+1} = wasm_v128_load(w); 98 $else: 99 const v128_t vb${N}${N+1} = wasm_v128_load((const ${XINT8_T}*) w + ${N * 8}); 100 $if DATATYPE == "QU8": 101 const v128_t vxb${N} = wasm_i16x8_sub(wasm_u16x8_extend_low_u8x16(vb${N}${N+1}), vb_zero_point); 102 const v128_t vxb${N+1} = wasm_i16x8_sub(wasm_u16x8_extend_high_u8x16(vb${N}${N+1}), vb_zero_point); 103 $else: 104 const v128_t vxb${N} = wasm_i16x8_extend_low_i8x16(vb${N}${N+1}); 105 const v128_t vxb${N+1} = wasm_i16x8_extend_high_i8x16(vb${N}${N+1}); 106 107 $for M in range(MR): 108 vacc${M}x${N} = wasm_i32x4_add(vacc${M}x${N}, wasm_i32x4_dot_i16x8(vxa${M}, vxb${N})); 109 vacc${M}x${N+1} = wasm_i32x4_add(vacc${M}x${N+1}, wasm_i32x4_dot_i16x8(vxa${M}, vxb${N+1})); 110 $else: 111 $for N in range(4): 112 $if VARIANT == "LD64": 113 $if DATATYPE == "QU8": 114 $if N == 0: 115 const v128_t vxb${N} = wasm_i16x8_sub(wasm_u16x8_load8x8(w), vb_zero_point); 116 $else: 117 const v128_t vxb${N} = wasm_i16x8_sub(wasm_u16x8_load8x8((const ${XINT8_T}*) w + ${N * 8}), vb_zero_point); 118 $else: 119 $if N == 0: 120 const v128_t vxb${N} = wasm_i16x8_load8x8(w); 121 $else: 122 const v128_t vxb${N} = wasm_i16x8_load8x8((const ${XINT8_T}*) w + ${N * 8}); 123 $elif VARIANT == "EXTENDED": 124 $if N == 0: 125 const v128_t vxb${N} = wasm_v128_load(w); 126 $else: 127 const v128_t vxb${N} = wasm_v128_load((const int16_t*) w + ${N * 8}); 128 129 $for M in range(MR): 130 vacc${M}x${N} = wasm_i32x4_add(vacc${M}x${N}, wasm_i32x4_dot_i16x8(vxa${M}, vxb${N})); 131 132 $if VARIANT == "EXTENDED": 133 w = (const void*) ((const int16_t*) w + 32); 134 $else: 135 w = (const void*) ((const ${XINT8_T}*) w + 32); 136 k += 8 * sizeof(${XINT8_T}); 137 } 138 p -= ${MR} * sizeof(void*); 139 } while (p != 0); 140 141 $for M in range(MR): 142 const v128_t vacc${M}x02 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x0, vacc${M}x2, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x0, vacc${M}x2, 2, 6, 3, 7)); 143 const v128_t vacc${M}x13 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x1, vacc${M}x3, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x1, vacc${M}x3, 2, 6, 3, 7)); 144 145 $for M in range(MR): 146 v128_t vacc${M}x0123 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x02, vacc${M}x13, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x02, vacc${M}x13, 2, 6, 3, 7)); 147 148 $for M in range(MR): 149 vacc${M}x0123 = wasm_f32x4_convert_i32x4(vacc${M}x0123); 150 151 $if DATATYPE == "QC8": 152 const v128_t vscale0123 = wasm_v128_load(w); 153 w = (const void*) ((const float*) w + 4); 154 $for M in range(MR): 155 vacc${M}x0123 = wasm_f32x4_mul(vacc${M}x0123, vscale0123); 156 $else: 157 const v128_t vscale = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.scale); 158 $for M in range(MR): 159 vacc${M}x0123 = wasm_f32x4_mul(vacc${M}x0123, vscale); 160 161 const v128_t vmagic_bias = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_bias); 162 $for M in range(MR): 163 vacc${M}x0123 = wasm_f32x4_add(vacc${M}x0123, vmagic_bias); 164 165 const v128_t vmagic_min = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_min); 166 $for M in range(MR): 167 vacc${M}x0123 = wasm_i32x4_max(vacc${M}x0123, vmagic_min); 168 169 const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point); 170 $for M in range(MR): 171 vacc${M}x0123 = wasm_i32x4_sub(vacc${M}x0123, vmagic_bias_less_output_zero_point); 172 173 $for M in range(0, MR, 2): 174 v128_t vacc${M}${min(M+1, MR-1)}x0123 = wasm_i16x8_narrow_i32x4(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123); 175 176 $if MR > 2: 177 v128_t vout = ${WASM_X8X16_NARROW_I16X8}(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123); 178 $else: 179 v128_t vout = ${WASM_X8X16_NARROW_I16X8}(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123); 180 181 const v128_t voutput_max = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.output_max); 182 vout = ${WASM_X8X16_MIN}(vout, voutput_max); 183 184 if (nc >= 4) { 185 $for M in reversed(range(MR)): 186 *((float*) c${M}) = (float) wasm_f32x4_extract_lane(vout, ${M}); 187 188 $for M in reversed(range(MR)): 189 c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride); 190 191 a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks); 192 193 nc -= 4; 194 } else { 195 $for M in reversed(range(MR)): 196 uint32_t vout${M} = wasm_i32x4_extract_lane(vout, ${M}); 197 if (nc & 2) { 198 $for M in reversed(range(MR)): 199 *((uint16_t*) c${M}) = (uint16_t) vout${M}; 200 vout${M} >>= 16; 201 c${M} += 2; 202 } 203 if (nc & 1) { 204 $for M in reversed(range(MR)): 205 *c${M} = (${XINT8_T}) vout${M}; 206 } 207 208 nc = 0; 209 } 210 } while (nc != 0); 211} 212