xref: /aosp_15_r20/external/XNNPACK/src/f16-raddstoreexpminusmax/avx2-rr1-p2.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 8 == 0
7$assert BATCH_TILE >= 8
8$SIMD_TILE = BATCH_TILE // 8
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/intrinsics-polyfill.h>
15#include <xnnpack/raddstoreexpminusmax.h>
16
17
18void xnn_f16_raddstoreexpminusmax_ukernel__avx2_rr1_p2_x${BATCH_TILE}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
19    size_t batch,
20    const void* input,
21    const void* max,
22    void* output,
23    void* sum,
24    const union xnn_f16_expminus_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
25{
26  assert(batch % sizeof(uint16_t) == 0);
27
28  const __m256 vi_max = _mm256_cvtph_ps(_mm_set1_epi16((short) *((const uint16_t*) max)));
29  const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p2.log2e);
30  const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p2.magic_bias);
31  const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p2.minus_ln2);
32  const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p2.c2);
33  const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p2.c1);
34  const __m256 vdenorm_cutoff = _mm256_load_ps(params->avx2_rr1_p2.denorm_cutoff);
35
36  const uint16_t* i = (const uint16_t*) input;
37  uint16_t* o = (uint16_t*) output;
38  $for K in range(ACCUMULATORS):
39    __m256 vacc${K} = _mm256_setzero_ps();
40  for (; batch >= ${BATCH_TILE} * sizeof(uint16_t); batch -= ${BATCH_TILE} * sizeof(uint16_t)) {
41    const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
42    $for N in range(1, SIMD_TILE):
43      const __m256 vi${ABC[N]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + ${N * 8})));
44    i += ${BATCH_TILE};
45
46    $for N in range(SIMD_TILE):
47      const __m256 vx${ABC[N]} = _mm256_sub_ps(vi${ABC[N]}, vi_max);
48
49    $for N in range(SIMD_TILE):
50      __m256 vn${ABC[N]} = _mm256_fmadd_ps(vx${ABC[N]}, vlog2e, vmagic_bias);
51
52    $for N in range(SIMD_TILE):
53      const __m256 vs${ABC[N]} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${ABC[N]}), 23));
54
55    $for N in range(SIMD_TILE):
56      vn${ABC[N]} = _mm256_sub_ps(vn${ABC[N]}, vmagic_bias);
57
58    $for N in range(SIMD_TILE):
59      __m256 vt${ABC[N]} = _mm256_fmadd_ps(vn${ABC[N]}, vminus_ln2, vx${ABC[N]});
60
61    $for N in range(SIMD_TILE):
62      const __m256 vp${ABC[N]} = _mm256_fmadd_ps(vc2, vt${ABC[N]}, vc1);
63
64    $for N in range(SIMD_TILE):
65      vt${ABC[N]} = _mm256_mul_ps(vt${ABC[N]}, vs${ABC[N]});
66
67    $for N in range(SIMD_TILE):
68      __m256 vf${ABC[N]} = _mm256_fmadd_ps(vt${ABC[N]}, vp${ABC[N]}, vs${ABC[N]});
69
70    $for N in range(SIMD_TILE):
71      vf${ABC[N]} = _mm256_andnot_ps(_mm256_cmp_ps(vx${ABC[N]}, vdenorm_cutoff, _CMP_LT_OS), vf${ABC[N]});
72
73    _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf0, _MM_FROUND_NO_EXC));
74    $for N in range(1, SIMD_TILE):
75      _mm_storeu_si128((__m128i*) (o + ${N * 8}), _mm256_cvtps_ph(vf${ABC[N]}, _MM_FROUND_NO_EXC));
76    o += ${BATCH_TILE};
77
78    $for N in range(SIMD_TILE):
79      vacc${N % ACCUMULATORS} = _mm256_add_ps(vacc${N % ACCUMULATORS}, vf${ABC[N]});
80  }
81  $if ACCUMULATORS > 1:
82    $ACC_SLICE = 1
83    $while ACC_SLICE < ACCUMULATORS:
84      $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
85        $if A + ACC_SLICE < ACCUMULATORS:
86          vacc${A} = _mm256_add_ps(vacc${A}, vacc${A + ACC_SLICE});
87      $ACC_SLICE *= 2
88
89  __m256 vacc = vacc0;
90  for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) {
91    const __m256 vi = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
92    i += 8;
93
94    const __m256 vx = _mm256_sub_ps(vi, vi_max);
95
96    __m256 vn = _mm256_fmadd_ps(vx, vlog2e, vmagic_bias);
97
98    const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
99
100    vn = _mm256_sub_ps(vn, vmagic_bias);
101
102    __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vx);
103
104    const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1);
105    vt = _mm256_mul_ps(vt, vs);
106    __m256 vf = _mm256_fmadd_ps(vt, vp, vs);
107    vf = _mm256_andnot_ps(_mm256_cmp_ps(vx, vdenorm_cutoff, _CMP_LT_OS), vf);
108
109    _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC));
110    o += 8;
111
112    vacc = _mm256_add_ps(vacc, vf);
113  }
114  __m128 vacc_lo = _mm_add_ps(_mm256_castps256_ps128(vacc), _mm256_extractf128_ps(vacc, 1));
115  if (batch != 0) {
116    assert(batch >= 1 * sizeof(uint16_t));
117    assert(batch <= 7 * sizeof(uint16_t));
118
119    const __m256 vi = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
120
121    const __m256 vx = _mm256_sub_ps(vi, vi_max);
122
123    __m256 vn = _mm256_fmadd_ps(vx, vlog2e, vmagic_bias);
124
125    const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
126
127    vn = _mm256_sub_ps(vn, vmagic_bias);
128
129    __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vx);
130
131    const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1);
132    vt = _mm256_mul_ps(vt, vs);
133    __m256 vf = _mm256_fmadd_ps(vt, vp, vs);
134    vf = _mm256_andnot_ps(_mm256_cmp_ps(vx, vdenorm_cutoff, _CMP_LT_OS), vf);
135
136    __m128i vh = _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC);
137    __m128 vf_lo = _mm256_castps256_ps128(vf);
138    if (batch & (4 * sizeof(uint16_t))) {
139      _mm_storel_epi64((__m128i*) o, vh);
140      vh = _mm_unpackhi_epi64(vh, vh);
141      vacc_lo = _mm_add_ps(vacc_lo, vf_lo);
142      vf_lo = _mm256_extractf128_ps(vf, 1);
143      o += 4;
144    }
145    if (batch & (2 * sizeof(uint16_t))) {
146      _mm_storeu_si32(o, vh);
147      vh = _mm_srli_epi64(vh, 32);
148      vacc_lo = _mm_blend_ps(_mm_add_ps(vacc_lo, vf_lo), vacc_lo, 0xC);
149      vf_lo = _mm_movehl_ps(vf_lo, vf_lo);
150      o += 2;
151    }
152    if (batch & (1 * sizeof(uint16_t))) {
153      *o = (uint16_t) _mm_extract_epi16(vh, 0);
154      vacc_lo = _mm_add_ss(vacc_lo, vf_lo);
155    }
156  }
157  vacc_lo = _mm_add_ps(vacc_lo, _mm_movehl_ps(vacc_lo, vacc_lo));
158  vacc_lo = _mm_add_ss(vacc_lo, _mm_movehdup_ps(vacc_lo));
159  *((uint16_t*) sum) = (uint16_t) _mm_extract_epi16(_mm_cvtps_ph(vacc_lo, _MM_FROUND_NO_EXC), 0);
160  _mm256_zeroupper();
161}
162