1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert DATATYPE in ["QS8", "QU8"] 7$assert SSE in [2, 4] 8$assert not AVX or SSE == 4 9$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE] 10$assert BATCH_TILE % 8 == 0 11$assert BATCH_TILE >= 8 12$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 13#include <assert.h> 14 15#include <${SSE_HEADER}> 16 17#include <xnnpack/unaligned.h> 18#include <xnnpack/vadd.h> 19 20 21$PARAMS_STRUCT = "sse4_mul16" if SSE == 4 and DATATYPE == "QS8" else "sse2" 22$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE] 23$_MM_CVTEPX8_EPI16 = {"QS8": "_mm_cvtepi8_epi16", "QU8": "_mm_cvtepu8_epi16"}[DATATYPE] 24$_MM_PACKXS_EPI16 = {"QS8": "_mm_packs_epi16", "QU8": "_mm_packus_epi16"}[DATATYPE] 25$_MM_MIN_EPX8 = {"QS8": "_mm_min_epi8", "QU8": "_mm_min_epu8"}[DATATYPE] 26$_MM_MAX_EPX8 = {"QS8": "_mm_max_epi8", "QU8": "_mm_max_epu8"}[DATATYPE] 27$ISA = "avx" if AVX else {2: "sse2", 4: "sse41"}[SSE] 28void xnn_${DATATYPE.lower()}_vadd_minmax_ukernel__${ISA}_mul16_ld64_x${BATCH_TILE}( 29 size_t n, 30 const ${XINT8_T}* input_a, 31 const ${XINT8_T}* input_b, 32 ${XINT8_T}* output, 33 const union xnn_${DATATYPE.lower()}_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 34{ 35 const __m128i vbias = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.bias); 36 const __m128i va_multiplier_lo = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.a_multiplier_lo); 37 const __m128i va_multiplier_hi = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.a_multiplier_hi); 38 const __m128i vb_multiplier_lo = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.b_multiplier_lo); 39 const __m128i vb_multiplier_hi = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.b_multiplier_hi); 40 const __m128i vshift = _mm_cvtsi32_si128((int) params->${PARAMS_STRUCT}.shift); 41 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point); 42 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); 43 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_max); 44 45 for (; n >= ${BATCH_TILE} * sizeof(${XINT8_T}); n -= ${BATCH_TILE} * sizeof(${XINT8_T})) { 46 $if SSE == 4: 47 const __m128i va${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) input_a)); 48 const __m128i vb${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) input_b)); 49 $for N in range(8, BATCH_TILE, 8): 50 const __m128i va${ABC[N:N+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (input_a + ${N}))); 51 const __m128i vb${ABC[N:N+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (input_b + ${N}))); 52 $else: 53 __m128i va${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) input_a); 54 __m128i vb${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) input_b); 55 $for N in range(8, BATCH_TILE, 8): 56 __m128i va${ABC[N:N+8]} = _mm_loadl_epi64((const __m128i*) (input_a + ${N})); 57 __m128i vb${ABC[N:N+8]} = _mm_loadl_epi64((const __m128i*) (input_b + ${N})); 58 input_a += ${BATCH_TILE}; 59 input_b += ${BATCH_TILE}; 60 61 $if SSE < 4: 62 $if DATATYPE == "QU8": 63 const __m128i vzero = _mm_setzero_si128(); 64 $for N in range(0, BATCH_TILE, 8): 65 va${ABC[N:N+8]} = _mm_unpacklo_epi8(va${ABC[N:N+8]}, vzero); 66 vb${ABC[N:N+8]} = _mm_unpacklo_epi8(vb${ABC[N:N+8]}, vzero); 67 $else: 68 $for N in range(0, BATCH_TILE, 8): 69 va${ABC[N:N+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(va${ABC[N:N+8]}, va${ABC[N:N+8]}), 8); 70 vb${ABC[N:N+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${ABC[N:N+8]}, vb${ABC[N:N+8]}), 8); 71 72 $for N in range(0, BATCH_TILE, 8): 73 __m128i vaprod${ABC[N:N+8]}hi = _mm_mulhi_epu16(va${ABC[N:N+8]}, va_multiplier_lo); 74 __m128i vbprod${ABC[N:N+8]}hi = _mm_mulhi_epu16(vb${ABC[N:N+8]}, vb_multiplier_lo); 75 const __m128i vaprod${ABC[N:N+8]}lo = _mm_mullo_epi16(va${ABC[N:N+8]}, va_multiplier_lo); 76 const __m128i vbprod${ABC[N:N+8]}lo = _mm_mullo_epi16(vb${ABC[N:N+8]}, vb_multiplier_lo); 77 78 $for N in range(0, BATCH_TILE, 8): 79 vaprod${ABC[N:N+8]}hi = _mm_add_epi16(vaprod${ABC[N:N+8]}hi, _mm_mullo_epi16(va${ABC[N:N+8]}, va_multiplier_hi)); 80 vbprod${ABC[N:N+8]}hi = _mm_add_epi16(vbprod${ABC[N:N+8]}hi, _mm_mullo_epi16(vb${ABC[N:N+8]}, vb_multiplier_hi)); 81 82 $if DATATYPE == "QS8": 83 $for N in range(0, BATCH_TILE, 8): 84 vaprod${ABC[N:N+8]}hi = _mm_sub_epi16(vaprod${ABC[N:N+8]}hi, _mm_and_si128(_mm_srai_epi16(va${ABC[N:N+8]}, 15), va_multiplier_lo)); 85 vbprod${ABC[N:N+8]}hi = _mm_sub_epi16(vbprod${ABC[N:N+8]}hi, _mm_and_si128(_mm_srai_epi16(vb${ABC[N:N+8]}, 15), vb_multiplier_lo)); 86 87 $for N in range(0, BATCH_TILE, 8): 88 __m128i vacc${ABC[N:N+4]} = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vaprod${ABC[N:N+8]}lo, vaprod${ABC[N:N+8]}hi)); 89 __m128i vacc${ABC[N+4:N+8]} = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vaprod${ABC[N:N+8]}lo, vaprod${ABC[N:N+8]}hi)); 90 91 $for N in range(0, BATCH_TILE, 8): 92 vacc${ABC[N:N+4]} = _mm_add_epi32(vacc${ABC[N:N+4]}, _mm_unpacklo_epi16(vbprod${ABC[N:N+8]}lo, vbprod${ABC[N:N+8]}hi)); 93 vacc${ABC[N+4:N+8]} = _mm_add_epi32(vacc${ABC[N+4:N+8]}, _mm_unpackhi_epi16(vbprod${ABC[N:N+8]}lo, vbprod${ABC[N:N+8]}hi)); 94 95 $for N in range(0, BATCH_TILE, 4): 96 vacc${ABC[N:N+4]} = _mm_sra_epi32(vacc${ABC[N:N+4]}, vshift); 97 98 $for N in range(0, BATCH_TILE, 8): 99 __m128i vout${ABC[N:N+8]} = _mm_adds_epi16(_mm_packs_epi32(vacc${ABC[N:N+4]}, vacc${ABC[N+4:N+8]}), voutput_zero_point); 100 101 $if DATATYPE == "QS8" and SSE < 4: 102 $for N in range(0, BATCH_TILE, 8): 103 vout${ABC[N:N+8]} = _mm_max_epi16(vout${ABC[N:N+8]}, voutput_min); 104 105 $for N in range(0, BATCH_TILE, 8): 106 vout${ABC[N:N+8]} = _mm_min_epi16(vout${ABC[N:N+8]}, voutput_max); 107 108 $for N in range(0, BATCH_TILE, 16): 109 $if N + 8 < BATCH_TILE: 110 __m128i vout${ABC[N:N+16]} = ${_MM_PACKXS_EPI16}(vout${ABC[N:N+8]}, vout${ABC[N+8:N+16]}); 111 $else: 112 __m128i vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_PACKXS_EPI16}(vout${ABC[N:N+8]}, vout${ABC[N:N+8]}); 113 114 $if DATATYPE == "QU8" or SSE == 4: 115 $for N in range(0, BATCH_TILE, 16): 116 $if N + 8 < BATCH_TILE: 117 vout${ABC[N:N+16]} = ${_MM_MAX_EPX8}(vout${ABC[N:N+16]}, voutput_min); 118 $else: 119 vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_MAX_EPX8}(vout${ABC[N:N+8]}${ABC[N:N+8]}, voutput_min); 120 121 $for N in range(0, BATCH_TILE, 16): 122 $if N + 8 < BATCH_TILE: 123 vout${ABC[N:N+16]} = ${_MM_MIN_EPX8}(vout${ABC[N:N+16]}, voutput_max); 124 $else: 125 vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_MIN_EPX8}(vout${ABC[N:N+8]}${ABC[N:N+8]}, voutput_max); 126 127 $if BATCH_TILE >= 16: 128 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]}); 129 $else: 130 _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]}); 131 $for N in range(16, BATCH_TILE, 16): 132 $if N + 8 < BATCH_TILE: 133 _mm_storeu_si128((__m128i*) (output + ${N}), vout${ABC[N:N+16]}); 134 $else: 135 _mm_storel_epi64((__m128i*) (output + ${N}), vout${ABC[N:N+8]}${ABC[N:N+8]}); 136 output += ${BATCH_TILE}; 137 } 138 if XNN_UNLIKELY(n != 0) { 139 ${"do " if BATCH_TILE > 8 else ""}{ 140 $if SSE == 4: 141 const __m128i va${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) input_a)); 142 const __m128i vb${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) input_b)); 143 $else: 144 __m128i va${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) input_a); 145 __m128i vb${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) input_b); 146 $if BATCH_TILE > 8: 147 input_a += 8; 148 input_b += 8; 149 150 $if SSE < 4: 151 $if DATATYPE == "QU8": 152 const __m128i vzero = _mm_setzero_si128(); 153 va${ABC[0:8]} = _mm_unpacklo_epi8(va${ABC[0:8]}, vzero); 154 vb${ABC[0:8]} = _mm_unpacklo_epi8(vb${ABC[0:8]}, vzero); 155 $else: 156 va${ABC[0:8]} = _mm_srai_epi16(_mm_unpacklo_epi8(va${ABC[0:8]}, va${ABC[0:8]}), 8); 157 vb${ABC[0:8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${ABC[0:8]}, vb${ABC[0:8]}), 8); 158 159 __m128i vaprod${ABC[0:8]}hi = _mm_mulhi_epu16(va${ABC[0:8]}, va_multiplier_lo); 160 __m128i vbprod${ABC[0:8]}hi = _mm_mulhi_epu16(vb${ABC[0:8]}, vb_multiplier_lo); 161 const __m128i vaprod${ABC[0:8]}lo = _mm_mullo_epi16(va${ABC[0:8]}, va_multiplier_lo); 162 const __m128i vbprod${ABC[0:8]}lo = _mm_mullo_epi16(vb${ABC[0:8]}, vb_multiplier_lo); 163 164 vaprod${ABC[0:8]}hi = _mm_add_epi16(vaprod${ABC[0:8]}hi, _mm_mullo_epi16(va${ABC[0:8]}, va_multiplier_hi)); 165 vbprod${ABC[0:8]}hi = _mm_add_epi16(vbprod${ABC[0:8]}hi, _mm_mullo_epi16(vb${ABC[0:8]}, vb_multiplier_hi)); 166 167 $if DATATYPE == "QS8": 168 vaprod${ABC[0:8]}hi = _mm_sub_epi16(vaprod${ABC[0:8]}hi, _mm_and_si128(_mm_srai_epi16(va${ABC[0:8]}, 15), va_multiplier_lo)); 169 vbprod${ABC[0:8]}hi = _mm_sub_epi16(vbprod${ABC[0:8]}hi, _mm_and_si128(_mm_srai_epi16(vb${ABC[0:8]}, 15), vb_multiplier_lo)); 170 171 __m128i vacc${ABC[0:4]} = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vaprod${ABC[0:8]}lo, vaprod${ABC[0:8]}hi)); 172 __m128i vacc${ABC[4:8]} = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vaprod${ABC[0:8]}lo, vaprod${ABC[0:8]}hi)); 173 174 vacc${ABC[0:4]} = _mm_add_epi32(vacc${ABC[0:4]}, _mm_unpacklo_epi16(vbprod${ABC[0:8]}lo, vbprod${ABC[0:8]}hi)); 175 vacc${ABC[4:8]} = _mm_add_epi32(vacc${ABC[4:8]}, _mm_unpackhi_epi16(vbprod${ABC[0:8]}lo, vbprod${ABC[0:8]}hi)); 176 177 vacc${ABC[0:4]} = _mm_sra_epi32(vacc${ABC[0:4]}, vshift); 178 vacc${ABC[4:8]} = _mm_sra_epi32(vacc${ABC[4:8]}, vshift); 179 180 __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(vacc${ABC[0:4]}, vacc${ABC[4:8]}), voutput_zero_point); 181 $if DATATYPE == "QS8" and SSE < 4: 182 vout${ABC[0:8]} = _mm_max_epi16(vout${ABC[0:8]}, voutput_min); 183 vout${ABC[0:8]} = _mm_min_epi16(vout${ABC[0:8]}, voutput_max); 184 185 __m128i vout${ABC[0:8]}${ABC[0:8]} = ${_MM_PACKXS_EPI16}(vout${ABC[0:8]}, vout${ABC[0:8]}); 186 $if DATATYPE == "QU8" or SSE == 4: 187 vout${ABC[0:8]}${ABC[0:8]} = ${_MM_MAX_EPX8}(vout${ABC[0:8]}${ABC[0:8]}, voutput_min); 188 vout${ABC[0:8]}${ABC[0:8]} = ${_MM_MIN_EPX8}(vout${ABC[0:8]}${ABC[0:8]}, voutput_max); 189 190 $if BATCH_TILE > 8: 191 if XNN_LIKELY(n >= (8 * sizeof(${XINT8_T}))) { 192 _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]}); 193 output += 8; 194 n -= 8 * sizeof(${XINT8_T}); 195 } else { 196 if (n & (4 * sizeof(${XINT8_T}))) { 197 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]})); 198 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32); 199 output += 4; 200 } 201 if (n & (2 * sizeof(${XINT8_T}))) { 202 $if SSE == 4: 203 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0)); 204 $else: 205 unaligned_store_u16(output, (uint16_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]})); 206 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16); 207 output += 2; 208 } 209 if (n & (1 * sizeof(${XINT8_T}))) { 210 $if SSE == 4: 211 *output = (${XINT8_T}) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0); 212 $else: 213 *output = (int32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 214 } 215 n = 0; 216 } 217 $else: 218 if (n & (4 * sizeof(${XINT8_T}))) { 219 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]})); 220 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32); 221 output += 4; 222 } 223 if (n & (2 * sizeof(${XINT8_T}))) { 224 $if SSE == 4: 225 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0)); 226 $else: 227 unaligned_store_u16(output, (uint16_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]})); 228 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16); 229 output += 2; 230 } 231 if (n & (1 * sizeof(${XINT8_T}))) { 232 $if SSE == 4: 233 *output = (${XINT8_T}) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0); 234 $else: 235 *output = (${XINT8_T}) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]}); 236 } 237 }${" while (n != 0);" if BATCH_TILE > 8 else ""} 238 } 239} 240