1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 8 == 0 7$assert BATCH_TILE >= 8 8$assert DIV_ALGO in ["div", "rcp"] 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10$SIMD_TILE = BATCH_TILE // 8 11#include <assert.h> 12 13#include <immintrin.h> 14 15#include <xnnpack/common.h> 16#include <xnnpack/intrinsics-polyfill.h> 17#include <xnnpack/vunary.h> 18 19 20void xnn_f16_vsigmoid_ukernel__avx2_rr1_p2_${DIV_ALGO}_x${BATCH_TILE}( 21 size_t batch, 22 const void* input, 23 void* output, 24 const union xnn_f16_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)]) 25{ 26 assert(batch % sizeof(uint16_t) == 0); 27 28 const __m256 vsign_mask = _mm256_load_ps(params->avx2_rr1_p2.sign_mask); 29 const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p2.magic_bias); 30 const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p2.log2e); 31 const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p2.minus_ln2); 32 const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p2.c2); 33 const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p2.c1); 34 const __m256 vone = _mm256_load_ps(params->avx2_rr1_p2.one); 35 const __m256 vdenorm_cutoff = _mm256_load_ps(params->avx2_rr1_p2.denorm_cutoff); 36 37 const uint16_t* i = (const uint16_t*) input; 38 uint16_t* o = (uint16_t*) output; 39 $if BATCH_TILE > 8: 40 for (; batch >= ${BATCH_TILE} * sizeof(uint16_t); batch -= ${BATCH_TILE} * sizeof(uint16_t)) { 41 const __m256 vx${ABC[0]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 42 $for N in range(1, SIMD_TILE): 43 const __m256 vx${ABC[N]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + ${N * 8}))); 44 i += ${BATCH_TILE}; 45 46 $for N in range(SIMD_TILE): 47 const __m256 vz${ABC[N]} = _mm256_or_ps(vx${ABC[N]}, vsign_mask); 48 49 $for N in range(SIMD_TILE): 50 __m256 vn${ABC[N]} = _mm256_fmadd_ps(vz${ABC[N]}, vlog2e, vmagic_bias); 51 52 $for N in range(SIMD_TILE): 53 const __m256 vs${ABC[N]} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${ABC[N]}), 23)); 54 55 $for N in range(SIMD_TILE): 56 vn${ABC[N]} = _mm256_sub_ps(vn${ABC[N]}, vmagic_bias); 57 58 $for N in range(SIMD_TILE): 59 __m256 vt${ABC[N]} = _mm256_fmadd_ps(vn${ABC[N]}, vminus_ln2, vz${ABC[N]}); 60 61 $for N in range(SIMD_TILE): 62 const __m256 vp${ABC[N]} = _mm256_fmadd_ps(vc2, vt${ABC[N]}, vc1); 63 64 $for N in range(SIMD_TILE): 65 vt${ABC[N]} = _mm256_mul_ps(vt${ABC[N]}, vs${ABC[N]}); 66 67 $for N in range(SIMD_TILE): 68 const __m256 ve${ABC[N]} = _mm256_fmadd_ps(vt${ABC[N]}, vp${ABC[N]}, vs${ABC[N]}); 69 70 $for N in range(SIMD_TILE): 71 const __m256 vd${ABC[N]} = _mm256_add_ps(ve${ABC[N]}, vone); 72 73 $if DIV_ALGO == "div": 74 $for N in range(SIMD_TILE): 75 __m256 vf${ABC[N]} = _mm256_div_ps(ve${ABC[N]}, vd${ABC[N]}); 76 $else: 77 $for N in range(SIMD_TILE): 78 const __m256 vr${ABC[N]} = _mm256_rcp_ps(vd${ABC[N]}); 79 80 $for N in range(SIMD_TILE): 81 __m256 vf${ABC[N]} = _mm256_mul_ps(ve${ABC[N]}, vr${ABC[N]}); 82 83 $for N in range(SIMD_TILE): 84 vf${ABC[N]} = _mm256_andnot_ps(_mm256_cmp_ps(vz${ABC[N]}, vdenorm_cutoff, _CMP_LT_OS), vf${ABC[N]}); 85 86 $for N in range(SIMD_TILE): 87 vf${ABC[N]} = _mm256_blendv_ps(_mm256_sub_ps(vone, vf${ABC[N]}), vf${ABC[N]}, vx${ABC[N]}); 88 89 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf${ABC[0]}, _MM_FROUND_NO_EXC)); 90 $for N in range(1, SIMD_TILE): 91 _mm_storeu_si128((__m128i*) (o + ${N * 8}), _mm256_cvtps_ph(vf${ABC[N]}, _MM_FROUND_NO_EXC)); 92 o += ${BATCH_TILE}; 93 } 94 for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) { 95 const __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 96 i += 8; 97 98 const __m256 vz = _mm256_or_ps(vx, vsign_mask); 99 100 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 101 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 102 vn = _mm256_sub_ps(vn, vmagic_bias); 103 104 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 105 106 const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1); 107 vt = _mm256_mul_ps(vt, vs); 108 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs); 109 110 const __m256 vd = _mm256_add_ps(ve, vone); 111 $if DIV_ALGO == "div": 112 __m256 vf = _mm256_div_ps(ve, vd); 113 $else: 114 const __m256 vr = _mm256_rcp_ps(vd); 115 __m256 vf = _mm256_mul_ps(ve, vr); 116 117 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 118 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 119 120 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC)); 121 o += 8; 122 } 123 if XNN_UNLIKELY(batch != 0) { 124 assert(batch >= 1 * sizeof(uint16_t)); 125 assert(batch <= 7 * sizeof(uint16_t)); 126 const __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 127 128 const __m256 vz = _mm256_or_ps(vx, vsign_mask); 129 130 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 131 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 132 vn = _mm256_sub_ps(vn, vmagic_bias); 133 134 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 135 136 const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1); 137 vt = _mm256_mul_ps(vt, vs); 138 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs); 139 140 const __m256 vd = _mm256_add_ps(ve, vone); 141 $if DIV_ALGO == "div": 142 __m256 vf = _mm256_div_ps(ve, vd); 143 $else: 144 const __m256 vr = _mm256_rcp_ps(vd); 145 __m256 vf = _mm256_mul_ps(ve, vr); 146 147 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 148 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 149 150 __m128i vh = _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC); 151 if (batch & (4 * sizeof(uint16_t))) { 152 _mm_storel_epi64((__m128i*) o, vh); 153 vh = _mm_unpackhi_epi64(vh, vh); 154 o += 4; 155 } 156 if (batch & (2 * sizeof(uint16_t))) { 157 _mm_storeu_si32(o, vh); 158 vh = _mm_srli_epi64(vh, 32); 159 o += 2; 160 } 161 if (batch & (1 * sizeof(uint16_t))) { 162 *o = (uint16_t) _mm_extract_epi16(vh, 0); 163 } 164 } 165} 166