xref: /aosp_15_r20/external/libopus/dnn/vec_neon.h (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1 /* Copyright (c) 2018 David Rowe
2                  2018 Mozilla
3                  2008-2011 Octasic Inc.
4                  2012-2017 Jean-Marc Valin */
5 /*
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9 
10    - Redistributions of source code must retain the above copyright
11    notice, this list of conditions and the following disclaimer.
12 
13    - Redistributions in binary form must reproduce the above copyright
14    notice, this list of conditions and the following disclaimer in the
15    documentation and/or other materials provided with the distribution.
16 
17    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
21    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29 /* NEON support for ARM machines */
30 
31 #ifndef VEC_NEON_H
32 #define VEC_NEON_H
33 
34 #include <arm_neon.h>
35 #include "os_support.h"
36 
37 #if defined(__arm__) && !defined(__aarch64__)
38 /* Emulate vcvtnq_s32_f32() for ARMv7 Neon. */
vcvtnq_s32_f32(float32x4_t x)39 static OPUS_INLINE int32x4_t vcvtnq_s32_f32(float32x4_t x) {
40   return vrshrq_n_s32(vcvtq_n_s32_f32(x, 8), 8);
41 }
42 
vpaddq_s16(int16x8_t a,int16x8_t b)43 static OPUS_INLINE int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
44   return vcombine_s16(vpadd_s16(vget_low_s16(a), vget_high_s16(a)), vpadd_s16(vget_low_s16(b), vget_high_s16(b)));
45 }
46 
vmull_high_s8(int8x16_t a,int8x16_t b)47 static OPUS_INLINE int16x8_t vmull_high_s8(int8x16_t a, int8x16_t b) {
48   return vmull_s8(vget_high_s8(a), vget_high_s8(b));
49 }
50 #endif
51 
52 #ifdef __ARM_FEATURE_FMA
53 /* If we can, force the compiler to use an FMA instruction rather than break
54    vmlaq_f32() into fmul/fadd. */
55 #define vmlaq_f32(a,b,c) vfmaq_f32(a,b,c)
56 #endif
57 
58 #ifndef LPCNET_TEST
exp4_approx(float32x4_t x)59 static inline float32x4_t exp4_approx(float32x4_t x) {
60   int32x4_t i;
61   float32x4_t xf;
62 
63   x = vmaxq_f32(vminq_f32(x, vdupq_n_f32(88.f)), vdupq_n_f32(-88.f));
64 
65   /* express exp(x) as exp2(x/log(2)), add 127 for the exponent later */
66   x = vmlaq_f32(vdupq_n_f32(127.f), x, vdupq_n_f32(1.44269504f));
67 
68   /* split into integer and fractional parts */
69   i = vcvtq_s32_f32(x);
70   xf = vcvtq_f32_s32(i);
71   x = vsubq_f32(x, xf);
72 
73   float32x4_t K0 = vdupq_n_f32(0.99992522f);
74   float32x4_t K1 = vdupq_n_f32(0.69583354f);
75   float32x4_t K2 = vdupq_n_f32(0.22606716f);
76   float32x4_t K3 = vdupq_n_f32(0.078024523f);
77   float32x4_t Y = vmlaq_f32(K0, x, vmlaq_f32(K1, x, vmlaq_f32(K2, K3, x)));
78 
79   /* compute 2^i */
80   float32x4_t exponent = vreinterpretq_f32_s32(vshlq_n_s32(i, 23));
81 
82   Y = vmulq_f32(Y, exponent);
83   return Y;
84 }
85 
tanh4_approx(float32x4_t X)86 static inline float32x4_t tanh4_approx(float32x4_t X)
87 {
88   const float32x4_t N0 = vdupq_n_f32(952.52801514f);
89   const float32x4_t N1 = vdupq_n_f32(96.39235687f);
90   const float32x4_t N2 = vdupq_n_f32(0.60863042f);
91   const float32x4_t D0 = vdupq_n_f32(952.72399902f);
92   const float32x4_t D1 = vdupq_n_f32(413.36801147f);
93   const float32x4_t D2 = vdupq_n_f32(11.88600922f);
94   const float32x4_t max_out = vdupq_n_f32(1.f);
95   const float32x4_t min_out = vdupq_n_f32(-1.f);
96   float32x4_t X2, num, den;
97   X2 = vmulq_f32(X, X);
98   num = vmlaq_f32(N0, X2, vmlaq_f32(N1, N2, X2));
99   den = vmlaq_f32(D0, X2, vmlaq_f32(D1, D2, X2));
100   num = vmulq_f32(num, X);
101   den = vrecpeq_f32(den);
102   num = vmulq_f32(num, den);
103   return vmaxq_f32(min_out, vminq_f32(max_out, num));
104 }
105 
sigmoid4_approx(float32x4_t X)106 static inline float32x4_t sigmoid4_approx(float32x4_t X)
107 {
108   const float32x4_t N0 = vdupq_n_f32(238.13200378f);
109   const float32x4_t N1 = vdupq_n_f32(6.02452230f);
110   const float32x4_t N2 = vdupq_n_f32(0.00950985f);
111   const float32x4_t D0 = vdupq_n_f32(952.72399902f);
112   const float32x4_t D1 = vdupq_n_f32(103.34200287f);
113   const float32x4_t D2 = vdupq_n_f32(0.74287558f);
114   const float32x4_t half = vdupq_n_f32(0.5f);
115   const float32x4_t max_out = vdupq_n_f32(1.f);
116   const float32x4_t min_out = vdupq_n_f32(0.f);
117   float32x4_t X2, num, den;
118   X2 = vmulq_f32(X, X);
119   num = vmlaq_f32(N0, X2, vmlaq_f32(N1, N2, X2));
120   den = vmlaq_f32(D0, X2, vmlaq_f32(D1, D2, X2));
121   num = vmulq_f32(num, X);
122   den = vrecpeq_f32(den);
123   num = vmlaq_f32(half, num, den);
124   return vmaxq_f32(min_out, vminq_f32(max_out, num));
125 }
126 
lpcnet_exp(float x)127 static inline float lpcnet_exp(float x)
128 {
129    float out[4];
130    float32x4_t X, Y;
131    X = vdupq_n_f32(x);
132    Y = exp4_approx(X);
133    vst1q_f32(out, Y);
134    return out[0];
135 }
136 
tanh_approx(float x)137 static inline float tanh_approx(float x)
138 {
139    float out[4];
140    float32x4_t X, Y;
141    X = vdupq_n_f32(x);
142    Y = tanh4_approx(X);
143    vst1q_f32(out, Y);
144    return out[0];
145 }
146 
sigmoid_approx(float x)147 static inline float sigmoid_approx(float x)
148 {
149    float out[4];
150    float32x4_t X, Y;
151    X = vdupq_n_f32(x);
152    Y = sigmoid4_approx(X);
153    vst1q_f32(out, Y);
154    return out[0];
155 }
156 
softmax(float * y,const float * x,int N)157 static inline void softmax(float *y, const float *x, int N)
158 {
159     int i;
160     for (i=0;i<N-3;i+=4)
161     {
162         float32x4_t X, Y;
163         X = vld1q_f32(&x[i]);
164         Y = exp4_approx(X);
165         vst1q_f32(&y[i], Y);
166     }
167     for (;i<N;i++)
168         y[i] = lpcnet_exp(x[i]);
169 }
170 
vec_tanh(float * y,const float * x,int N)171 static inline void vec_tanh(float *y, const float *x, int N)
172 {
173     int i;
174     for (i=0;i<N-3;i+=4)
175     {
176         float32x4_t X, Y;
177         X = vld1q_f32(&x[i]);
178         Y = tanh4_approx(X);
179         vst1q_f32(&y[i], Y);
180     }
181     for (;i<N;i++)
182     {
183         float ex2;
184         ex2 = lpcnet_exp(2*x[i]);
185         y[i] = (ex2-1)/(ex2+1);
186     }
187 }
188 
vec_sigmoid(float * y,const float * x,int N)189 static inline void vec_sigmoid(float *y, const float *x, int N)
190 {
191     int i;
192     for (i=0;i<N-3;i+=4)
193     {
194         float32x4_t X, Y;
195         X = vld1q_f32(&x[i]);
196         Y = sigmoid4_approx(X);
197         vst1q_f32(&y[i], Y);
198     }
199     for (;i<N;i++)
200     {
201         float ex;
202         ex = lpcnet_exp(x[i]);
203         y[i] = (ex)/(ex+1);
204     }
205 }
206 #endif
207 
sgemv16x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)208 static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
209 {
210     int i, j;
211     for (i=0;i<rows;i+=16)
212     {
213 	float * restrict y = &out[i];
214 
215 	/* keep y[0..15] in registers for duration of inner loop */
216 
217 	float32x4_t y0_3 = vdupq_n_f32(0);
218 	float32x4_t y4_7 = vdupq_n_f32(0);
219 	float32x4_t y8_11 = vdupq_n_f32(0);
220 	float32x4_t y12_15 = vdupq_n_f32(0);
221 
222 	for (j=0;j<cols;j++)
223 	{
224 	    const float * restrict w;
225 	    float32x4_t wvec0_3, wvec4_7, wvec8_11, wvec12_15;
226 	    float32x4_t xj;
227 
228 	    w = &weights[j*col_stride + i];
229 	    wvec0_3 = vld1q_f32(&w[0]);
230 	    wvec4_7 = vld1q_f32(&w[4]);
231 	    wvec8_11 = vld1q_f32(&w[8]);
232 	    wvec12_15 = vld1q_f32(&w[12]);
233 
234 	    xj = vld1q_dup_f32(&x[j]);
235 
236 	    y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
237 	    y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
238 	    y8_11 = vmlaq_f32(y8_11, wvec8_11, xj);
239 	    y12_15 = vmlaq_f32(y12_15, wvec12_15, xj);
240 	}
241 
242 	/* save y[0..15] back to memory */
243 
244 	vst1q_f32(&y[0], y0_3);
245 	vst1q_f32(&y[4], y4_7);
246 	vst1q_f32(&y[8], y8_11);
247 	vst1q_f32(&y[12], y12_15);
248 
249     }
250 }
251 
sgemv8x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)252 static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
253 {
254     int i, j;
255     for (i=0;i<rows;i+=8)
256     {
257     float * restrict y = &out[i];
258 
259     /* keep y[0..15] in registers for duration of inner loop */
260 
261     float32x4_t y0_3 = vdupq_n_f32(0);
262     float32x4_t y4_7 = vdupq_n_f32(0);
263 
264     for (j=0;j<cols;j++)
265     {
266         const float * restrict w;
267         float32x4_t wvec0_3, wvec4_7;
268         float32x4_t xj;
269 
270         w = &weights[j*col_stride + i];
271         wvec0_3 = vld1q_f32(&w[0]);
272         wvec4_7 = vld1q_f32(&w[4]);
273 
274         xj = vld1q_dup_f32(&x[j]);
275 
276         y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
277         y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
278     }
279 
280     /* save y[0..15] back to memory */
281 
282     vst1q_f32(&y[0], y0_3);
283     vst1q_f32(&y[4], y4_7);
284     }
285 }
286 
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)287 static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
288 {
289    if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
290    else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
291    else {
292       int i, j;
293       for (i=0;i<rows;i++)
294       {
295          out[i] = 0;
296          for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
297       }
298    }
299 }
300 
301 /* Temporarily use unoptimized version */
sparse_sgemv8x4(float * out,const float * w,const int * idx,int rows,const float * x)302 static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
303 {
304    int i, j;
305    OPUS_CLEAR(out, rows);
306    for (i=0;i<rows;i+=8)
307    {
308       int cols;
309       cols = *idx++;
310       for (j=0;j<cols;j++)
311       {
312          int pos;
313          float * restrict y;
314          float xj0, xj1, xj2, xj3;
315          pos = (*idx++);
316          xj0 = x[pos+0];
317          xj1 = x[pos+1];
318          xj2 = x[pos+2];
319          xj3 = x[pos+3];
320          y = &out[i];
321          y[0] += w[0]*xj0;
322          y[1] += w[1]*xj0;
323          y[2] += w[2]*xj0;
324          y[3] += w[3]*xj0;
325          y[4] += w[4]*xj0;
326          y[5] += w[5]*xj0;
327          y[6] += w[6]*xj0;
328          y[7] += w[7]*xj0;
329 
330          y[0] += w[8]*xj1;
331          y[1] += w[9]*xj1;
332          y[2] += w[10]*xj1;
333          y[3] += w[11]*xj1;
334          y[4] += w[12]*xj1;
335          y[5] += w[13]*xj1;
336          y[6] += w[14]*xj1;
337          y[7] += w[15]*xj1;
338 
339          y[0] += w[16]*xj2;
340          y[1] += w[17]*xj2;
341          y[2] += w[18]*xj2;
342          y[3] += w[19]*xj2;
343          y[4] += w[20]*xj2;
344          y[5] += w[21]*xj2;
345          y[6] += w[22]*xj2;
346          y[7] += w[23]*xj2;
347 
348          y[0] += w[24]*xj3;
349          y[1] += w[25]*xj3;
350          y[2] += w[26]*xj3;
351          y[3] += w[27]*xj3;
352          y[4] += w[28]*xj3;
353          y[5] += w[29]*xj3;
354          y[6] += w[30]*xj3;
355          y[7] += w[31]*xj3;
356          w += 32;
357       }
358    }
359 }
360 
361 
362 #define SCALE (128.f*127.f)
363 #define SCALE_1 (1.f/128.f/127.f)
364 
365 #define MAX_INPUTS 2048
366 #define MAX_OUTPUTS 8192
367 
368 #if __ARM_FEATURE_DOTPROD
vdotprod(int32x4_t acc,int8x16_t a,int8x16_t b)369 static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b) {
370   return vdotq_s32(acc, a, b);
371 }
372 #else
vdotprod(int32x4_t acc,int8x16_t a,int8x16_t b)373 static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b)
374 {
375   return vpadalq_s16(acc, vpaddq_s16(vmull_s8(vget_low_s8(a), vget_low_s8(b)),  vmull_high_s8(a, b)));
376 }
377 #endif
378 
cgemv8x4(float * _out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)379 static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
380 {
381    int i, j;
382    opus_int32 x_int[MAX_INPUTS/4];
383    opus_int8 *x = (opus_int8*) x_int;
384    const float32x4_t const127 = vdupq_n_f32(127.);
385    for (i=0;i<cols;i+=8) {
386       int32x4_t xi0, xi4;
387       int16x8_t x_short;
388       xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
389       xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
390       x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
391       vst1_s8(&x[i], vmovn_s16(x_short));
392    }
393    for (i=0;i<rows;i+=8)
394    {
395       int32x4_t acc0, acc1;
396       int32x4_t acc2, acc3;
397       acc0 = vdupq_n_s32(0);
398       acc1 = vdupq_n_s32(0);
399       acc2 = vdupq_n_s32(0);
400       acc3 = vdupq_n_s32(0);
401       j=0;
402       for (;j<cols-4;j+=8)
403       {
404          int8x16_t vw0, vw1, vw2, vw3, vx0, vx1;
405          vx0 = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j]);
406          vw0 = vld1q_s8(w);
407          vw1 = vld1q_s8(&w[16]);
408          acc0 = vdotprod(acc0, vw0, vx0);
409          acc1 = vdotprod(acc1, vw1, vx0);
410          vx1 = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j+4]);
411          vw2 = vld1q_s8(&w[32]);
412          vw3 = vld1q_s8(&w[48]);
413          acc2 = vdotprod(acc2, vw2, vx1);
414          acc3 = vdotprod(acc3, vw3, vx1);
415          w += 64;
416       }
417       acc0 = vaddq_s32(acc0, acc2);
418       acc1 = vaddq_s32(acc1, acc3);
419       for (;j<cols;j+=4)
420       {
421          int8x16_t vw0, vw1, vx;
422          vx = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j]);
423          vw0 = vld1q_s8(w);
424          vw1 = vld1q_s8(&w[16]);
425          acc0 = vdotprod(acc0, vw0, vx);
426          acc1 = vdotprod(acc1, vw1, vx);
427          w += 32;
428       }
429       vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
430       vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
431    }
432 }
433 
sparse_cgemv8x4(float * _out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)434 static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
435 {
436    int i, j;
437    opus_int32 x_int[MAX_INPUTS/4];
438    opus_int8 *x = (opus_int8*) x_int;
439    const float32x4_t const127 = vdupq_n_f32(127.);
440    for (i=0;i<cols;i+=8) {
441       int32x4_t xi0, xi4;
442       int16x8_t x_short;
443       xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
444       xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
445       x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
446       vst1_s8(&x[i], vmovn_s16(x_short));
447    }
448    for (i=0;i<rows;i+=8)
449    {
450       int colblocks;
451       int32x4_t acc0, acc1;
452       acc0 = vdupq_n_s32(0);
453       acc1 = vdupq_n_s32(0);
454       colblocks = *idx++;
455       for (j=0;j<colblocks;j++)
456       {
457          int pos;
458          pos = (*idx++);
459          int8x16_t vw0, vw1, vx;
460          vx = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[pos]);
461          vw0 = vld1q_s8(w);
462          vw1 = vld1q_s8(&w[16]);
463          acc0 = vdotprod(acc0, vw0, vx);
464          acc1 = vdotprod(acc1, vw1, vx);
465          w += 32;
466       }
467       vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
468       vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
469    }
470 }
471 
472 
473 #endif
474