1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert ELEMENTS_TILE % 4 == 0 7$assert ELEMENTS_TILE >= 4 8$SIMD_TILE = ELEMENTS_TILE // 4 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10$VMULADDQ_F32 = "vfmaq_f32" if FMA else "vmlaq_f32" 11#include <assert.h> 12 13#include <arm_neon.h> 14 15#include <xnnpack/common.h> 16#include <xnnpack/raddstoreexpminusmax.h> 17 18 19extern XNN_INTERNAL const float xnn_table_exp2_k_over_64[64]; 20 21$PARAMS_STRUCT = "neonfma_rr1_lut64_p2" if FMA else "neon_rr2_lut64_p2" 22void xnn_f32_raddstoreexpminusmax_ukernel__${"neonfma" if FMA else "neon"}_rr${1 if FMA else 2}_lut64_p2_x${ELEMENTS_TILE}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}( 23 size_t elements, 24 const float* input, 25 const float* max, 26 float* output, 27 float* sum, 28 const union xnn_f32_expminus_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 29{ 30 assert(elements % sizeof(float) == 0); 31 32 const float32x4_t vi_max = vld1q_dup_f32(max); 33 const float32x4_t vlog2e = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.log2e); 34 const float32x4_t vmagic_bias = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.magic_bias); 35 const int32x4_t vindex_mask = vmovq_n_s32(INT32_C(0x3F)); 36 $if FMA: 37 const float32x4_t vminus_ln2 = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.minus_ln2); 38 $else: 39 const float32x4_t vminus_ln2_hi = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.minus_ln2_hi); 40 const float32x4_t vminus_ln2_lo = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.minus_ln2_lo); 41 const float32x4_t vc2 = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.c2); 42 const float32x4_t vdenorm_cutoff = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.denorm_cutoff); 43 44 $if ELEMENTS_TILE > 4: 45 $for K in range(ACCUMULATORS): 46 float32x4_t vacc${K} = vmovq_n_f32(0.0f); 47 for (; elements >= ${ELEMENTS_TILE} * sizeof(float); elements -= ${ELEMENTS_TILE} * sizeof(float)) { 48 $for N in range(0, ELEMENTS_TILE, 4): 49 const float32x4_t vi${ABC[N:N+4]} = vld1q_f32(input); input += 4; 50 51 $for N in range(0, ELEMENTS_TILE, 4): 52 const float32x4_t vx${ABC[N:N+4]} = vsubq_f32(vi${ABC[N:N+4]}, vi_max); 53 54 $for N in range(0, ELEMENTS_TILE, 4): 55 float32x4_t vn${ABC[N:N+4]} = ${VMULADDQ_F32}(vmagic_bias, vx${ABC[N:N+4]}, vlog2e); 56 57 $for N in range(0, ELEMENTS_TILE, 4): 58 const int32x4_t ve${ABC[N:N+4]} = vshlq_n_s32(vbicq_s32(vreinterpretq_s32_f32(vn${ABC[N:N+4]}), vmovq_n_s32(INT32_C(0x3F))), 17); 59 60 $for N in range(0, ELEMENTS_TILE, 4): 61 const uint64x2_t vidx${ABC[N:N+4]} = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn${ABC[N:N+4]}), vindex_mask)); 62 const uint64_t vidx${ABC[N:N+2]} = vgetq_lane_u64(vidx${ABC[N:N+4]}, 0); 63 const uint64_t vidx${ABC[N+2:N+4]} = vgetq_lane_u64(vidx${ABC[N:N+4]}, 1); 64 65 $for N in range(0, ELEMENTS_TILE, 4): 66 float32x2_t vl${ABC[N:N+2]} = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx${ABC[N:N+2]}]); 67 float32x2_t vl${ABC[N+2:N+4]} = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx${ABC[N+2:N+4]}]); 68 69 $for N in range(0, ELEMENTS_TILE, 4): 70 vl${ABC[N:N+2]} = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx${ABC[N:N+2]} >> 32)], vl${ABC[N:N+2]}, 1); 71 vl${ABC[N+2:N+4]} = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx${ABC[N+2:N+4]} >> 32)], vl${ABC[N+2:N+4]}, 1); 72 const float32x4_t vl${ABC[N:N+4]} = vcombine_f32(vl${ABC[N:N+2]}, vl${ABC[N+2:N+4]}); 73 74 $for N in range(0, ELEMENTS_TILE, 4): 75 const float32x4_t vs${ABC[N:N+4]} = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl${ABC[N:N+4]}), ve${ABC[N:N+4]})); 76 77 $for N in range(0, ELEMENTS_TILE, 4): 78 vn${ABC[N:N+4]} = vsubq_f32(vn${ABC[N:N+4]}, vmagic_bias); 79 80 $if FMA: 81 $for N in range(0, ELEMENTS_TILE, 4): 82 float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vx${ABC[N:N+4]}, vn${ABC[N:N+4]}, vminus_ln2); 83 $else: 84 $for N in range(0, ELEMENTS_TILE, 4): 85 float32x4_t vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vx${ABC[N:N+4]}, vn${ABC[N:N+4]}, vminus_ln2_hi); 86 87 $for N in range(0, ELEMENTS_TILE, 4): 88 vt${ABC[N:N+4]} = ${VMULADDQ_F32}(vt${ABC[N:N+4]}, vn${ABC[N:N+4]}, vminus_ln2_lo); 89 90 $for N in range(0, ELEMENTS_TILE, 4): 91 float32x4_t vp${ABC[N:N+4]} = vmulq_f32(vt${ABC[N:N+4]}, vc2); 92 93 $for N in range(0, ELEMENTS_TILE, 4): 94 vp${ABC[N:N+4]} = ${VMULADDQ_F32}(vt${ABC[N:N+4]}, vt${ABC[N:N+4]}, vp${ABC[N:N+4]}); 95 96 $for N in range(0, ELEMENTS_TILE, 4): 97 float32x4_t vf${ABC[N:N+4]} = ${VMULADDQ_F32}(vs${ABC[N:N+4]}, vs${ABC[N:N+4]}, vp${ABC[N:N+4]}); 98 99 $for N in range(0, ELEMENTS_TILE, 4): 100 vf${ABC[N:N+4]} = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf${ABC[N:N+4]}), vcltq_f32(vx${ABC[N:N+4]}, vdenorm_cutoff))); 101 102 $for N in range(0, ELEMENTS_TILE, 4): 103 vst1q_f32(output, vf${ABC[N:N+4]}); output += 4; 104 105 $for N in range(0, ELEMENTS_TILE, 4): 106 vacc${N % ACCUMULATORS} = vaddq_f32(vacc${N % ACCUMULATORS}, vf${ABC[N:N+4]}); 107 } 108 $if ACCUMULATORS > 1: 109 $ACC_SLICE = 1 110 $while ACC_SLICE < ACCUMULATORS: 111 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 112 $if A + ACC_SLICE < ACCUMULATORS: 113 vacc${A} = vaddq_f32(vacc${A}, vacc${A + ACC_SLICE}); 114 $ACC_SLICE *= 2 115 116 float32x4_t vacc = vacc0; 117 $else: 118 float32x4_t vacc = vmovq_n_f32(0.0f); 119 for (; elements >= 4 * sizeof(float); elements -= 4 * sizeof(float)) { 120 const float32x4_t vi = vld1q_f32(input); input += 4; 121 122 const float32x4_t vx = vsubq_f32(vi, vi_max); 123 124 float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vx, vlog2e); 125 126 const int32x4_t ve = vshlq_n_s32(vbicq_s32(vreinterpretq_s32_f32(vn), vmovq_n_s32(INT32_C(0x3F))), 17); 127 128 const uint64x2_t vidx = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn), vindex_mask)); 129 const uint64_t vidx_lo = vgetq_lane_u64(vidx, 0); 130 const uint64_t vidx_hi = vgetq_lane_u64(vidx, 1); 131 float32x2_t vl_lo = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx_lo]); 132 float32x2_t vl_hi = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx_hi]); 133 vl_lo = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx_lo >> 32)], vl_lo, 1); 134 vl_hi = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx_hi >> 32)], vl_hi, 1); 135 const float32x4_t vl = vcombine_f32(vl_lo, vl_hi); 136 const float32x4_t vs = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl), ve)); 137 138 vn = vsubq_f32(vn, vmagic_bias); 139 140 $if FMA: 141 float32x4_t vt = ${VMULADDQ_F32}(vx, vn, vminus_ln2); 142 $else: 143 float32x4_t vt = ${VMULADDQ_F32}(vx, vn, vminus_ln2_hi); 144 vt = ${VMULADDQ_F32}(vt, vn, vminus_ln2_lo); 145 146 float32x4_t vp = vmulq_f32(vt, vc2); 147 vp = ${VMULADDQ_F32}(vt, vt, vp); 148 149 float32x4_t vf = ${VMULADDQ_F32}(vs, vs, vp); 150 151 vf = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf), vcltq_f32(vx, vdenorm_cutoff))); 152 153 vst1q_f32(output, vf); output += 4; 154 155 vacc = vaddq_f32(vacc, vf); 156 } 157#if XNN_ARCH_ARM64 158 float vacc_lo = vaddvq_f32(vacc); 159#else 160 float32x2_t vacc_lo = vadd_f32(vget_high_f32(vacc), vget_low_f32(vacc)); 161#endif 162 if (elements != 0) { 163 assert(elements >= 1 * sizeof(float)); 164 assert(elements <= 3 * sizeof(float)); 165 const float32x4_t vi = vld1q_f32(input); input += 4; 166 167 const float32x4_t vx = vsubq_f32(vi, vi_max); 168 169 float32x4_t vn = ${VMULADDQ_F32}(vmagic_bias, vx, vlog2e); 170 171 const int32x4_t ve = vshlq_n_s32(vbicq_s32(vreinterpretq_s32_f32(vn), vmovq_n_s32(INT32_C(0x3F))), 17); 172 173 const uint64x2_t vidx = vreinterpretq_u64_s32(vandq_s32(vreinterpretq_s32_f32(vn), vindex_mask)); 174 const uint64_t vidx_lo = vgetq_lane_u64(vidx, 0); 175 const uint64_t vidx_hi = vgetq_lane_u64(vidx, 1); 176 float32x2_t vl_lo = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx_lo]); 177 float32x2_t vl_hi = vld1_dup_f32(&xnn_table_exp2_k_over_64[(uint32_t) vidx_hi]); 178 vl_lo = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx_lo >> 32)], vl_lo, 1); 179 vl_hi = vld1_lane_f32(&xnn_table_exp2_k_over_64[(uint32_t) (vidx_hi >> 32)], vl_hi, 1); 180 const float32x4_t vl = vcombine_f32(vl_lo, vl_hi); 181 const float32x4_t vs = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(vl), ve)); 182 183 vn = vsubq_f32(vn, vmagic_bias); 184 185 $if FMA: 186 float32x4_t vt = ${VMULADDQ_F32}(vx, vn, vminus_ln2); 187 $else: 188 float32x4_t vt = ${VMULADDQ_F32}(vx, vn, vminus_ln2_hi); 189 vt = ${VMULADDQ_F32}(vt, vn, vminus_ln2_lo); 190 191 float32x4_t vp = vmulq_f32(vt, vc2); 192 vp = ${VMULADDQ_F32}(vt, vt, vp); 193 194 float32x4_t vf = ${VMULADDQ_F32}(vs, vs, vp); 195 196 vf = vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(vf), vcltq_f32(vx, vdenorm_cutoff))); 197 198 float32x2_t vf_lo = vget_low_f32(vf); 199 if (elements & (2 * sizeof(float))) { 200 vst1_f32(output, vf_lo); output += 2; 201 202 #if XNN_ARCH_ARM64 203 vacc_lo += vaddv_f32(vf_lo); 204 #else 205 vacc_lo = vadd_f32(vacc_lo, vf_lo); 206 #endif 207 208 vf_lo = vget_high_f32(vf); 209 } 210 if (elements & (1 * sizeof(float))) { 211 vst1_lane_f32(output, vf_lo, 0); 212 213 #if XNN_ARCH_ARM64 214 vacc_lo += vget_lane_f32(vf_lo, 0); 215 #else 216 vacc_lo = vadd_f32(vacc_lo, vreinterpret_f32_u64(vshl_n_u64(vreinterpret_u64_f32(vf_lo), 32))); 217 #endif 218 } 219 } 220#if XNN_ARCH_ARM64 221 *sum = vacc_lo; 222#else 223 vst1_lane_f32(sum, vpadd_f32(vacc_lo, vacc_lo), 0); 224#endif 225} 226