xref: /aosp_15_r20/external/XNNPACK/src/f32-velu/scalar-rr2-p6.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2020 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker//
3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker$assert BATCH_TILE >= 1
7*4bdc9457SAndroid Build Coastguard Worker#include <assert.h>
8*4bdc9457SAndroid Build Coastguard Worker#include <math.h>
9*4bdc9457SAndroid Build Coastguard Worker
10*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/common.h>
11*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/math.h>
12*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/vunary.h>
13*4bdc9457SAndroid Build Coastguard Worker
14*4bdc9457SAndroid Build Coastguard Worker
15*4bdc9457SAndroid Build Coastguard Workervoid xnn_f32_velu_ukernel__${"wasm" if WASM else "scalar"}_rr2_p6_x${BATCH_TILE}(
16*4bdc9457SAndroid Build Coastguard Worker    size_t n,
17*4bdc9457SAndroid Build Coastguard Worker    const float* x,
18*4bdc9457SAndroid Build Coastguard Worker    float* y,
19*4bdc9457SAndroid Build Coastguard Worker    const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)])
20*4bdc9457SAndroid Build Coastguard Worker{
21*4bdc9457SAndroid Build Coastguard Worker  assert(n % sizeof(float) == 0);
22*4bdc9457SAndroid Build Coastguard Worker
23*4bdc9457SAndroid Build Coastguard Worker  const float vprescale = params->scalar_rr2_p6.prescale;
24*4bdc9457SAndroid Build Coastguard Worker  const float valpha = params->scalar_rr2_p6.alpha;
25*4bdc9457SAndroid Build Coastguard Worker  const float vbeta = params->scalar_rr2_p6.beta;
26*4bdc9457SAndroid Build Coastguard Worker  const float vmagic_bias = params->scalar_rr2_p6.magic_bias;
27*4bdc9457SAndroid Build Coastguard Worker  const float vlog2e = params->scalar_rr2_p6.log2e;
28*4bdc9457SAndroid Build Coastguard Worker  const float vsat_cutoff = params->scalar_rr2_p6.sat_cutoff;
29*4bdc9457SAndroid Build Coastguard Worker  const float vminus_ln2_hi = params->scalar_rr2_p6.minus_ln2_hi;
30*4bdc9457SAndroid Build Coastguard Worker  const float vminus_ln2_lo = params->scalar_rr2_p6.minus_ln2_lo;
31*4bdc9457SAndroid Build Coastguard Worker  const float vc6 = params->scalar_rr2_p6.c6;
32*4bdc9457SAndroid Build Coastguard Worker  const float vc5 = params->scalar_rr2_p6.c5;
33*4bdc9457SAndroid Build Coastguard Worker  const float vc4 = params->scalar_rr2_p6.c4;
34*4bdc9457SAndroid Build Coastguard Worker  const float vc3 = params->scalar_rr2_p6.c3;
35*4bdc9457SAndroid Build Coastguard Worker  const float vc2 = params->scalar_rr2_p6.c2;
36*4bdc9457SAndroid Build Coastguard Worker  const float vone = params->scalar_rr2_p6.one;
37*4bdc9457SAndroid Build Coastguard Worker
38*4bdc9457SAndroid Build Coastguard Worker  $if BATCH_TILE > 1:
39*4bdc9457SAndroid Build Coastguard Worker    for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
40*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
41*4bdc9457SAndroid Build Coastguard Worker        float vx${N} = x[${N}];
42*4bdc9457SAndroid Build Coastguard Worker      x += ${BATCH_TILE};
43*4bdc9457SAndroid Build Coastguard Worker
44*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
45*4bdc9457SAndroid Build Coastguard Worker        $if WASM:
46*4bdc9457SAndroid Build Coastguard Worker          const float vz${N} = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx${N} * vprescale, vsat_cutoff), 0.0f);
47*4bdc9457SAndroid Build Coastguard Worker        $else:
48*4bdc9457SAndroid Build Coastguard Worker          const float vz${N} = vx${N} * vprescale;
49*4bdc9457SAndroid Build Coastguard Worker
50*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
51*4bdc9457SAndroid Build Coastguard Worker        float vn${N} = vz${N} * vlog2e + vmagic_bias;
52*4bdc9457SAndroid Build Coastguard Worker
53*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
54*4bdc9457SAndroid Build Coastguard Worker        float vs${N} = uint32_as_float(float_as_uint32(vn${N}) << 23);
55*4bdc9457SAndroid Build Coastguard Worker        vn${N} -= vmagic_bias;
56*4bdc9457SAndroid Build Coastguard Worker
57*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
58*4bdc9457SAndroid Build Coastguard Worker        float vt${N} = vn${N} * vminus_ln2_hi + vz${N};
59*4bdc9457SAndroid Build Coastguard Worker
60*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
61*4bdc9457SAndroid Build Coastguard Worker        vt${N} = vn${N} * vminus_ln2_lo + vt${N};
62*4bdc9457SAndroid Build Coastguard Worker
63*4bdc9457SAndroid Build Coastguard Worker      $if not WASM:
64*4bdc9457SAndroid Build Coastguard Worker        $for N in range(BATCH_TILE):
65*4bdc9457SAndroid Build Coastguard Worker          if XNN_UNPREDICTABLE(vz${N} <= vsat_cutoff) {
66*4bdc9457SAndroid Build Coastguard Worker            vs${N} = 0.0f;
67*4bdc9457SAndroid Build Coastguard Worker            vt${N} = 0.0f;
68*4bdc9457SAndroid Build Coastguard Worker          }
69*4bdc9457SAndroid Build Coastguard Worker
70*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
71*4bdc9457SAndroid Build Coastguard Worker        float vp${N} = vc6 * vt${N} + vc5;
72*4bdc9457SAndroid Build Coastguard Worker
73*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
74*4bdc9457SAndroid Build Coastguard Worker        vp${N} = vp${N} * vt${N} + vc4;
75*4bdc9457SAndroid Build Coastguard Worker
76*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
77*4bdc9457SAndroid Build Coastguard Worker        vp${N} = vp${N} * vt${N} + vc3;
78*4bdc9457SAndroid Build Coastguard Worker
79*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
80*4bdc9457SAndroid Build Coastguard Worker        vp${N} = vp${N} * vt${N} + vc2;
81*4bdc9457SAndroid Build Coastguard Worker
82*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
83*4bdc9457SAndroid Build Coastguard Worker        vp${N} *= vt${N};
84*4bdc9457SAndroid Build Coastguard Worker
85*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
86*4bdc9457SAndroid Build Coastguard Worker        vt${N} *= vs${N};
87*4bdc9457SAndroid Build Coastguard Worker        vs${N} -= vone;
88*4bdc9457SAndroid Build Coastguard Worker
89*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
90*4bdc9457SAndroid Build Coastguard Worker        vp${N} = vp${N} * vt${N} + vt${N};
91*4bdc9457SAndroid Build Coastguard Worker
92*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
93*4bdc9457SAndroid Build Coastguard Worker        const float ve${N} = (vp${N} + vs${N}) * valpha;
94*4bdc9457SAndroid Build Coastguard Worker        $if WASM:
95*4bdc9457SAndroid Build Coastguard Worker          float vy${N} = __builtin_wasm_max_f32(vx${N} * vbeta, 0.0f);
96*4bdc9457SAndroid Build Coastguard Worker        $else:
97*4bdc9457SAndroid Build Coastguard Worker          float vy${N} = vx${N} * vbeta;
98*4bdc9457SAndroid Build Coastguard Worker
99*4bdc9457SAndroid Build Coastguard Worker      $if WASM:
100*4bdc9457SAndroid Build Coastguard Worker        $for N in range(BATCH_TILE):
101*4bdc9457SAndroid Build Coastguard Worker          vy${N} += __builtin_wasm_min_f32(ve${N}, 0.0f);
102*4bdc9457SAndroid Build Coastguard Worker      $else:
103*4bdc9457SAndroid Build Coastguard Worker        $for N in range(BATCH_TILE):
104*4bdc9457SAndroid Build Coastguard Worker          if XNN_UNPREDICTABLE(vx${N} < 0.0f) {
105*4bdc9457SAndroid Build Coastguard Worker            vy${N} = ve${N};
106*4bdc9457SAndroid Build Coastguard Worker          }
107*4bdc9457SAndroid Build Coastguard Worker
108*4bdc9457SAndroid Build Coastguard Worker      $for N in range(BATCH_TILE):
109*4bdc9457SAndroid Build Coastguard Worker        y[${N}] = vy${N};
110*4bdc9457SAndroid Build Coastguard Worker      y += ${BATCH_TILE};
111*4bdc9457SAndroid Build Coastguard Worker    }
112*4bdc9457SAndroid Build Coastguard Worker  $if BATCH_TILE == 1:
113*4bdc9457SAndroid Build Coastguard Worker    do {
114*4bdc9457SAndroid Build Coastguard Worker      float vx = *x++;
115*4bdc9457SAndroid Build Coastguard Worker
116*4bdc9457SAndroid Build Coastguard Worker      $if WASM:
117*4bdc9457SAndroid Build Coastguard Worker        const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
118*4bdc9457SAndroid Build Coastguard Worker      $else:
119*4bdc9457SAndroid Build Coastguard Worker        const float vz = vx * vprescale;
120*4bdc9457SAndroid Build Coastguard Worker
121*4bdc9457SAndroid Build Coastguard Worker      float vn = vz * vlog2e + vmagic_bias;
122*4bdc9457SAndroid Build Coastguard Worker      float vs = uint32_as_float(float_as_uint32(vn) << 23);
123*4bdc9457SAndroid Build Coastguard Worker      vn -= vmagic_bias;
124*4bdc9457SAndroid Build Coastguard Worker
125*4bdc9457SAndroid Build Coastguard Worker      float vt = vn * vminus_ln2_hi + vz;
126*4bdc9457SAndroid Build Coastguard Worker      vt = vn * vminus_ln2_lo + vt;
127*4bdc9457SAndroid Build Coastguard Worker
128*4bdc9457SAndroid Build Coastguard Worker      $if not WASM:
129*4bdc9457SAndroid Build Coastguard Worker        if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
130*4bdc9457SAndroid Build Coastguard Worker          vs = 0.0f;
131*4bdc9457SAndroid Build Coastguard Worker          vt = 0.0f;
132*4bdc9457SAndroid Build Coastguard Worker        }
133*4bdc9457SAndroid Build Coastguard Worker
134*4bdc9457SAndroid Build Coastguard Worker      float vp = vc6 * vt + vc5;
135*4bdc9457SAndroid Build Coastguard Worker      vp = vp * vt + vc4;
136*4bdc9457SAndroid Build Coastguard Worker      vp = vp * vt + vc3;
137*4bdc9457SAndroid Build Coastguard Worker      vp = vp * vt + vc2;
138*4bdc9457SAndroid Build Coastguard Worker      vp *= vt;
139*4bdc9457SAndroid Build Coastguard Worker
140*4bdc9457SAndroid Build Coastguard Worker      vt *= vs;
141*4bdc9457SAndroid Build Coastguard Worker      vs -= vone;
142*4bdc9457SAndroid Build Coastguard Worker      vp = vp * vt + vt;
143*4bdc9457SAndroid Build Coastguard Worker      const float ve = (vp + vs) * valpha;
144*4bdc9457SAndroid Build Coastguard Worker
145*4bdc9457SAndroid Build Coastguard Worker      $if WASM:
146*4bdc9457SAndroid Build Coastguard Worker        float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
147*4bdc9457SAndroid Build Coastguard Worker        vy += __builtin_wasm_min_f32(ve, 0.0f);
148*4bdc9457SAndroid Build Coastguard Worker      $else:
149*4bdc9457SAndroid Build Coastguard Worker        float vy = vx * vbeta;
150*4bdc9457SAndroid Build Coastguard Worker        if XNN_UNPREDICTABLE(vx < 0.0f) {
151*4bdc9457SAndroid Build Coastguard Worker          vy = ve;
152*4bdc9457SAndroid Build Coastguard Worker        }
153*4bdc9457SAndroid Build Coastguard Worker
154*4bdc9457SAndroid Build Coastguard Worker      *y++ = vy;
155*4bdc9457SAndroid Build Coastguard Worker
156*4bdc9457SAndroid Build Coastguard Worker      n -= sizeof(float);
157*4bdc9457SAndroid Build Coastguard Worker    } while (n != 0);
158*4bdc9457SAndroid Build Coastguard Worker  $elif BATCH_TILE == 2:
159*4bdc9457SAndroid Build Coastguard Worker    if XNN_UNLIKELY(n != 0) {
160*4bdc9457SAndroid Build Coastguard Worker      float vx = *x;
161*4bdc9457SAndroid Build Coastguard Worker
162*4bdc9457SAndroid Build Coastguard Worker      $if WASM:
163*4bdc9457SAndroid Build Coastguard Worker        const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
164*4bdc9457SAndroid Build Coastguard Worker      $else:
165*4bdc9457SAndroid Build Coastguard Worker        const float vz = vx * vprescale;
166*4bdc9457SAndroid Build Coastguard Worker
167*4bdc9457SAndroid Build Coastguard Worker      float vn = vz * vlog2e + vmagic_bias;
168*4bdc9457SAndroid Build Coastguard Worker      float vs = uint32_as_float(float_as_uint32(vn) << 23);
169*4bdc9457SAndroid Build Coastguard Worker      vn -= vmagic_bias;
170*4bdc9457SAndroid Build Coastguard Worker
171*4bdc9457SAndroid Build Coastguard Worker      float vt = vn * vminus_ln2_hi + vz;
172*4bdc9457SAndroid Build Coastguard Worker      vt = vn * vminus_ln2_lo + vt;
173*4bdc9457SAndroid Build Coastguard Worker
174*4bdc9457SAndroid Build Coastguard Worker      $if not WASM:
175*4bdc9457SAndroid Build Coastguard Worker        if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
176*4bdc9457SAndroid Build Coastguard Worker          vs = 0.0f;
177*4bdc9457SAndroid Build Coastguard Worker          vt = 0.0f;
178*4bdc9457SAndroid Build Coastguard Worker        }
179*4bdc9457SAndroid Build Coastguard Worker
180*4bdc9457SAndroid Build Coastguard Worker      float vp = vc6 * vt + vc5;
181*4bdc9457SAndroid Build Coastguard Worker      vp = vp * vt + vc4;
182*4bdc9457SAndroid Build Coastguard Worker      vp = vp * vt + vc3;
183*4bdc9457SAndroid Build Coastguard Worker      vp = vp * vt + vc2;
184*4bdc9457SAndroid Build Coastguard Worker      vp *= vt;
185*4bdc9457SAndroid Build Coastguard Worker
186*4bdc9457SAndroid Build Coastguard Worker      vt *= vs;
187*4bdc9457SAndroid Build Coastguard Worker      vs -= vone;
188*4bdc9457SAndroid Build Coastguard Worker      vp = vp * vt + vt;
189*4bdc9457SAndroid Build Coastguard Worker      const float ve = (vp + vs) * valpha;
190*4bdc9457SAndroid Build Coastguard Worker
191*4bdc9457SAndroid Build Coastguard Worker      $if WASM:
192*4bdc9457SAndroid Build Coastguard Worker        float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
193*4bdc9457SAndroid Build Coastguard Worker        vy += __builtin_wasm_min_f32(ve, 0.0f);
194*4bdc9457SAndroid Build Coastguard Worker      $else:
195*4bdc9457SAndroid Build Coastguard Worker        float vy = vx * vbeta;
196*4bdc9457SAndroid Build Coastguard Worker        if XNN_UNPREDICTABLE(vx < 0.0f) {
197*4bdc9457SAndroid Build Coastguard Worker          vy = ve;
198*4bdc9457SAndroid Build Coastguard Worker        }
199*4bdc9457SAndroid Build Coastguard Worker
200*4bdc9457SAndroid Build Coastguard Worker      *y = vy;
201*4bdc9457SAndroid Build Coastguard Worker    }
202*4bdc9457SAndroid Build Coastguard Worker  $else:
203*4bdc9457SAndroid Build Coastguard Worker    if XNN_UNLIKELY(n != 0) {
204*4bdc9457SAndroid Build Coastguard Worker      do {
205*4bdc9457SAndroid Build Coastguard Worker        float vx = *x++;
206*4bdc9457SAndroid Build Coastguard Worker
207*4bdc9457SAndroid Build Coastguard Worker        $if WASM:
208*4bdc9457SAndroid Build Coastguard Worker          const float vz = __builtin_wasm_min_f32(__builtin_wasm_max_f32(vx * vprescale, vsat_cutoff), 0.0f);
209*4bdc9457SAndroid Build Coastguard Worker        $else:
210*4bdc9457SAndroid Build Coastguard Worker          const float vz = vx * vprescale;
211*4bdc9457SAndroid Build Coastguard Worker
212*4bdc9457SAndroid Build Coastguard Worker        float vn = vz * vlog2e + vmagic_bias;
213*4bdc9457SAndroid Build Coastguard Worker        float vs = uint32_as_float(float_as_uint32(vn) << 23);
214*4bdc9457SAndroid Build Coastguard Worker        vn -= vmagic_bias;
215*4bdc9457SAndroid Build Coastguard Worker
216*4bdc9457SAndroid Build Coastguard Worker        float vt = vn * vminus_ln2_hi + vz;
217*4bdc9457SAndroid Build Coastguard Worker        vt = vn * vminus_ln2_lo + vt;
218*4bdc9457SAndroid Build Coastguard Worker
219*4bdc9457SAndroid Build Coastguard Worker        $if not WASM:
220*4bdc9457SAndroid Build Coastguard Worker          if XNN_UNPREDICTABLE(vz <= vsat_cutoff) {
221*4bdc9457SAndroid Build Coastguard Worker            vs = 0.0f;
222*4bdc9457SAndroid Build Coastguard Worker            vt = 0.0f;
223*4bdc9457SAndroid Build Coastguard Worker          }
224*4bdc9457SAndroid Build Coastguard Worker
225*4bdc9457SAndroid Build Coastguard Worker        float vp = vc6 * vt + vc5;
226*4bdc9457SAndroid Build Coastguard Worker        vp = vp * vt + vc4;
227*4bdc9457SAndroid Build Coastguard Worker        vp = vp * vt + vc3;
228*4bdc9457SAndroid Build Coastguard Worker        vp = vp * vt + vc2;
229*4bdc9457SAndroid Build Coastguard Worker        vp *= vt;
230*4bdc9457SAndroid Build Coastguard Worker
231*4bdc9457SAndroid Build Coastguard Worker        vt *= vs;
232*4bdc9457SAndroid Build Coastguard Worker        vs -= vone;
233*4bdc9457SAndroid Build Coastguard Worker        vp = vp * vt + vt;
234*4bdc9457SAndroid Build Coastguard Worker        const float ve = (vp + vs) * valpha;
235*4bdc9457SAndroid Build Coastguard Worker
236*4bdc9457SAndroid Build Coastguard Worker        $if WASM:
237*4bdc9457SAndroid Build Coastguard Worker          float vy = __builtin_wasm_max_f32(vx * vbeta, 0.0f);
238*4bdc9457SAndroid Build Coastguard Worker          vy += __builtin_wasm_min_f32(ve, 0.0f);
239*4bdc9457SAndroid Build Coastguard Worker        $else:
240*4bdc9457SAndroid Build Coastguard Worker          float vy = vx * vbeta;
241*4bdc9457SAndroid Build Coastguard Worker          if XNN_UNPREDICTABLE(vx < 0.0f) {
242*4bdc9457SAndroid Build Coastguard Worker            vy = ve;
243*4bdc9457SAndroid Build Coastguard Worker          }
244*4bdc9457SAndroid Build Coastguard Worker
245*4bdc9457SAndroid Build Coastguard Worker        *y++ = vy;
246*4bdc9457SAndroid Build Coastguard Worker
247*4bdc9457SAndroid Build Coastguard Worker        n -= sizeof(float);
248*4bdc9457SAndroid Build Coastguard Worker      } while (n != 0);
249*4bdc9457SAndroid Build Coastguard Worker    }
250*4bdc9457SAndroid Build Coastguard Worker}
251