xref: /aosp_15_r20/external/libopus/dnn/vec.h (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1 /* Copyright (c) 2018 Mozilla
2                  2008-2011 Octasic Inc.
3                  2012-2017 Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8 
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11 
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15 
16    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
20    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28 
29 #ifndef VEC_H
30 #define VEC_H
31 
32 #include "opus_types.h"
33 #include <math.h>
34 #include "arch.h"
35 #include "x86/x86_arch_macros.h"
36 
37 
38 #if defined(__AVX__) || defined(__SSE2__)
39 #include "vec_avx.h"
40 #elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) && !defined(DISABLE_NEON)
41 #include "vec_neon.h"
42 #else
43 
44 #include "os_support.h"
45 
46 #define MAX_INPUTS (2048)
47 
48 #define NO_OPTIMIZATIONS
49 
sgemv16x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)50 static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
51 {
52    int i, j;
53    OPUS_CLEAR(out, rows);
54    for (i=0;i<rows;i+=16)
55    {
56       for (j=0;j<cols;j++)
57       {
58          const float * restrict w;
59          float * restrict y;
60          float xj;
61          w = &weights[j*col_stride + i];
62          xj = x[j];
63          y = &out[i];
64          y[0] += w[0]*xj;
65          y[1] += w[1]*xj;
66          y[2] += w[2]*xj;
67          y[3] += w[3]*xj;
68          y[4] += w[4]*xj;
69          y[5] += w[5]*xj;
70          y[6] += w[6]*xj;
71          y[7] += w[7]*xj;
72          y[8] += w[8]*xj;
73          y[9] += w[9]*xj;
74          y[10] += w[10]*xj;
75          y[11] += w[11]*xj;
76          y[12] += w[12]*xj;
77          y[13] += w[13]*xj;
78          y[14] += w[14]*xj;
79          y[15] += w[15]*xj;
80       }
81    }
82 }
83 
sgemv8x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)84 static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
85 {
86    int i, j;
87    OPUS_CLEAR(out, rows);
88    for (i=0;i<rows;i+=8)
89    {
90       for (j=0;j<cols;j++)
91       {
92          const float * restrict w;
93          float * restrict y;
94          float xj;
95          w = &weights[j*col_stride + i];
96          xj = x[j];
97          y = &out[i];
98          y[0] += w[0]*xj;
99          y[1] += w[1]*xj;
100          y[2] += w[2]*xj;
101          y[3] += w[3]*xj;
102          y[4] += w[4]*xj;
103          y[5] += w[5]*xj;
104          y[6] += w[6]*xj;
105          y[7] += w[7]*xj;
106       }
107    }
108 }
109 
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)110 static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
111 {
112    if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
113    else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
114    else {
115       int i, j;
116       for (i=0;i<rows;i++)
117       {
118          out[i] = 0;
119          for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
120       }
121    }
122 }
123 
sparse_sgemv8x4(float * out,const float * w,const int * idx,int rows,const float * x)124 static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
125 {
126    int i, j;
127    OPUS_CLEAR(out, rows);
128    for (i=0;i<rows;i+=8)
129    {
130       int cols;
131       cols = *idx++;
132       for (j=0;j<cols;j++)
133       {
134          int pos;
135          float * restrict y;
136          float xj0, xj1, xj2, xj3;
137          pos = (*idx++);
138          xj0 = x[pos+0];
139          xj1 = x[pos+1];
140          xj2 = x[pos+2];
141          xj3 = x[pos+3];
142          y = &out[i];
143          y[0] += w[0]*xj0;
144          y[1] += w[1]*xj0;
145          y[2] += w[2]*xj0;
146          y[3] += w[3]*xj0;
147          y[4] += w[4]*xj0;
148          y[5] += w[5]*xj0;
149          y[6] += w[6]*xj0;
150          y[7] += w[7]*xj0;
151 
152          y[0] += w[8]*xj1;
153          y[1] += w[9]*xj1;
154          y[2] += w[10]*xj1;
155          y[3] += w[11]*xj1;
156          y[4] += w[12]*xj1;
157          y[5] += w[13]*xj1;
158          y[6] += w[14]*xj1;
159          y[7] += w[15]*xj1;
160 
161          y[0] += w[16]*xj2;
162          y[1] += w[17]*xj2;
163          y[2] += w[18]*xj2;
164          y[3] += w[19]*xj2;
165          y[4] += w[20]*xj2;
166          y[5] += w[21]*xj2;
167          y[6] += w[22]*xj2;
168          y[7] += w[23]*xj2;
169 
170          y[0] += w[24]*xj3;
171          y[1] += w[25]*xj3;
172          y[2] += w[26]*xj3;
173          y[3] += w[27]*xj3;
174          y[4] += w[28]*xj3;
175          y[5] += w[29]*xj3;
176          y[6] += w[30]*xj3;
177          y[7] += w[31]*xj3;
178          w += 32;
179       }
180    }
181 }
182 
183 #ifdef USE_SU_BIAS
sparse_cgemv8x4(float * out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)184 static inline void sparse_cgemv8x4(float *out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
185 {
186    int i, j;
187    unsigned char x[MAX_INPUTS];
188    for (i=0;i<rows;i++) out[i] = 0;
189    for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
190    for (i=0;i<rows;i+=8)
191    {
192       int colblocks;
193       colblocks = *idx++;
194       for (j=0;j<colblocks;j++)
195       {
196          int pos;
197          float * restrict y;
198          int xj0, xj1, xj2, xj3;
199          pos = (*idx++);
200          xj0 = x[pos+0];
201          xj1 = x[pos+1];
202          xj2 = x[pos+2];
203          xj3 = x[pos+3];
204          y = &out[i];
205          y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
206          y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
207          y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
208          y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
209          y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
210          y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
211          y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
212          y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
213          w += 32;
214       }
215    }
216    for (i=0;i<rows;i++) out[i] *= scale[i];
217 }
cgemv8x4(float * out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)218 static inline void cgemv8x4(float *out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
219 {
220    int i, j;
221    unsigned char x[MAX_INPUTS];
222    for (i=0;i<rows;i++) out[i] = 0;
223    for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
224    for (i=0;i<rows;i+=8)
225    {
226       for (j=0;j<cols;j+=4)
227       {
228          float *y;
229          float xj0, xj1, xj2, xj3;
230          xj0 = x[j+0];
231          xj1 = x[j+1];
232          xj2 = x[j+2];
233          xj3 = x[j+3];
234          y = &out[i];
235          y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
236          y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
237          y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
238          y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
239          y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
240          y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
241          y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
242          y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
243          w += 32;
244       }
245    }
246    for (i=0;i<rows;i++) out[i] *= scale[i];
247 }
248 #else
sparse_cgemv8x4(float * out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)249 static inline void sparse_cgemv8x4(float *out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
250 {
251    int i, j;
252    opus_int8 x[MAX_INPUTS];
253    for (i=0;i<rows;i++) out[i] = 0;
254    for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
255    for (i=0;i<rows;i+=8)
256    {
257       int colblocks;
258       colblocks = *idx++;
259       for (j=0;j<colblocks;j++)
260       {
261          int pos;
262          float * restrict y;
263          int xj0, xj1, xj2, xj3;
264          pos = (*idx++);
265          xj0 = x[pos+0];
266          xj1 = x[pos+1];
267          xj2 = x[pos+2];
268          xj3 = x[pos+3];
269          y = &out[i];
270          y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
271          y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
272          y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
273          y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
274          y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
275          y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
276          y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
277          y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
278          w += 32;
279       }
280    }
281    for (i=0;i<rows;i++) out[i] *= scale[i];
282 }
cgemv8x4(float * out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)283 static inline void cgemv8x4(float *out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
284 {
285    int i, j;
286    opus_int8 x[MAX_INPUTS];
287    for (i=0;i<rows;i++) out[i] = 0;
288    for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
289    for (i=0;i<rows;i+=8)
290    {
291       for (j=0;j<cols;j+=4)
292       {
293          float *y;
294          float xj0, xj1, xj2, xj3;
295          xj0 = x[j+0];
296          xj1 = x[j+1];
297          xj2 = x[j+2];
298          xj3 = x[j+3];
299          y = &out[i];
300          y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
301          y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
302          y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
303          y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
304          y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
305          y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
306          y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
307          y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
308          w += 32;
309       }
310    }
311    for (i=0;i<rows;i++) out[i] *= scale[i];
312 }
313 #endif
314 
315 /* No AVX2/FMA support */
316 #ifndef LPCNET_TEST
lpcnet_exp2(float x)317 static inline float lpcnet_exp2(float x)
318 {
319    int integer;
320    float frac;
321    union {
322       float f;
323       opus_uint32 i;
324    } res;
325    integer = floor(x);
326    if (integer < -50)
327       return 0;
328    frac = x-integer;
329    /* K0 = 1, K1 = log(2), K2 = 3-4*log(2), K3 = 3*log(2) - 2 */
330    res.f = 0.99992522f + frac * (0.69583354f
331            + frac * (0.22606716f + 0.078024523f*frac));
332    res.i = (res.i + (integer<<23)) & 0x7fffffff;
333    return res.f;
334 }
335 #define lpcnet_exp(x) lpcnet_exp2((x)*1.44269504f)
336 
337 #define fmadd(a, b, c) ((a)*(b)+(c))
tanh_approx(float x)338 static OPUS_INLINE float tanh_approx(float x)
339 {
340     const float N0 = 952.52801514f;
341     const float N1 = 96.39235687f;
342     const float N2 = 0.60863042f;
343     const float D0 = 952.72399902f;
344     const float D1 = 413.36801147f;
345     const float D2 = 11.88600922f;
346     float X2, num, den;
347     X2 = x*x;
348     num = fmadd(fmadd(N2, X2, N1), X2, N0);
349     den = fmadd(fmadd(D2, X2, D1), X2, D0);
350     num = num*x/den;
351     return MAX32(-1.f, MIN32(1.f, num));
352 }
353 
sigmoid_approx(float x)354 static inline float sigmoid_approx(float x)
355 {
356    return .5f + .5f*tanh_approx(.5f*x);
357 }
358 
softmax(float * y,const float * x,int N)359 static inline void softmax(float *y, const float *x, int N)
360 {
361     int i;
362     for (i=0;i<N;i++)
363         y[i] = lpcnet_exp(x[i]);
364 }
365 
vec_tanh(float * y,const float * x,int N)366 static inline void vec_tanh(float *y, const float *x, int N)
367 {
368     int i;
369     for (i=0;i<N;i++)
370     {
371         y[i] = tanh_approx(x[i]);
372     }
373 }
374 
vec_sigmoid(float * y,const float * x,int N)375 static inline void vec_sigmoid(float *y, const float *x, int N)
376 {
377     int i;
378     for (i=0;i<N;i++)
379     {
380         y[i] = sigmoid_approx(x[i]);
381     }
382 }
383 #endif
384 
385 #define SCALE (128.f*127.f)
386 #define SCALE_1 (1.f/128.f/127.f)
387 
388 #endif /*no optimizations*/
389 #endif /*VEC_H*/
390