1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker// 3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker$assert ELEMENTS_TILE % 16 == 0 7*4bdc9457SAndroid Build Coastguard Worker$assert ELEMENTS_TILE >= 16 8*4bdc9457SAndroid Build Coastguard Worker$SIMD_TILE = ELEMENTS_TILE // 16 9*4bdc9457SAndroid Build Coastguard Worker$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10*4bdc9457SAndroid Build Coastguard Worker#include <assert.h> 11*4bdc9457SAndroid Build Coastguard Worker#include <math.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker#include <immintrin.h> 14*4bdc9457SAndroid Build Coastguard Worker 15*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/common.h> 16*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/intrinsics-polyfill.h> 17*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/raddextexp.h> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Workervoid xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_x${ELEMENTS_TILE}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}( 21*4bdc9457SAndroid Build Coastguard Worker size_t elements, 22*4bdc9457SAndroid Build Coastguard Worker const float* x, 23*4bdc9457SAndroid Build Coastguard Worker float* sum) 24*4bdc9457SAndroid Build Coastguard Worker{ 25*4bdc9457SAndroid Build Coastguard Worker assert(elements % sizeof(float) == 0); 26*4bdc9457SAndroid Build Coastguard Worker 27*4bdc9457SAndroid Build Coastguard Worker const __m512 vlog2e = _mm512_set1_ps(0x1.715476p+0f); 28*4bdc9457SAndroid Build Coastguard Worker const __m512 vminus_ln2_hi = _mm512_set1_ps(-0x1.62E43p-1f); 29*4bdc9457SAndroid Build Coastguard Worker const __m512 vminus_ln2_lo = _mm512_set1_ps(0x1.05C61p-29f); 30*4bdc9457SAndroid Build Coastguard Worker 31*4bdc9457SAndroid Build Coastguard Worker const __m512 vc0 = _mm512_set1_ps(1.0f); 32*4bdc9457SAndroid Build Coastguard Worker const __m512 vc1 = _mm512_set1_ps(0x1.FFFFF6p-1f); 33*4bdc9457SAndroid Build Coastguard Worker const __m512 vc2 = _mm512_set1_ps(0x1.FFFDC6p-2f); 34*4bdc9457SAndroid Build Coastguard Worker const __m512 vc3 = _mm512_set1_ps(0x1.555A80p-3f); 35*4bdc9457SAndroid Build Coastguard Worker const __m512 vc4 = _mm512_set1_ps(0x1.573A1Ap-5f); 36*4bdc9457SAndroid Build Coastguard Worker const __m512 vc5 = _mm512_set1_ps(0x1.0F9F9Cp-7f); 37*4bdc9457SAndroid Build Coastguard Worker 38*4bdc9457SAndroid Build Coastguard Worker const __m512 vminus_inf = _mm512_set1_ps(-INFINITY); 39*4bdc9457SAndroid Build Coastguard Worker 40*4bdc9457SAndroid Build Coastguard Worker $for K in range(ACCUMULATORS): 41*4bdc9457SAndroid Build Coastguard Worker __m512 vaccv${K} = _mm512_setzero_ps(); 42*4bdc9457SAndroid Build Coastguard Worker $for K in range(ACCUMULATORS): 43*4bdc9457SAndroid Build Coastguard Worker __m512 vacce${K} = vminus_inf; 44*4bdc9457SAndroid Build Coastguard Worker for (; elements >= ${ELEMENTS_TILE} * sizeof(float); elements -= ${ELEMENTS_TILE} * sizeof(float)) { 45*4bdc9457SAndroid Build Coastguard Worker // Load ${ELEMENTS_TILE} (${SIMD_TILE}x16) inputs at a time. 46*4bdc9457SAndroid Build Coastguard Worker const __m512 vx0 = _mm512_loadu_ps(x); 47*4bdc9457SAndroid Build Coastguard Worker $for N in range(1, SIMD_TILE): 48*4bdc9457SAndroid Build Coastguard Worker const __m512 vx${N} = _mm512_loadu_ps(x + ${N * 16}); 49*4bdc9457SAndroid Build Coastguard Worker x += ${ELEMENTS_TILE}; 50*4bdc9457SAndroid Build Coastguard Worker 51*4bdc9457SAndroid Build Coastguard Worker // Compute reduced argument elements := round(x / log(2)). 52*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 53*4bdc9457SAndroid Build Coastguard Worker const __m512 vn${N} = _mm512_roundscale_ps(_mm512_mul_ps(vx${N}, vlog2e), 0); 54*4bdc9457SAndroid Build Coastguard Worker 55*4bdc9457SAndroid Build Coastguard Worker // Compute reduced argument t := x - elements * log(2). 56*4bdc9457SAndroid Build Coastguard Worker // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy. 57*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 58*4bdc9457SAndroid Build Coastguard Worker __m512 vt${N} = _mm512_fmadd_ps(vn${N}, vminus_ln2_hi, vx${N}); 59*4bdc9457SAndroid Build Coastguard Worker 60*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 61*4bdc9457SAndroid Build Coastguard Worker vt${N} = _mm512_fmadd_ps(vn${N}, vminus_ln2_lo, vt${N}); 62*4bdc9457SAndroid Build Coastguard Worker 63*4bdc9457SAndroid Build Coastguard Worker // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2]. 64*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 65*4bdc9457SAndroid Build Coastguard Worker __m512 vp${N} = _mm512_fmadd_ps(vc5, vt${N}, vc4); 66*4bdc9457SAndroid Build Coastguard Worker 67*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 68*4bdc9457SAndroid Build Coastguard Worker vp${N} = _mm512_fmadd_ps(vp${N}, vt${N}, vc3); 69*4bdc9457SAndroid Build Coastguard Worker 70*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 71*4bdc9457SAndroid Build Coastguard Worker vp${N} = _mm512_fmadd_ps(vp${N}, vt${N}, vc2); 72*4bdc9457SAndroid Build Coastguard Worker 73*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 74*4bdc9457SAndroid Build Coastguard Worker vp${N} = _mm512_fmadd_ps(vp${N}, vt${N}, vc1); 75*4bdc9457SAndroid Build Coastguard Worker 76*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 77*4bdc9457SAndroid Build Coastguard Worker vp${N} = _mm512_fmadd_ps(vp${N}, vt${N}, vc0); 78*4bdc9457SAndroid Build Coastguard Worker 79*4bdc9457SAndroid Build Coastguard Worker // Accumulate "extended" floating-point numbers in ("mantissa", "exponent") representation where 80*4bdc9457SAndroid Build Coastguard Worker // - vnX is "exponent" 81*4bdc9457SAndroid Build Coastguard Worker // - vpX is "mantissa" 82*4bdc9457SAndroid Build Coastguard Worker // 83*4bdc9457SAndroid Build Coastguard Worker // exp2(ae) * av + exp2(be) * bv = 84*4bdc9457SAndroid Build Coastguard Worker // = exp2(max(ae, be)) * exp2(ae - max(ae, be)) * av + exp2(max(ae, be)) * exp2(be - max(ae, be)) * bv 85*4bdc9457SAndroid Build Coastguard Worker // = exp2(max_e) * (exp2(ae - max_e) * av + exp2(be - max_e) * bv) 86*4bdc9457SAndroid Build Coastguard Worker // = exp2(max_e) * (exp2(delta_ae) * av + exp2(delta_be) * bv) 87*4bdc9457SAndroid Build Coastguard Worker // 88*4bdc9457SAndroid Build Coastguard Worker // For computational efficiency we add three "extended" floating-point numbers at a time. 89*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 90*4bdc9457SAndroid Build Coastguard Worker $if N < ACCUMULATORS: 91*4bdc9457SAndroid Build Coastguard Worker __m512 vmax_e${N} = _mm512_max_ps(vacce${N}, vn${N}); 92*4bdc9457SAndroid Build Coastguard Worker $else: 93*4bdc9457SAndroid Build Coastguard Worker vmax_e${N % ACCUMULATORS} = _mm512_max_ps(vmax_e${N % ACCUMULATORS}, vn${N}); 94*4bdc9457SAndroid Build Coastguard Worker 95*4bdc9457SAndroid Build Coastguard Worker $for K in range(ACCUMULATORS): 96*4bdc9457SAndroid Build Coastguard Worker const __m512 vdelta_acce${K} = _mm512_sub_ps(vacce${K}, vmax_e${K}); 97*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 98*4bdc9457SAndroid Build Coastguard Worker const __m512 vdelta_e${N} = _mm512_sub_ps(vn${N}, vmax_e${N % ACCUMULATORS}); 99*4bdc9457SAndroid Build Coastguard Worker 100*4bdc9457SAndroid Build Coastguard Worker // Update accumulated "mantissa" and "exponent" values 101*4bdc9457SAndroid Build Coastguard Worker $for K in range(ACCUMULATORS): 102*4bdc9457SAndroid Build Coastguard Worker vaccv${K} = _mm512_scalef_ps(vaccv${K}, vdelta_acce${K}); 103*4bdc9457SAndroid Build Coastguard Worker $for N in range(SIMD_TILE): 104*4bdc9457SAndroid Build Coastguard Worker vaccv${N % ACCUMULATORS} = _mm512_add_ps(vaccv${N % ACCUMULATORS}, _mm512_scalef_ps(vp${N}, vdelta_e${N})); 105*4bdc9457SAndroid Build Coastguard Worker 106*4bdc9457SAndroid Build Coastguard Worker $for K in range(ACCUMULATORS): 107*4bdc9457SAndroid Build Coastguard Worker vacce${K} = vmax_e${K}; 108*4bdc9457SAndroid Build Coastguard Worker } 109*4bdc9457SAndroid Build Coastguard Worker 110*4bdc9457SAndroid Build Coastguard Worker // Reduce partial sums of "extended" floating-point numbers into a single "extended" SIMD vector of sums. 111*4bdc9457SAndroid Build Coastguard Worker $if ACCUMULATORS > 1: 112*4bdc9457SAndroid Build Coastguard Worker $for A in range(0, ACCUMULATORS, 2): 113*4bdc9457SAndroid Build Coastguard Worker $if A + 1 < ACCUMULATORS: 114*4bdc9457SAndroid Build Coastguard Worker const __m512 vmax_acce${ABC[A:A+2]} = _mm512_max_ps(vacce${A}, vacce${A+1}); 115*4bdc9457SAndroid Build Coastguard Worker $else: 116*4bdc9457SAndroid Build Coastguard Worker const __m512 vmax_acce${ABC[A]} = vacce${A}; 117*4bdc9457SAndroid Build Coastguard Worker $ACC_SLICE = 2 118*4bdc9457SAndroid Build Coastguard Worker $while ACC_SLICE < ACCUMULATORS: 119*4bdc9457SAndroid Build Coastguard Worker $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 120*4bdc9457SAndroid Build Coastguard Worker $if A + ACC_SLICE < ACCUMULATORS: 121*4bdc9457SAndroid Build Coastguard Worker const __m512 vmax_acce${ABC[A:min(A+ACC_SLICE*2, ACCUMULATORS)]} = _mm512_max_ps(vmax_acce${ABC[A:A+ACC_SLICE]}, vmax_acce${ABC[A+ACC_SLICE:min(ACCUMULATORS,A+ACC_SLICE*2)]}); 122*4bdc9457SAndroid Build Coastguard Worker $ACC_SLICE *= 2 123*4bdc9457SAndroid Build Coastguard Worker 124*4bdc9457SAndroid Build Coastguard Worker $for K in range(ACCUMULATORS): 125*4bdc9457SAndroid Build Coastguard Worker const __m512 vdelta_acce${K} = _mm512_sub_ps(vacce${K}, vmax_acce${ABC[0:ACCUMULATORS]}); 126*4bdc9457SAndroid Build Coastguard Worker 127*4bdc9457SAndroid Build Coastguard Worker __m512 vaccv = _mm512_scalef_ps(vaccv0, vdelta_acce0); 128*4bdc9457SAndroid Build Coastguard Worker $for K in range(1, ACCUMULATORS): 129*4bdc9457SAndroid Build Coastguard Worker vaccv = _mm512_add_ps(vaccv, _mm512_scalef_ps(vaccv${K}, vdelta_acce${K})); 130*4bdc9457SAndroid Build Coastguard Worker __m512 vacce = vmax_acce${ABC[0:ACCUMULATORS]}; 131*4bdc9457SAndroid Build Coastguard Worker $else: 132*4bdc9457SAndroid Build Coastguard Worker __m512 vaccv = vaccv0; 133*4bdc9457SAndroid Build Coastguard Worker __m512 vacce = vacce0; 134*4bdc9457SAndroid Build Coastguard Worker 135*4bdc9457SAndroid Build Coastguard Worker for (; elements >= 16 * sizeof(float); elements -= 16 * sizeof(float)) { 136*4bdc9457SAndroid Build Coastguard Worker // Load 16 inputs at a time. 137*4bdc9457SAndroid Build Coastguard Worker const __m512 vx = _mm512_loadu_ps(x); 138*4bdc9457SAndroid Build Coastguard Worker x += 16; 139*4bdc9457SAndroid Build Coastguard Worker 140*4bdc9457SAndroid Build Coastguard Worker // Compute reduced argument elements := round(x / log(2)). 141*4bdc9457SAndroid Build Coastguard Worker const __m512 vn = _mm512_roundscale_ps(_mm512_mul_ps(vx, vlog2e), 0); 142*4bdc9457SAndroid Build Coastguard Worker 143*4bdc9457SAndroid Build Coastguard Worker // Compute reduced argument t := x - elements * log(2). 144*4bdc9457SAndroid Build Coastguard Worker // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy. 145*4bdc9457SAndroid Build Coastguard Worker __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2_hi, vx); 146*4bdc9457SAndroid Build Coastguard Worker vt = _mm512_fmadd_ps(vn, vminus_ln2_lo, vt); 147*4bdc9457SAndroid Build Coastguard Worker 148*4bdc9457SAndroid Build Coastguard Worker // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2]. 149*4bdc9457SAndroid Build Coastguard Worker __m512 vp = _mm512_fmadd_ps(vc5, vt, vc4); 150*4bdc9457SAndroid Build Coastguard Worker vp = _mm512_fmadd_ps(vp, vt, vc3); 151*4bdc9457SAndroid Build Coastguard Worker vp = _mm512_fmadd_ps(vp, vt, vc2); 152*4bdc9457SAndroid Build Coastguard Worker vp = _mm512_fmadd_ps(vp, vt, vc1); 153*4bdc9457SAndroid Build Coastguard Worker vp = _mm512_fmadd_ps(vp, vt, vc0); 154*4bdc9457SAndroid Build Coastguard Worker 155*4bdc9457SAndroid Build Coastguard Worker // Accumulate "extended" floating-point numbers in ("mantissa", "exponent") representation. 156*4bdc9457SAndroid Build Coastguard Worker const __m512 vmax_e = _mm512_max_ps(vacce, vn); 157*4bdc9457SAndroid Build Coastguard Worker const __m512 vdelta_acce = _mm512_sub_ps(vacce, vmax_e); 158*4bdc9457SAndroid Build Coastguard Worker const __m512 vdelta_e = _mm512_sub_ps(vn, vmax_e); 159*4bdc9457SAndroid Build Coastguard Worker vaccv = _mm512_scalef_ps(vaccv, vdelta_acce); 160*4bdc9457SAndroid Build Coastguard Worker vaccv = _mm512_add_ps(vaccv, _mm512_scalef_ps(vp, vdelta_e)); 161*4bdc9457SAndroid Build Coastguard Worker 162*4bdc9457SAndroid Build Coastguard Worker vacce = vmax_e; 163*4bdc9457SAndroid Build Coastguard Worker } 164*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(elements != 0) { 165*4bdc9457SAndroid Build Coastguard Worker // Prepare mask for valid 32-bit elements (depends on elements). 166*4bdc9457SAndroid Build Coastguard Worker elements >>= 2 /* log2(sizeof(float)) */; 167*4bdc9457SAndroid Build Coastguard Worker const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << elements) - UINT32_C(1))); 168*4bdc9457SAndroid Build Coastguard Worker 169*4bdc9457SAndroid Build Coastguard Worker // Load up to 15 inputs at a time. 170*4bdc9457SAndroid Build Coastguard Worker const __m512 vx = _mm512_maskz_loadu_ps(vmask, x); 171*4bdc9457SAndroid Build Coastguard Worker 172*4bdc9457SAndroid Build Coastguard Worker // Compute reduced argument elements := round(x / log(2)). 173*4bdc9457SAndroid Build Coastguard Worker const __m512 vn = _mm512_roundscale_ps(_mm512_mul_ps(vx, vlog2e), 0); 174*4bdc9457SAndroid Build Coastguard Worker 175*4bdc9457SAndroid Build Coastguard Worker // Compute reduced argument t := x - elements * log(2). 176*4bdc9457SAndroid Build Coastguard Worker // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy. 177*4bdc9457SAndroid Build Coastguard Worker __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2_hi, vx); 178*4bdc9457SAndroid Build Coastguard Worker vt = _mm512_fmadd_ps(vn, vminus_ln2_lo, vt); 179*4bdc9457SAndroid Build Coastguard Worker 180*4bdc9457SAndroid Build Coastguard Worker // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2]. 181*4bdc9457SAndroid Build Coastguard Worker __m512 vp = _mm512_fmadd_ps(vc5, vt, vc4); 182*4bdc9457SAndroid Build Coastguard Worker vp = _mm512_fmadd_ps(vp, vt, vc3); 183*4bdc9457SAndroid Build Coastguard Worker vp = _mm512_fmadd_ps(vp, vt, vc2); 184*4bdc9457SAndroid Build Coastguard Worker vp = _mm512_fmadd_ps(vp, vt, vc1); 185*4bdc9457SAndroid Build Coastguard Worker vp = _mm512_fmadd_ps(vp, vt, vc0); 186*4bdc9457SAndroid Build Coastguard Worker 187*4bdc9457SAndroid Build Coastguard Worker // Accumulate "extended" floating-point numbers in ("mantissa", "exponent") representation. 188*4bdc9457SAndroid Build Coastguard Worker const __m512 vmax_e = _mm512_mask_max_ps(vacce, vmask, vacce, vn); 189*4bdc9457SAndroid Build Coastguard Worker const __m512 vdelta_acce = _mm512_sub_ps(vacce, vmax_e); 190*4bdc9457SAndroid Build Coastguard Worker const __m512 vdelta_e = _mm512_sub_ps(vn, vmax_e); 191*4bdc9457SAndroid Build Coastguard Worker vaccv = _mm512_mask_scalef_ps(vaccv, vmask, vaccv, vdelta_acce); 192*4bdc9457SAndroid Build Coastguard Worker vaccv = _mm512_mask_add_ps(vaccv, vmask, vaccv, _mm512_maskz_scalef_ps(vmask, vp, vdelta_e)); 193*4bdc9457SAndroid Build Coastguard Worker vacce = vmax_e; 194*4bdc9457SAndroid Build Coastguard Worker } 195*4bdc9457SAndroid Build Coastguard Worker 196*4bdc9457SAndroid Build Coastguard Worker // Reduce partial sums of "extended" floating-point numbers into a single "extended" floating-point sum. 197*4bdc9457SAndroid Build Coastguard Worker const float vmax_acce = _mm512_reduce_max_ps(vacce); 198*4bdc9457SAndroid Build Coastguard Worker const __m512 vdelta_acce = _mm512_sub_ps(vacce, _mm512_set1_ps(vmax_acce)); 199*4bdc9457SAndroid Build Coastguard Worker 200*4bdc9457SAndroid Build Coastguard Worker sum[0] = _mm512_reduce_add_ps(_mm512_scalef_ps(vaccv, vdelta_acce)); 201*4bdc9457SAndroid Build Coastguard Worker sum[1] = vmax_acce; 202*4bdc9457SAndroid Build Coastguard Worker 203*4bdc9457SAndroid Build Coastguard Worker _mm256_zeroupper(); 204*4bdc9457SAndroid Build Coastguard Worker} 205