1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 8 == 0 7$assert BATCH_TILE >= 8 8$assert DIV_ALGO in ["div", "nr1fma", "nr2fma"] 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10$SIMD_TILE = BATCH_TILE // 8 11#include <assert.h> 12 13#include <immintrin.h> 14 15#include <xnnpack/common.h> 16#include <xnnpack/vunary.h> 17 18 19void xnn_f32_vsigmoid_ukernel__avx2_rr1_p5_${DIV_ALGO}_x${BATCH_TILE}( 20 size_t n, 21 const float* x, 22 float* y, 23 const union xnn_f32_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)]) 24{ 25 assert(n % sizeof(float) == 0); 26 27 const __m256 vsign_mask = _mm256_load_ps(params->avx2_rr1_p5.sign_mask); 28 const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p5.magic_bias); 29 const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p5.log2e); 30 const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p5.minus_ln2); 31 const __m256 vc5 = _mm256_load_ps(params->avx2_rr1_p5.c5); 32 const __m256 vc4 = _mm256_load_ps(params->avx2_rr1_p5.c4); 33 const __m256 vc3 = _mm256_load_ps(params->avx2_rr1_p5.c3); 34 const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p5.c2); 35 const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p5.c1); 36 const __m256 vone = _mm256_load_ps(params->avx2_rr1_p5.one); 37 const __m256 vdenorm_cutoff = _mm256_load_ps(params->avx2_rr1_p5.denorm_cutoff); 38 39 $if BATCH_TILE > 8: 40 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 41 const __m256 vx${ABC[0]} = _mm256_loadu_ps(x); 42 $for N in range(1, SIMD_TILE): 43 const __m256 vx${ABC[N]} = _mm256_loadu_ps(x + ${N * 8}); 44 x += ${BATCH_TILE}; 45 46 $for N in range(SIMD_TILE): 47 const __m256 vz${ABC[N]} = _mm256_or_ps(vx${ABC[N]}, vsign_mask); 48 49 $for N in range(SIMD_TILE): 50 __m256 vn${ABC[N]} = _mm256_fmadd_ps(vz${ABC[N]}, vlog2e, vmagic_bias); 51 52 $for N in range(SIMD_TILE): 53 const __m256 vs${ABC[N]} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${ABC[N]}), 23)); 54 55 $for N in range(SIMD_TILE): 56 vn${ABC[N]} = _mm256_sub_ps(vn${ABC[N]}, vmagic_bias); 57 58 $for N in range(SIMD_TILE): 59 __m256 vt${ABC[N]} = _mm256_fmadd_ps(vn${ABC[N]}, vminus_ln2, vz${ABC[N]}); 60 61 $for N in range(SIMD_TILE): 62 __m256 vp${ABC[N]} = _mm256_fmadd_ps(vc5, vt${ABC[N]}, vc4); 63 64 $for N in range(SIMD_TILE): 65 vp${ABC[N]} = _mm256_fmadd_ps(vp${ABC[N]}, vt${ABC[N]}, vc3); 66 67 $for N in range(SIMD_TILE): 68 vp${ABC[N]} = _mm256_fmadd_ps(vp${ABC[N]}, vt${ABC[N]}, vc2); 69 70 $for N in range(SIMD_TILE): 71 vp${ABC[N]} = _mm256_fmadd_ps(vp${ABC[N]}, vt${ABC[N]}, vc1); 72 73 $for N in range(SIMD_TILE): 74 vt${ABC[N]} = _mm256_mul_ps(vt${ABC[N]}, vs${ABC[N]}); 75 76 $for N in range(SIMD_TILE): 77 const __m256 ve${ABC[N]} = _mm256_fmadd_ps(vt${ABC[N]}, vp${ABC[N]}, vs${ABC[N]}); 78 79 $for N in range(SIMD_TILE): 80 const __m256 vd${ABC[N]} = _mm256_add_ps(ve${ABC[N]}, vone); 81 82 $if DIV_ALGO == "div": 83 $for N in range(SIMD_TILE): 84 __m256 vf${ABC[N]} = _mm256_div_ps(ve${ABC[N]}, vd${ABC[N]}); 85 $else: 86 $for N in range(SIMD_TILE): 87 __m256 vr${ABC[N]} = _mm256_rcp_ps(vd${ABC[N]}); 88 89 $for N in range(SIMD_TILE): 90 vr${ABC[N]} = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr${ABC[N]}, vd${ABC[N]}, vone), vr${ABC[N]}, vr${ABC[N]}); 91 92 $if DIV_ALGO == "nr2fma": 93 $for N in range(SIMD_TILE): 94 vr${ABC[N]} = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr${ABC[N]}, vd${ABC[N]}, vone), vr${ABC[N]}, vr${ABC[N]}); 95 96 $for N in range(SIMD_TILE): 97 __m256 vf${ABC[N]} = _mm256_mul_ps(ve${ABC[N]}, vr${ABC[N]}); 98 99 $for N in range(SIMD_TILE): 100 vf${ABC[N]} = _mm256_andnot_ps(_mm256_cmp_ps(vz${ABC[N]}, vdenorm_cutoff, _CMP_LT_OS), vf${ABC[N]}); 101 102 $for N in range(SIMD_TILE): 103 vf${ABC[N]} = _mm256_blendv_ps(_mm256_sub_ps(vone, vf${ABC[N]}), vf${ABC[N]}, vx${ABC[N]}); 104 105 _mm256_storeu_ps(y, vf${ABC[0]}); 106 $for N in range(1, SIMD_TILE): 107 _mm256_storeu_ps(y + ${N * 8}, vf${ABC[N]}); 108 y += ${BATCH_TILE}; 109 } 110 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) { 111 const __m256 vx = _mm256_loadu_ps(x); 112 x += 8; 113 114 const __m256 vz = _mm256_or_ps(vx, vsign_mask); 115 116 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 117 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 118 vn = _mm256_sub_ps(vn, vmagic_bias); 119 120 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 121 122 __m256 vp = _mm256_fmadd_ps(vc5, vt, vc4); 123 vp = _mm256_fmadd_ps(vp, vt, vc3); 124 vp = _mm256_fmadd_ps(vp, vt, vc2); 125 vp = _mm256_fmadd_ps(vp, vt, vc1); 126 127 vt = _mm256_mul_ps(vt, vs); 128 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs); 129 130 const __m256 vd = _mm256_add_ps(ve, vone); 131 $if DIV_ALGO == "div": 132 __m256 vf = _mm256_div_ps(ve, vd); 133 $else: 134 __m256 vr = _mm256_rcp_ps(vd); 135 vr = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr, vd, vone), vr, vr); 136 $if DIV_ALGO == "nr2fma": 137 vr = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr, vd, vone), vr, vr); 138 __m256 vf = _mm256_mul_ps(ve, vr); 139 140 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 141 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 142 143 _mm256_storeu_ps(y, vf); 144 y += 8; 145 } 146 if XNN_UNLIKELY(n != 0) { 147 assert(n >= 1 * sizeof(float)); 148 assert(n <= 7 * sizeof(float)); 149 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx2_rr1_p5.mask_table[7] - n)); 150 151 const __m256 vx = _mm256_maskload_ps(x, vmask); 152 153 const __m256 vz = _mm256_or_ps(vx, vsign_mask); 154 155 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 156 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 157 vn = _mm256_sub_ps(vn, vmagic_bias); 158 159 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 160 161 __m256 vp = _mm256_fmadd_ps(vc5, vt, vc4); 162 vp = _mm256_fmadd_ps(vp, vt, vc3); 163 vp = _mm256_fmadd_ps(vp, vt, vc2); 164 vp = _mm256_fmadd_ps(vp, vt, vc1); 165 166 vt = _mm256_mul_ps(vt, vs); 167 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs); 168 169 const __m256 vd = _mm256_add_ps(ve, vone); 170 $if DIV_ALGO == "div": 171 __m256 vf = _mm256_div_ps(ve, vd); 172 $else: 173 __m256 vr = _mm256_rcp_ps(vd); 174 vr = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr, vd, vone), vr, vr); 175 $if DIV_ALGO == "nr2fma": 176 vr = _mm256_fmadd_ps(_mm256_fnmadd_ps(vr, vd, vone), vr, vr); 177 __m256 vf = _mm256_mul_ps(ve, vr); 178 179 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 180 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 181 182 __m128 vf_lo = _mm256_castps256_ps128(vf); 183 if (n & (4 * sizeof(float))) { 184 _mm_storeu_ps(y, vf_lo); 185 vf_lo = _mm256_extractf128_ps(vf, 1); 186 y += 4; 187 } 188 if (n & (2 * sizeof(float))) { 189 _mm_storel_pi((__m64*) y, vf_lo); 190 vf_lo = _mm_movehl_ps(vf_lo, vf_lo); 191 y += 2; 192 } 193 if (n & (1 * sizeof(float))) { 194 _mm_store_ss(y, vf_lo); 195 } 196 } 197} 198