1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert NR % 8 == 0 7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 8#include <assert.h> 9 10#include <immintrin.h> 11 12#include <xnnpack/gemm.h> 13#include <xnnpack/intrinsics-polyfill.h> 14 15 16void xnn_f16_gemm_minmax_ukernel_${MR}x${NR}__avx2_broadcast( 17 size_t mr, 18 size_t nc, 19 size_t kc, 20 const void*restrict a, 21 size_t a_stride, 22 const void*restrict w, 23 void*restrict c, 24 size_t cm_stride, 25 size_t cn_stride, 26 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 27{ 28 assert(mr != 0); 29 assert(mr <= ${MR}); 30 assert(nc != 0); 31 assert(kc != 0); 32 assert(kc % sizeof(uint16_t) == 0); 33 assert(a != NULL); 34 assert(w != NULL); 35 assert(c != NULL); 36 37 const uint16_t* a0 = a; 38 uint16_t* c0 = c; 39 $for M in range(1, MR): 40 const uint16_t* a${M} = (const uint16_t*) ((uintptr_t) a${M-1} + a_stride); 41 uint16_t* c${M} = (uint16_t*) ((uintptr_t) c${M-1} + cm_stride); 42 $if M % 2 == 0: 43 if XNN_UNPREDICTABLE(mr <= ${M}) { 44 a${M} = a${M-1}; 45 c${M} = c${M-1}; 46 } 47 $elif M + 1 == MR: 48 if XNN_UNPREDICTABLE(mr != ${M+1}) { 49 a${M} = a${M-1}; 50 c${M} = c${M-1}; 51 } 52 $else: 53 if XNN_UNPREDICTABLE(mr < ${M+1}) { 54 a${M} = a${M-1}; 55 c${M} = c${M-1}; 56 } 57 58 do { 59 __m256 vacc0x${ABC[0:8]} = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w)); 60 $for N in range(8, NR, 8): 61 __m256 vacc0x${ABC[N:N+8]} = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + ${N}))); 62 $for M in range(1, MR): 63 $for N in range(0, NR, 8): 64 __m256 vacc${M}x${ABC[N:N+8]} = vacc0x${ABC[N:N+8]}; 65 w = (const uint16_t*) w + ${NR}; 66 67 size_t k = kc; 68 do { 69 $for M in range(MR): 70 const __m256 va${M} = _mm256_cvtph_ps(_mm_set1_epi16((short) *a${M})); 71 a${M} += 1; 72 73 const __m256 vb${ABC[0:8]} = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w)); 74 $for N in range(8, NR, 8): 75 const __m256 vb${ABC[N:N+8]} = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + ${N}))); 76 w = (const uint16_t*) w + ${NR}; 77 78 $for N in range(0, NR, 8): 79 $for M in range(MR): 80 vacc${M}x${ABC[N:N+8]} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va${M}, vb${ABC[N:N+8]}, vacc${M}x${ABC[N:N+8]}), _MM_FROUND_NO_EXC)); 81 82 k -= sizeof(uint16_t); 83 } while (k != 0); 84 85 const __m256 vmin = _mm256_load_ps(params->avx.min); 86 $for N in range(0, NR, 8): 87 $for M in range(MR): 88 vacc${M}x${ABC[N:N+8]} = _mm256_max_ps(vacc${M}x${ABC[N:N+8]}, vmin); 89 90 const __m256 vmax = _mm256_load_ps(params->avx.max); 91 $for N in range(0, NR, 8): 92 $for M in range(MR): 93 vacc${M}x${ABC[N:N+8]} = _mm256_min_ps(vacc${M}x${ABC[N:N+8]}, vmax); 94 95 if XNN_LIKELY(nc >= ${NR}) { 96 $for M in range(MR): 97 _mm_storeu_si128((__m128i*) c${M}, _mm256_cvtps_ph(vacc${M}x${ABC[0:8]}, _MM_FROUND_NO_EXC)); 98 $for N in range(8, NR, 8): 99 _mm_storeu_si128((__m128i*) (c${M} + ${N}), _mm256_cvtps_ph(vacc${M}x${ABC[N:N+8]}, _MM_FROUND_NO_EXC)); 100 c${M} = (uint16_t*) ((uintptr_t) c${M} + cn_stride); 101 102 $for M in range(MR): 103 a${M} = (const uint16_t*) ((uintptr_t) a${M} - kc); 104 105 nc -= ${NR}; 106 } else { 107 $for LOG2N in reversed(range(NR.bit_length())): 108 $if LOG2N == 3: 109 $for M in range(MR): 110 __m128i vh${M}x${ABC[0:8]} = _mm256_cvtps_ph(vacc${M}x${ABC[0:8]}, _MM_FROUND_NO_EXC); 111 $if NR != 1 << LOG2N: 112 if (nc & ${1 << LOG2N}) { 113 $if LOG2N >= 4: 114 $for M in range(MR): 115 _mm_storeu_si128((__m128i*) c${M}, _mm256_cvtps_ph(vacc${M}x${ABC[0:8]}, _MM_FROUND_NO_EXC)); 116 $for N in range(8, 1 << LOG2N, 8): 117 _mm_storeu_si128((__m128i*) (c${M} + ${N}), _mm256_cvtps_ph(vacc${M}x${ABC[N:N+8]}, _MM_FROUND_NO_EXC)); 118 119 $for M in range(MR): 120 $for N in range(0, 1 << (LOG2N - 1), 8): 121 vacc${M}x${ABC[N:N+8]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+8]}; 122 123 $for M in range(MR): 124 c${M} += ${1 << LOG2N}; 125 $elif LOG2N == 3: 126 $for M in range(MR): 127 _mm_storeu_si128((__m128i*) c${M}, vh${M}x${ABC[0:8]}); 128 129 $for M in range(MR): 130 vh${M}x${ABC[0:8]} = _mm256_cvtps_ph(vacc${M}x${ABC[8:16]}, _MM_FROUND_NO_EXC); 131 132 $for M in range(MR): 133 c${M} += ${1 << LOG2N}; 134 $elif LOG2N == 2: 135 $for M in range(MR): 136 _mm_storel_epi64((__m128i*) c${M}, vh${M}x${ABC[0:8]}); 137 138 $for M in range(MR): 139 vh${M}x${ABC[0:8]} = _mm_unpackhi_epi64(vh${M}x${ABC[0:8]}, vh${M}x${ABC[0:8]}); 140 141 $for M in range(MR): 142 c${M} += 4; 143 $elif LOG2N == 1: 144 $for M in range(MR): 145 _mm_storeu_si32(c${M}, vh${M}x${ABC[0:8]}); 146 147 $for M in range(MR): 148 vh${M}x${ABC[0:8]} = _mm_srli_epi64(vh${M}x${ABC[0:8]}, 32); 149 150 $for M in range(MR): 151 c${M} += 2; 152 $elif LOG2N == 0: 153 $for M in range(MR): 154 *c${M} = (uint16_t) _mm_extract_epi16(vh${M}x${ABC[0:8]}, 0); 155 } 156 157 nc = 0; 158 } 159 } while (nc != 0); 160} 161