1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2022 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker// 3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker$assert BATCH_TILE % 8 == 0 7*4bdc9457SAndroid Build Coastguard Worker$assert BATCH_TILE >= 8 8*4bdc9457SAndroid Build Coastguard Worker$assert DIV_ALGO in ["div", "rcp"] 9*4bdc9457SAndroid Build Coastguard Worker$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10*4bdc9457SAndroid Build Coastguard Worker$SIMD_TILE = BATCH_TILE // 8 11*4bdc9457SAndroid Build Coastguard Worker#include <assert.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker#include <immintrin.h> 14*4bdc9457SAndroid Build Coastguard Worker 15*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/common.h> 16*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/intrinsics-polyfill.h> 17*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/vunary.h> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Workervoid xnn_f16_vsigmoid_ukernel__avx2_rr1_p2_${DIV_ALGO}_x${BATCH_TILE}( 21*4bdc9457SAndroid Build Coastguard Worker size_t batch, 22*4bdc9457SAndroid Build Coastguard Worker const void* input, 23*4bdc9457SAndroid Build Coastguard Worker void* output, 24*4bdc9457SAndroid Build Coastguard Worker const union xnn_f16_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)]) 25*4bdc9457SAndroid Build Coastguard Worker{ 26*4bdc9457SAndroid Build Coastguard Worker assert(batch % sizeof(uint16_t) == 0); 27*4bdc9457SAndroid Build Coastguard Worker 28*4bdc9457SAndroid Build Coastguard Worker const __m256 vsign_mask = _mm256_load_ps(params->avx2_rr1_p2.sign_mask); 29*4bdc9457SAndroid Build Coastguard Worker const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p2.magic_bias); 30*4bdc9457SAndroid Build Coastguard Worker const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p2.log2e); 31*4bdc9457SAndroid Build Coastguard Worker const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p2.minus_ln2); 32*4bdc9457SAndroid Build Coastguard Worker const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p2.c2); 33*4bdc9457SAndroid Build Coastguard Worker const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p2.c1); 34*4bdc9457SAndroid Build Coastguard Worker const __m256 vone = _mm256_load_ps(params->avx2_rr1_p2.one); 35*4bdc9457SAndroid Build Coastguard Worker const __m256 vdenorm_cutoff = _mm256_load_ps(params->avx2_rr1_p2.denorm_cutoff); 36*4bdc9457SAndroid Build Coastguard Worker 37*4bdc9457SAndroid Build Coastguard Worker const uint16_t* i = (const uint16_t*) input; 38*4bdc9457SAndroid Build Coastguard Worker uint16_t* o = (uint16_t*) output; 39*4bdc9457SAndroid Build Coastguard Worker $if BATCH_TILE > 8: 40*4bdc9457SAndroid Build Coastguard Worker for (; batch >= ${BATCH_TILE} * sizeof(uint16_t); batch -= ${BATCH_TILE} * sizeof(uint16_t)) { 41*4bdc9457SAndroid Build Coastguard Worker const __m256 vx${ABC[0]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 42*4bdc9457SAndroid Build Coastguard Worker $for N in range(1, SIMD_TILE): 43*4bdc9457SAndroid Build Coastguard Worker const __m256 vx${ABC[N]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + ${N * 8}))); 44*4bdc9457SAndroid Build Coastguard Worker i += ${BATCH_TILE}; 45*4bdc9457SAndroid Build Coastguard Worker 46*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 47*4bdc9457SAndroid Build Coastguard Worker const __m256 vz${ABC[N]} = _mm256_or_ps(vx${ABC[N]}, vsign_mask); 48*4bdc9457SAndroid Build Coastguard Worker 49*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 50*4bdc9457SAndroid Build Coastguard Worker __m256 vn${ABC[N]} = _mm256_fmadd_ps(vz${ABC[N]}, vlog2e, vmagic_bias); 51*4bdc9457SAndroid Build Coastguard Worker 52*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 53*4bdc9457SAndroid Build Coastguard Worker const __m256 vs${ABC[N]} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${ABC[N]}), 23)); 54*4bdc9457SAndroid Build Coastguard Worker 55*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 56*4bdc9457SAndroid Build Coastguard Worker vn${ABC[N]} = _mm256_sub_ps(vn${ABC[N]}, vmagic_bias); 57*4bdc9457SAndroid Build Coastguard Worker 58*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 59*4bdc9457SAndroid Build Coastguard Worker __m256 vt${ABC[N]} = _mm256_fmadd_ps(vn${ABC[N]}, vminus_ln2, vz${ABC[N]}); 60*4bdc9457SAndroid Build Coastguard Worker 61*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 62*4bdc9457SAndroid Build Coastguard Worker const __m256 vp${ABC[N]} = _mm256_fmadd_ps(vc2, vt${ABC[N]}, vc1); 63*4bdc9457SAndroid Build Coastguard Worker 64*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 65*4bdc9457SAndroid Build Coastguard Worker vt${ABC[N]} = _mm256_mul_ps(vt${ABC[N]}, vs${ABC[N]}); 66*4bdc9457SAndroid Build Coastguard Worker 67*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 68*4bdc9457SAndroid Build Coastguard Worker const __m256 ve${ABC[N]} = _mm256_fmadd_ps(vt${ABC[N]}, vp${ABC[N]}, vs${ABC[N]}); 69*4bdc9457SAndroid Build Coastguard Worker 70*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 71*4bdc9457SAndroid Build Coastguard Worker const __m256 vd${ABC[N]} = _mm256_add_ps(ve${ABC[N]}, vone); 72*4bdc9457SAndroid Build Coastguard Worker 73*4bdc9457SAndroid Build Coastguard Worker $if DIV_ALGO == "div": 74*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 75*4bdc9457SAndroid Build Coastguard Worker __m256 vf${ABC[N]} = _mm256_div_ps(ve${ABC[N]}, vd${ABC[N]}); 76*4bdc9457SAndroid Build Coastguard Worker $else: 77*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 78*4bdc9457SAndroid Build Coastguard Worker const __m256 vr${ABC[N]} = _mm256_rcp_ps(vd${ABC[N]}); 79*4bdc9457SAndroid Build Coastguard Worker 80*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 81*4bdc9457SAndroid Build Coastguard Worker __m256 vf${ABC[N]} = _mm256_mul_ps(ve${ABC[N]}, vr${ABC[N]}); 82*4bdc9457SAndroid Build Coastguard Worker 83*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 84*4bdc9457SAndroid Build Coastguard Worker vf${ABC[N]} = _mm256_andnot_ps(_mm256_cmp_ps(vz${ABC[N]}, vdenorm_cutoff, _CMP_LT_OS), vf${ABC[N]}); 85*4bdc9457SAndroid Build Coastguard Worker 86*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 87*4bdc9457SAndroid Build Coastguard Worker vf${ABC[N]} = _mm256_blendv_ps(_mm256_sub_ps(vone, vf${ABC[N]}), vf${ABC[N]}, vx${ABC[N]}); 88*4bdc9457SAndroid Build Coastguard Worker 89*4bdc9457SAndroid Build Coastguard Worker _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf${ABC[0]}, _MM_FROUND_NO_EXC)); 90*4bdc9457SAndroid Build Coastguard Worker $for N in range(1, SIMD_TILE): 91*4bdc9457SAndroid Build Coastguard Worker _mm_storeu_si128((__m128i*) (o + ${N * 8}), _mm256_cvtps_ph(vf${ABC[N]}, _MM_FROUND_NO_EXC)); 92*4bdc9457SAndroid Build Coastguard Worker o += ${BATCH_TILE}; 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) { 95*4bdc9457SAndroid Build Coastguard Worker const __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 96*4bdc9457SAndroid Build Coastguard Worker i += 8; 97*4bdc9457SAndroid Build Coastguard Worker 98*4bdc9457SAndroid Build Coastguard Worker const __m256 vz = _mm256_or_ps(vx, vsign_mask); 99*4bdc9457SAndroid Build Coastguard Worker 100*4bdc9457SAndroid Build Coastguard Worker __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 101*4bdc9457SAndroid Build Coastguard Worker const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 102*4bdc9457SAndroid Build Coastguard Worker vn = _mm256_sub_ps(vn, vmagic_bias); 103*4bdc9457SAndroid Build Coastguard Worker 104*4bdc9457SAndroid Build Coastguard Worker __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 105*4bdc9457SAndroid Build Coastguard Worker 106*4bdc9457SAndroid Build Coastguard Worker const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1); 107*4bdc9457SAndroid Build Coastguard Worker vt = _mm256_mul_ps(vt, vs); 108*4bdc9457SAndroid Build Coastguard Worker const __m256 ve = _mm256_fmadd_ps(vt, vp, vs); 109*4bdc9457SAndroid Build Coastguard Worker 110*4bdc9457SAndroid Build Coastguard Worker const __m256 vd = _mm256_add_ps(ve, vone); 111*4bdc9457SAndroid Build Coastguard Worker $if DIV_ALGO == "div": 112*4bdc9457SAndroid Build Coastguard Worker __m256 vf = _mm256_div_ps(ve, vd); 113*4bdc9457SAndroid Build Coastguard Worker $else: 114*4bdc9457SAndroid Build Coastguard Worker const __m256 vr = _mm256_rcp_ps(vd); 115*4bdc9457SAndroid Build Coastguard Worker __m256 vf = _mm256_mul_ps(ve, vr); 116*4bdc9457SAndroid Build Coastguard Worker 117*4bdc9457SAndroid Build Coastguard Worker vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 118*4bdc9457SAndroid Build Coastguard Worker vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 119*4bdc9457SAndroid Build Coastguard Worker 120*4bdc9457SAndroid Build Coastguard Worker _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC)); 121*4bdc9457SAndroid Build Coastguard Worker o += 8; 122*4bdc9457SAndroid Build Coastguard Worker } 123*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(batch != 0) { 124*4bdc9457SAndroid Build Coastguard Worker assert(batch >= 1 * sizeof(uint16_t)); 125*4bdc9457SAndroid Build Coastguard Worker assert(batch <= 7 * sizeof(uint16_t)); 126*4bdc9457SAndroid Build Coastguard Worker const __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i)); 127*4bdc9457SAndroid Build Coastguard Worker 128*4bdc9457SAndroid Build Coastguard Worker const __m256 vz = _mm256_or_ps(vx, vsign_mask); 129*4bdc9457SAndroid Build Coastguard Worker 130*4bdc9457SAndroid Build Coastguard Worker __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias); 131*4bdc9457SAndroid Build Coastguard Worker const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23)); 132*4bdc9457SAndroid Build Coastguard Worker vn = _mm256_sub_ps(vn, vmagic_bias); 133*4bdc9457SAndroid Build Coastguard Worker 134*4bdc9457SAndroid Build Coastguard Worker __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz); 135*4bdc9457SAndroid Build Coastguard Worker 136*4bdc9457SAndroid Build Coastguard Worker const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1); 137*4bdc9457SAndroid Build Coastguard Worker vt = _mm256_mul_ps(vt, vs); 138*4bdc9457SAndroid Build Coastguard Worker const __m256 ve = _mm256_fmadd_ps(vt, vp, vs); 139*4bdc9457SAndroid Build Coastguard Worker 140*4bdc9457SAndroid Build Coastguard Worker const __m256 vd = _mm256_add_ps(ve, vone); 141*4bdc9457SAndroid Build Coastguard Worker $if DIV_ALGO == "div": 142*4bdc9457SAndroid Build Coastguard Worker __m256 vf = _mm256_div_ps(ve, vd); 143*4bdc9457SAndroid Build Coastguard Worker $else: 144*4bdc9457SAndroid Build Coastguard Worker const __m256 vr = _mm256_rcp_ps(vd); 145*4bdc9457SAndroid Build Coastguard Worker __m256 vf = _mm256_mul_ps(ve, vr); 146*4bdc9457SAndroid Build Coastguard Worker 147*4bdc9457SAndroid Build Coastguard Worker vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf); 148*4bdc9457SAndroid Build Coastguard Worker vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx); 149*4bdc9457SAndroid Build Coastguard Worker 150*4bdc9457SAndroid Build Coastguard Worker __m128i vh = _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC); 151*4bdc9457SAndroid Build Coastguard Worker if (batch & (4 * sizeof(uint16_t))) { 152*4bdc9457SAndroid Build Coastguard Worker _mm_storel_epi64((__m128i*) o, vh); 153*4bdc9457SAndroid Build Coastguard Worker vh = _mm_unpackhi_epi64(vh, vh); 154*4bdc9457SAndroid Build Coastguard Worker o += 4; 155*4bdc9457SAndroid Build Coastguard Worker } 156*4bdc9457SAndroid Build Coastguard Worker if (batch & (2 * sizeof(uint16_t))) { 157*4bdc9457SAndroid Build Coastguard Worker _mm_storeu_si32(o, vh); 158*4bdc9457SAndroid Build Coastguard Worker vh = _mm_srli_epi64(vh, 32); 159*4bdc9457SAndroid Build Coastguard Worker o += 2; 160*4bdc9457SAndroid Build Coastguard Worker } 161*4bdc9457SAndroid Build Coastguard Worker if (batch & (1 * sizeof(uint16_t))) { 162*4bdc9457SAndroid Build Coastguard Worker *o = (uint16_t) _mm_extract_epi16(vh, 0); 163*4bdc9457SAndroid Build Coastguard Worker } 164*4bdc9457SAndroid Build Coastguard Worker } 165*4bdc9457SAndroid Build Coastguard Worker} 166