1 /* Copyright (c) 2018 Mozilla
2 2008-2011 Octasic Inc.
3 2012-2017 Jean-Marc Valin */
4 /*
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions
7 are met:
8
9 - Redistributions of source code must retain the above copyright
10 notice, this list of conditions and the following disclaimer.
11
12 - Redistributions in binary form must reproduce the above copyright
13 notice, this list of conditions and the following disclaimer in the
14 documentation and/or other materials provided with the distribution.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
20 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #ifndef VEC_H
30 #define VEC_H
31
32 #include "opus_types.h"
33 #include <math.h>
34 #include "arch.h"
35 #include "x86/x86_arch_macros.h"
36
37
38 #if defined(__AVX__) || defined(__SSE2__)
39 #include "vec_avx.h"
40 #elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) && !defined(DISABLE_NEON)
41 #include "vec_neon.h"
42 #else
43
44 #include "os_support.h"
45
46 #define MAX_INPUTS (2048)
47
48 #define NO_OPTIMIZATIONS
49
sgemv16x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)50 static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
51 {
52 int i, j;
53 OPUS_CLEAR(out, rows);
54 for (i=0;i<rows;i+=16)
55 {
56 for (j=0;j<cols;j++)
57 {
58 const float * restrict w;
59 float * restrict y;
60 float xj;
61 w = &weights[j*col_stride + i];
62 xj = x[j];
63 y = &out[i];
64 y[0] += w[0]*xj;
65 y[1] += w[1]*xj;
66 y[2] += w[2]*xj;
67 y[3] += w[3]*xj;
68 y[4] += w[4]*xj;
69 y[5] += w[5]*xj;
70 y[6] += w[6]*xj;
71 y[7] += w[7]*xj;
72 y[8] += w[8]*xj;
73 y[9] += w[9]*xj;
74 y[10] += w[10]*xj;
75 y[11] += w[11]*xj;
76 y[12] += w[12]*xj;
77 y[13] += w[13]*xj;
78 y[14] += w[14]*xj;
79 y[15] += w[15]*xj;
80 }
81 }
82 }
83
sgemv8x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)84 static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
85 {
86 int i, j;
87 OPUS_CLEAR(out, rows);
88 for (i=0;i<rows;i+=8)
89 {
90 for (j=0;j<cols;j++)
91 {
92 const float * restrict w;
93 float * restrict y;
94 float xj;
95 w = &weights[j*col_stride + i];
96 xj = x[j];
97 y = &out[i];
98 y[0] += w[0]*xj;
99 y[1] += w[1]*xj;
100 y[2] += w[2]*xj;
101 y[3] += w[3]*xj;
102 y[4] += w[4]*xj;
103 y[5] += w[5]*xj;
104 y[6] += w[6]*xj;
105 y[7] += w[7]*xj;
106 }
107 }
108 }
109
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)110 static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
111 {
112 if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
113 else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
114 else {
115 int i, j;
116 for (i=0;i<rows;i++)
117 {
118 out[i] = 0;
119 for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
120 }
121 }
122 }
123
sparse_sgemv8x4(float * out,const float * w,const int * idx,int rows,const float * x)124 static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
125 {
126 int i, j;
127 OPUS_CLEAR(out, rows);
128 for (i=0;i<rows;i+=8)
129 {
130 int cols;
131 cols = *idx++;
132 for (j=0;j<cols;j++)
133 {
134 int pos;
135 float * restrict y;
136 float xj0, xj1, xj2, xj3;
137 pos = (*idx++);
138 xj0 = x[pos+0];
139 xj1 = x[pos+1];
140 xj2 = x[pos+2];
141 xj3 = x[pos+3];
142 y = &out[i];
143 y[0] += w[0]*xj0;
144 y[1] += w[1]*xj0;
145 y[2] += w[2]*xj0;
146 y[3] += w[3]*xj0;
147 y[4] += w[4]*xj0;
148 y[5] += w[5]*xj0;
149 y[6] += w[6]*xj0;
150 y[7] += w[7]*xj0;
151
152 y[0] += w[8]*xj1;
153 y[1] += w[9]*xj1;
154 y[2] += w[10]*xj1;
155 y[3] += w[11]*xj1;
156 y[4] += w[12]*xj1;
157 y[5] += w[13]*xj1;
158 y[6] += w[14]*xj1;
159 y[7] += w[15]*xj1;
160
161 y[0] += w[16]*xj2;
162 y[1] += w[17]*xj2;
163 y[2] += w[18]*xj2;
164 y[3] += w[19]*xj2;
165 y[4] += w[20]*xj2;
166 y[5] += w[21]*xj2;
167 y[6] += w[22]*xj2;
168 y[7] += w[23]*xj2;
169
170 y[0] += w[24]*xj3;
171 y[1] += w[25]*xj3;
172 y[2] += w[26]*xj3;
173 y[3] += w[27]*xj3;
174 y[4] += w[28]*xj3;
175 y[5] += w[29]*xj3;
176 y[6] += w[30]*xj3;
177 y[7] += w[31]*xj3;
178 w += 32;
179 }
180 }
181 }
182
183 #ifdef USE_SU_BIAS
sparse_cgemv8x4(float * out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)184 static inline void sparse_cgemv8x4(float *out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
185 {
186 int i, j;
187 unsigned char x[MAX_INPUTS];
188 for (i=0;i<rows;i++) out[i] = 0;
189 for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
190 for (i=0;i<rows;i+=8)
191 {
192 int colblocks;
193 colblocks = *idx++;
194 for (j=0;j<colblocks;j++)
195 {
196 int pos;
197 float * restrict y;
198 int xj0, xj1, xj2, xj3;
199 pos = (*idx++);
200 xj0 = x[pos+0];
201 xj1 = x[pos+1];
202 xj2 = x[pos+2];
203 xj3 = x[pos+3];
204 y = &out[i];
205 y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
206 y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
207 y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
208 y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
209 y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
210 y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
211 y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
212 y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
213 w += 32;
214 }
215 }
216 for (i=0;i<rows;i++) out[i] *= scale[i];
217 }
cgemv8x4(float * out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)218 static inline void cgemv8x4(float *out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
219 {
220 int i, j;
221 unsigned char x[MAX_INPUTS];
222 for (i=0;i<rows;i++) out[i] = 0;
223 for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
224 for (i=0;i<rows;i+=8)
225 {
226 for (j=0;j<cols;j+=4)
227 {
228 float *y;
229 float xj0, xj1, xj2, xj3;
230 xj0 = x[j+0];
231 xj1 = x[j+1];
232 xj2 = x[j+2];
233 xj3 = x[j+3];
234 y = &out[i];
235 y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
236 y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
237 y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
238 y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
239 y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
240 y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
241 y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
242 y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
243 w += 32;
244 }
245 }
246 for (i=0;i<rows;i++) out[i] *= scale[i];
247 }
248 #else
sparse_cgemv8x4(float * out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)249 static inline void sparse_cgemv8x4(float *out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
250 {
251 int i, j;
252 opus_int8 x[MAX_INPUTS];
253 for (i=0;i<rows;i++) out[i] = 0;
254 for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
255 for (i=0;i<rows;i+=8)
256 {
257 int colblocks;
258 colblocks = *idx++;
259 for (j=0;j<colblocks;j++)
260 {
261 int pos;
262 float * restrict y;
263 int xj0, xj1, xj2, xj3;
264 pos = (*idx++);
265 xj0 = x[pos+0];
266 xj1 = x[pos+1];
267 xj2 = x[pos+2];
268 xj3 = x[pos+3];
269 y = &out[i];
270 y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
271 y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
272 y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
273 y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
274 y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
275 y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
276 y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
277 y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
278 w += 32;
279 }
280 }
281 for (i=0;i<rows;i++) out[i] *= scale[i];
282 }
cgemv8x4(float * out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)283 static inline void cgemv8x4(float *out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
284 {
285 int i, j;
286 opus_int8 x[MAX_INPUTS];
287 for (i=0;i<rows;i++) out[i] = 0;
288 for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]);
289 for (i=0;i<rows;i+=8)
290 {
291 for (j=0;j<cols;j+=4)
292 {
293 float *y;
294 float xj0, xj1, xj2, xj3;
295 xj0 = x[j+0];
296 xj1 = x[j+1];
297 xj2 = x[j+2];
298 xj3 = x[j+3];
299 y = &out[i];
300 y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
301 y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
302 y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
303 y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
304 y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
305 y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
306 y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
307 y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
308 w += 32;
309 }
310 }
311 for (i=0;i<rows;i++) out[i] *= scale[i];
312 }
313 #endif
314
315 /* No AVX2/FMA support */
316 #ifndef LPCNET_TEST
lpcnet_exp2(float x)317 static inline float lpcnet_exp2(float x)
318 {
319 int integer;
320 float frac;
321 union {
322 float f;
323 opus_uint32 i;
324 } res;
325 integer = floor(x);
326 if (integer < -50)
327 return 0;
328 frac = x-integer;
329 /* K0 = 1, K1 = log(2), K2 = 3-4*log(2), K3 = 3*log(2) - 2 */
330 res.f = 0.99992522f + frac * (0.69583354f
331 + frac * (0.22606716f + 0.078024523f*frac));
332 res.i = (res.i + (integer<<23)) & 0x7fffffff;
333 return res.f;
334 }
335 #define lpcnet_exp(x) lpcnet_exp2((x)*1.44269504f)
336
337 #define fmadd(a, b, c) ((a)*(b)+(c))
tanh_approx(float x)338 static OPUS_INLINE float tanh_approx(float x)
339 {
340 const float N0 = 952.52801514f;
341 const float N1 = 96.39235687f;
342 const float N2 = 0.60863042f;
343 const float D0 = 952.72399902f;
344 const float D1 = 413.36801147f;
345 const float D2 = 11.88600922f;
346 float X2, num, den;
347 X2 = x*x;
348 num = fmadd(fmadd(N2, X2, N1), X2, N0);
349 den = fmadd(fmadd(D2, X2, D1), X2, D0);
350 num = num*x/den;
351 return MAX32(-1.f, MIN32(1.f, num));
352 }
353
sigmoid_approx(float x)354 static inline float sigmoid_approx(float x)
355 {
356 return .5f + .5f*tanh_approx(.5f*x);
357 }
358
softmax(float * y,const float * x,int N)359 static inline void softmax(float *y, const float *x, int N)
360 {
361 int i;
362 for (i=0;i<N;i++)
363 y[i] = lpcnet_exp(x[i]);
364 }
365
vec_tanh(float * y,const float * x,int N)366 static inline void vec_tanh(float *y, const float *x, int N)
367 {
368 int i;
369 for (i=0;i<N;i++)
370 {
371 y[i] = tanh_approx(x[i]);
372 }
373 }
374
vec_sigmoid(float * y,const float * x,int N)375 static inline void vec_sigmoid(float *y, const float *x, int N)
376 {
377 int i;
378 for (i=0;i<N;i++)
379 {
380 y[i] = sigmoid_approx(x[i]);
381 }
382 }
383 #endif
384
385 #define SCALE (128.f*127.f)
386 #define SCALE_1 (1.f/128.f/127.f)
387
388 #endif /*no optimizations*/
389 #endif /*VEC_H*/
390