xref: /aosp_15_r20/external/XNNPACK/src/f32-raddstoreexpminusmax/avx512f-rr1-p5-scalef.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ELEMENTS_TILE % 16 == 0
7$assert ELEMENTS_TILE >= 16
8$SIMD_TILE = ELEMENTS_TILE // 16
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/intrinsics-polyfill.h>
15#include <xnnpack/raddstoreexpminusmax.h>
16
17
18void xnn_f32_raddstoreexpminusmax_ukernel__avx512f_rr1_p5_scalef_x${ELEMENTS_TILE}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
19    size_t elements,
20    const float* input,
21    const float* max,
22    float* output,
23    float* sum,
24    const union xnn_f32_expminus_params params[restrict XNN_MIN_ELEMENTS(1)])
25{
26  assert(elements % sizeof(float) == 0);
27
28  const __m512 vi_max = _mm512_set1_ps(*max);
29  const __m512 vlog2e = _mm512_set1_ps(params->avx512_rr1_p5.log2e);
30  const __m512 vminus_ln2 = _mm512_set1_ps(params->avx512_rr1_p5.minus_ln2);
31  const __m512 vc5 = _mm512_set1_ps(params->avx512_rr1_p5.c5);
32  const __m512 vc4 = _mm512_set1_ps(params->avx512_rr1_p5.c4);
33  const __m512 vc3 = _mm512_set1_ps(params->avx512_rr1_p5.c3);
34  const __m512 vc2 = _mm512_set1_ps(params->avx512_rr1_p5.c2);
35  const __m512 vc1 = _mm512_set1_ps(params->avx512_rr1_p5.c1);
36  const __m512 vc0 = _mm512_set1_ps(params->avx512_rr1_p5.c0);
37
38  $for K in range(ACCUMULATORS):
39    __m512 vacc${K} = _mm512_setzero_ps();
40  for (; elements >= ${ELEMENTS_TILE} * sizeof(float); elements -= ${ELEMENTS_TILE} * sizeof(float)) {
41    const __m512 vi0 = _mm512_loadu_ps(input);
42    $for N in range(1, SIMD_TILE):
43      const __m512 vi${N} = _mm512_loadu_ps(input + ${N * 16});
44    input += ${ELEMENTS_TILE};
45
46    $for N in range(SIMD_TILE):
47      const __m512 vx${N} = _mm512_sub_ps(vi${N}, vi_max);
48
49    $for N in range(SIMD_TILE):
50      const __m512 vn${N} = _mm512_roundscale_ps(_mm512_mul_ps(vx${N}, vlog2e), 0);
51
52    $for N in range(SIMD_TILE):
53      const __m512 vt${N} = _mm512_fmadd_ps(vn${N}, vminus_ln2, vx${N});
54
55    $for N in range(SIMD_TILE):
56      __m512 vp${N} = _mm512_fmadd_ps(vc5, vt${N}, vc4);
57
58    $for N in range(SIMD_TILE):
59      vp${N} = _mm512_fmadd_ps(vp${N}, vt${N}, vc3);
60
61    $for N in range(SIMD_TILE):
62      vp${N} = _mm512_fmadd_ps(vp${N}, vt${N}, vc2);
63
64    $for N in range(SIMD_TILE):
65      vp${N} = _mm512_fmadd_ps(vp${N}, vt${N}, vc1);
66
67    $for N in range(SIMD_TILE):
68      vp${N} = _mm512_fmadd_ps(vp${N}, vt${N}, vc0);
69
70    $for N in range(SIMD_TILE):
71      const __m512 vf${N} = _mm512_scalef_ps(vp${N}, vn${N});
72
73    _mm512_storeu_ps(output, vf0);
74    $for N in range(1, SIMD_TILE):
75      _mm512_storeu_ps(output + ${N * 16}, vf${N});
76    output += ${ELEMENTS_TILE};
77
78    $for N in range(SIMD_TILE):
79      vacc${N % ACCUMULATORS} = _mm512_add_ps(vacc${N % ACCUMULATORS}, vf${N});
80  }
81  $if ACCUMULATORS > 1:
82    $ACC_SLICE = 1
83    $while ACC_SLICE < ACCUMULATORS:
84      $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
85        $if A + ACC_SLICE < ACCUMULATORS:
86          vacc${A} = _mm512_add_ps(vacc${A}, vacc${A + ACC_SLICE});
87      $ACC_SLICE *= 2
88
89  __m512 vacc = vacc0;
90  for (; elements >= 16 * sizeof(float); elements -= 16 * sizeof(float)) {
91    const __m512 vi = _mm512_loadu_ps(input);
92    input += 16;
93
94    const __m512 vx = _mm512_sub_ps(vi, vi_max);
95
96    const __m512 vn = _mm512_roundscale_ps(_mm512_mul_ps(vx, vlog2e), 0);
97
98    const __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2, vx);
99
100    __m512 vp = _mm512_fmadd_ps(vc5, vt, vc4);
101    vp = _mm512_fmadd_ps(vp, vt, vc3);
102    vp = _mm512_fmadd_ps(vp, vt, vc2);
103    vp = _mm512_fmadd_ps(vp, vt, vc1);
104    vp = _mm512_fmadd_ps(vp, vt, vc0);
105
106    const __m512 vf = _mm512_scalef_ps(vp, vn);
107
108    _mm512_storeu_ps(output, vf);
109    output += 16;
110
111    vacc = _mm512_add_ps(vacc, vf);
112  }
113  if (elements != 0) {
114    // Prepare mask for valid 32-bit elements (depends on elements).
115    elements >>= 2 /* log2(sizeof(float)) */;
116    const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << elements) - UINT32_C(1)));
117
118    const __m512 vi = _mm512_maskz_loadu_ps(vmask, input);
119
120    const __m512 vx = _mm512_sub_ps(vi, vi_max);
121
122    const __m512 vn = _mm512_roundscale_ps(_mm512_mul_ps(vx, vlog2e), 0);
123
124    const __m512 vt = _mm512_fmadd_ps(vn, vminus_ln2, vx);
125
126    __m512 vp = _mm512_fmadd_ps(vc5, vt, vc4);
127    vp = _mm512_fmadd_ps(vp, vt, vc3);
128    vp = _mm512_fmadd_ps(vp, vt, vc2);
129    vp = _mm512_fmadd_ps(vp, vt, vc1);
130    vp = _mm512_fmadd_ps(vp, vt, vc0);
131
132    const __m512 vf = _mm512_scalef_ps(vp, vn);
133
134    _mm512_mask_storeu_ps(output, vmask, vf);
135
136    vacc = _mm512_mask_add_ps(vacc, vmask, vacc, vf);
137  }
138  *sum = _mm512_reduce_add_ps(vacc);
139}
140