xref: /aosp_15_r20/external/XNNPACK/src/f32-velu/sse-rr2-p6.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 4 == 0
7$assert BATCH_TILE >= 4
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9$SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
10#include <assert.h>
11
12#include <${SSE_HEADER}>
13
14#include <xnnpack/vunary.h>
15#include <xnnpack/common.h>
16
17
18$ISA = {2: "sse2", 4: "sse41"}[SSE]
19void xnn_f32_velu_ukernel__${ISA}_rr2_p6_x${BATCH_TILE}(
20    size_t n,
21    const float* x,
22    float* y,
23    const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
24{
25  assert(n != 0);
26  assert(n % sizeof(float) == 0);
27  assert(x != NULL);
28  assert(y != NULL);
29
30  const __m128 vprescale = _mm_load_ps(params->sse2_rr2_p6.prescale);
31  const __m128 valpha = _mm_load_ps(params->sse2_rr2_p6.alpha);
32  const __m128 vbeta = _mm_load_ps(params->sse2_rr2_p6.beta);
33  const __m128 vsat_cutoff = _mm_load_ps(params->sse2_rr2_p6.sat_cutoff);
34  const __m128 vmagic_bias = _mm_load_ps(params->sse2_rr2_p6.magic_bias);
35  const __m128 vlog2e = _mm_load_ps(params->sse2_rr2_p6.log2e);
36  const __m128 vminus_ln2_hi = _mm_load_ps(params->sse2_rr2_p6.minus_ln2_hi);
37  const __m128 vminus_ln2_lo = _mm_load_ps(params->sse2_rr2_p6.minus_ln2_lo);
38  const __m128 vc6 = _mm_load_ps(params->sse2_rr2_p6.c6);
39  const __m128 vc5 = _mm_load_ps(params->sse2_rr2_p6.c5);
40  const __m128 vc4 = _mm_load_ps(params->sse2_rr2_p6.c4);
41  const __m128 vc3 = _mm_load_ps(params->sse2_rr2_p6.c3);
42  const __m128 vc2 = _mm_load_ps(params->sse2_rr2_p6.c2);
43  const __m128 vone = _mm_load_ps(params->sse2_rr2_p6.one);
44
45  $if BATCH_TILE > 4:
46    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
47      __m128 vx${ABC[0:4]} = _mm_loadu_ps(x);
48      $for N in range(4, BATCH_TILE, 4):
49        __m128 vx${ABC[N:N+4]} = _mm_loadu_ps(x + ${N});
50      x += ${BATCH_TILE};
51
52      $for N in range(0, BATCH_TILE, 4):
53        const __m128 vz${ABC[N:N+4]} = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx${ABC[N:N+4]}, vprescale));
54
55      $for N in range(0, BATCH_TILE, 4):
56        __m128 vn${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vz${ABC[N:N+4]}, vlog2e), vmagic_bias);
57
58      $for N in range(0, BATCH_TILE, 4):
59        __m128 vs${ABC[N:N+4]} = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn${ABC[N:N+4]}), 23));
60
61      $for N in range(0, BATCH_TILE, 4):
62        vn${ABC[N:N+4]} = _mm_sub_ps(vn${ABC[N:N+4]}, vmagic_bias);
63
64      $for N in range(0, BATCH_TILE, 4):
65        __m128 vt${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vn${ABC[N:N+4]}, vminus_ln2_hi), vz${ABC[N:N+4]});
66
67      $for N in range(0, BATCH_TILE, 4):
68        vt${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vn${ABC[N:N+4]}, vminus_ln2_lo), vt${ABC[N:N+4]});
69
70      $for N in range(0, BATCH_TILE, 4):
71        __m128 vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vc6, vt${ABC[N:N+4]}), vc5);
72
73      $for N in range(0, BATCH_TILE, 4):
74        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc4);
75
76      $for N in range(0, BATCH_TILE, 4):
77        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc3);
78
79      $for N in range(0, BATCH_TILE, 4):
80        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vc2);
81
82      $for N in range(0, BATCH_TILE, 4):
83        vp${ABC[N:N+4]} = _mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]});
84
85      $for N in range(0, BATCH_TILE, 4):
86        vt${ABC[N:N+4]} = _mm_mul_ps(vt${ABC[N:N+4]}, vs${ABC[N:N+4]});
87        vs${ABC[N:N+4]} = _mm_sub_ps(vs${ABC[N:N+4]}, vone);
88
89      $for N in range(0, BATCH_TILE, 4):
90        vp${ABC[N:N+4]} = _mm_add_ps(_mm_mul_ps(vp${ABC[N:N+4]}, vt${ABC[N:N+4]}), vt${ABC[N:N+4]});
91
92      $for N in range(0, BATCH_TILE, 4):
93        const __m128 ve${ABC[N:N+4]} = _mm_mul_ps(_mm_add_ps(vp${ABC[N:N+4]}, vs${ABC[N:N+4]}), valpha);
94
95      $for N in range(0, BATCH_TILE, 4):
96        $if SSE < 4:
97          const __m128 vm${ABC[N:N+4]} = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx${ABC[N:N+4]})));
98        vx${ABC[N:N+4]} = _mm_mul_ps(vx${ABC[N:N+4]}, vbeta);
99
100      $for N in range(0, BATCH_TILE, 4):
101        $if SSE >= 4:
102          const __m128 vy${ABC[N:N+4]} = _mm_blendv_ps(vx${ABC[N:N+4]}, ve${ABC[N:N+4]}, vx${ABC[N:N+4]});
103        $else:
104          const __m128 vy${ABC[N:N+4]} = _mm_or_ps(_mm_and_ps(ve${ABC[N:N+4]}, vm${ABC[N:N+4]}), _mm_andnot_ps(vm${ABC[N:N+4]}, vx${ABC[N:N+4]}));
105
106      _mm_storeu_ps(y, vy${ABC[0:4]});
107      $for N in range(4, BATCH_TILE, 4):
108        _mm_storeu_ps(y + ${N}, vy${ABC[N:N+4]});
109      y += ${BATCH_TILE};
110    }
111  for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) {
112    __m128 vx = _mm_loadu_ps(x);
113    x += 4;
114
115    const __m128 vz = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx, vprescale));
116
117    __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias);
118    __m128 vs = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn), 23));
119    vn = _mm_sub_ps(vn, vmagic_bias);
120
121    __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_hi), vz);
122    vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_lo), vt);
123
124    __m128 vp = _mm_add_ps(_mm_mul_ps(vc6, vt), vc5);
125    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc4);
126    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc3);
127    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2);
128    vp = _mm_mul_ps(vp, vt);
129
130    vt = _mm_mul_ps(vt, vs);
131    vs = _mm_sub_ps(vs, vone);
132    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vt);
133    const __m128 ve = _mm_mul_ps(_mm_add_ps(vp, vs), valpha);
134
135    $if SSE < 4:
136      const __m128 vm = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx)));
137    vx = _mm_mul_ps(vx, vbeta);
138    $if SSE >= 4:
139      const __m128 vy = _mm_blendv_ps(vx, ve, vx);
140    $else:
141      const __m128 vy = _mm_or_ps(_mm_and_ps(ve, vm), _mm_andnot_ps(vm, vx));
142
143    _mm_storeu_ps(y, vy);
144    y += 4;
145  }
146  if XNN_UNLIKELY(n != 0) {
147    __m128 vx = _mm_loadu_ps(x);
148
149    const __m128 vz = _mm_max_ps(vsat_cutoff, _mm_mul_ps(vx, vprescale));
150
151    __m128 vn = _mm_add_ps(_mm_mul_ps(vz, vlog2e), vmagic_bias);
152    __m128 vs = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn), 23));
153    vn = _mm_sub_ps(vn, vmagic_bias);
154
155    __m128 vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_hi), vz);
156    vt = _mm_add_ps(_mm_mul_ps(vn, vminus_ln2_lo), vt);
157
158    __m128 vp = _mm_add_ps(_mm_mul_ps(vc6, vt), vc5);
159    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc4);
160    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc3);
161    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vc2);
162    vp = _mm_mul_ps(vp, vt);
163
164    vt = _mm_mul_ps(vt, vs);
165    vs = _mm_sub_ps(vs, vone);
166    vp = _mm_add_ps(_mm_mul_ps(vp, vt), vt);
167    const __m128 ve = _mm_mul_ps(_mm_add_ps(vp, vs), valpha);
168
169    $if SSE < 4:
170      const __m128 vm = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vx)));
171    vx = _mm_mul_ps(vx, vbeta);
172    $if SSE >= 4:
173      __m128 vy = _mm_blendv_ps(vx, ve, vx);
174    $else:
175      __m128 vy = _mm_or_ps(_mm_and_ps(ve, vm), _mm_andnot_ps(vm, vx));
176
177    if (n & (2 * sizeof(float))) {
178      _mm_storel_pi((__m64*) y, vy);
179      vy = _mm_movehl_ps(vy, vy);
180      y += 2;
181    }
182    if (n & (1 * sizeof(float))) {
183      _mm_store_ss(y, vy);
184    }
185  }
186}
187