1 /* Copyright (c) 2018 Mozilla
2 2012-2017 Jean-Marc Valin */
3 /*
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions
6 are met:
7
8 - Redistributions of source code must retain the above copyright
9 notice, this list of conditions and the following disclaimer.
10
11 - Redistributions in binary form must reproduce the above copyright
12 notice, this list of conditions and the following disclaimer in the
13 documentation and/or other materials provided with the distribution.
14
15 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
19 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 /*
28 AVX implementation of vector operations, compile with -mavx
29 AVX2/FMA implementation of vector operations, compile with -mavx2 -mfma
30 */
31
32 #ifndef VEC_AVX_H
33 #define VEC_AVX_H
34
35 #include <immintrin.h>
36 #include <math.h>
37 #include "celt/x86/x86cpu.h"
38
39 #define MAX_INPUTS (2048)
40
41 #define USE_SU_BIAS
42
43 #ifndef __SSE_4_1__
mm_floor_ps(__m128 x)44 static inline __m128 mm_floor_ps(__m128 x) {
45 __m128 half = _mm_set1_ps(0.5);
46 return _mm_cvtepi32_ps(_mm_cvtps_epi32(_mm_sub_ps(x, half)));
47 }
48 #undef _mm_floor_ps
49 #define _mm_floor_ps(x) mm_floor_ps(x)
50 #endif
51
52
53 /* If we don't have AVX available, emulate what we need with SSE up to 4.1. */
54 #ifndef __AVX__
55
56 typedef struct {
57 __m128 lo;
58 __m128 hi;
59 } mm256_emu;
60 #define __m256 mm256_emu
61
mm256_loadu_ps(const float * src)62 static inline mm256_emu mm256_loadu_ps(const float *src) {
63 mm256_emu ret;
64 ret.lo = _mm_loadu_ps(&src[0]);
65 ret.hi = _mm_loadu_ps(&src[4]);
66 return ret;
67 }
68 #define _mm256_loadu_ps(src) mm256_loadu_ps(src)
69
70
mm256_storeu_ps(float * dst,mm256_emu src)71 static inline void mm256_storeu_ps(float *dst, mm256_emu src) {
72 _mm_storeu_ps(dst, src.lo);
73 _mm_storeu_ps(&dst[4], src.hi);
74 }
75 #define _mm256_storeu_ps(dst, src) mm256_storeu_ps(dst, src)
76
77
mm256_setzero_ps(void)78 static inline mm256_emu mm256_setzero_ps(void) {
79 mm256_emu ret;
80 ret.lo = _mm_setzero_ps();
81 ret.hi = ret.lo;
82 return ret;
83 }
84 #define _mm256_setzero_ps mm256_setzero_ps
85
mm256_broadcast_ss(const float * x)86 static inline mm256_emu mm256_broadcast_ss(const float *x) {
87 mm256_emu ret;
88 ret.lo = _mm_set1_ps(*x);
89 ret.hi = ret.lo;
90 return ret;
91 }
92 #define _mm256_broadcast_ss(x) mm256_broadcast_ss(x)
93
mm256_set1_ps(float x)94 static inline mm256_emu mm256_set1_ps(float x) {
95 mm256_emu ret;
96 ret.lo = _mm_set1_ps(x);
97 ret.hi = ret.lo;
98 return ret;
99 }
100 #define _mm256_set1_ps(x) mm256_set1_ps(x)
101
102
103
mm256_mul_ps(mm256_emu a,mm256_emu b)104 static inline mm256_emu mm256_mul_ps(mm256_emu a, mm256_emu b) {
105 mm256_emu ret;
106 ret.lo = _mm_mul_ps(a.lo, b.lo);
107 ret.hi = _mm_mul_ps(a.hi, b.hi);
108 return ret;
109 }
110 #define _mm256_mul_ps(a,b) mm256_mul_ps(a,b)
111
mm256_add_ps(mm256_emu a,mm256_emu b)112 static inline mm256_emu mm256_add_ps(mm256_emu a, mm256_emu b) {
113 mm256_emu ret;
114 ret.lo = _mm_add_ps(a.lo, b.lo);
115 ret.hi = _mm_add_ps(a.hi, b.hi);
116 return ret;
117 }
118 #define _mm256_add_ps(a,b) mm256_add_ps(a,b)
119
120
mm256_max_ps(mm256_emu a,mm256_emu b)121 static inline mm256_emu mm256_max_ps(mm256_emu a, mm256_emu b) {
122 mm256_emu ret;
123 ret.lo = _mm_max_ps(a.lo, b.lo);
124 ret.hi = _mm_max_ps(a.hi, b.hi);
125 return ret;
126 }
127 #define _mm256_max_ps(a,b) mm256_max_ps(a,b)
128
mm256_min_ps(mm256_emu a,mm256_emu b)129 static inline mm256_emu mm256_min_ps(mm256_emu a, mm256_emu b) {
130 mm256_emu ret;
131 ret.lo = _mm_min_ps(a.lo, b.lo);
132 ret.hi = _mm_min_ps(a.hi, b.hi);
133 return ret;
134 }
135 #define _mm256_min_ps(a,b) mm256_min_ps(a,b)
136
mm256_rcp_ps(mm256_emu a)137 static inline mm256_emu mm256_rcp_ps(mm256_emu a) {
138 mm256_emu ret;
139 ret.lo = _mm_rcp_ps(a.lo);
140 ret.hi = _mm_rcp_ps(a.hi);
141 return ret;
142 }
143 #define _mm256_rcp_ps(a) mm256_rcp_ps(a)
144
145
mm256_extractf128_ps(mm256_emu x,int i)146 static inline __m128 mm256_extractf128_ps(mm256_emu x, int i) {
147 return (i==0) ? x.lo : x.hi;
148 }
149 #undef _mm256_extractf128_ps
150 #define _mm256_extractf128_ps(x,i) mm256_extractf128_ps(x,i)
151
mm256_insertf128_ps(mm256_emu dst,__m128 src,int i)152 static inline mm256_emu mm256_insertf128_ps(mm256_emu dst, __m128 src, int i) {
153 if (i==0) dst.lo = src;
154 else dst.hi = src;
155 return dst;
156 }
157 #undef _mm256_insertf128_ps
158 #define _mm256_insertf128_ps(dst,src,i) mm256_insertf128_ps(dst,src,i)
159
160 #endif /* __AVX__ */
161
162
163
164 /* If we don't have AVX2 available, emulate what we need with SSE up to 4.1. */
165 #ifndef __AVX2__
166
167 typedef struct {
168 __m128i lo;
169 __m128i hi;
170 } mm256i_emu;
171 typedef __m256i real_m256i;
172 #define __m256i mm256i_emu
173
mm256_setzero_si256(void)174 static inline mm256i_emu mm256_setzero_si256(void) {
175 mm256i_emu ret;
176 ret.lo = _mm_setzero_si128();
177 ret.hi = ret.lo;
178 return ret;
179 }
180 #define _mm256_setzero_si256 mm256_setzero_si256
181
182
mm256_loadu_si256(const mm256i_emu * src)183 static inline mm256i_emu mm256_loadu_si256(const mm256i_emu *src) {
184 mm256i_emu ret;
185 ret.lo = _mm_loadu_si128((const __m128i*)src);
186 ret.hi = _mm_loadu_si128(&((const __m128i*)src)[1]);
187 return ret;
188 }
189 #define _mm256_loadu_si256(src) mm256_loadu_si256(src)
190
191
mm256_storeu_si256(mm256i_emu * dst,mm256i_emu src)192 static inline void mm256_storeu_si256(mm256i_emu *dst, mm256i_emu src) {
193 _mm_storeu_si128((__m128i*)dst, src.lo);
194 _mm_storeu_si128(&((__m128i*)dst)[1], src.hi);
195 }
196 #define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src)
197
198
mm256_broadcastd_epi32(__m128i x)199 static inline mm256i_emu mm256_broadcastd_epi32(__m128i x) {
200 mm256i_emu ret;
201 ret.hi = ret.lo = _mm_shuffle_epi32(x, 0);
202 return ret;
203 }
204 #define _mm256_broadcastd_epi32(x) mm256_broadcastd_epi32(x)
205
206
mm256_set1_epi32(int x)207 static inline mm256i_emu mm256_set1_epi32(int x) {
208 mm256i_emu ret;
209 ret.lo = _mm_set1_epi32(x);
210 ret.hi = ret.lo;
211 return ret;
212 }
213 #define _mm256_set1_epi32(x) mm256_set1_epi32(x)
214
mm256_set1_epi16(int x)215 static inline mm256i_emu mm256_set1_epi16(int x) {
216 mm256i_emu ret;
217 ret.lo = _mm_set1_epi16(x);
218 ret.hi = ret.lo;
219 return ret;
220 }
221 #define _mm256_set1_epi16(x) mm256_set1_epi16(x)
222
223
mm256_add_epi32(mm256i_emu a,mm256i_emu b)224 static inline mm256i_emu mm256_add_epi32(mm256i_emu a, mm256i_emu b) {
225 mm256i_emu ret;
226 ret.lo = _mm_add_epi32(a.lo, b.lo);
227 ret.hi = _mm_add_epi32(a.hi, b.hi);
228 return ret;
229 }
230 #define _mm256_add_epi32(a,b) mm256_add_epi32(a,b)
231
mm256_madd_epi16(mm256i_emu a,mm256i_emu b)232 static inline mm256i_emu mm256_madd_epi16(mm256i_emu a, mm256i_emu b) {
233 mm256i_emu ret;
234 ret.lo = _mm_madd_epi16(a.lo, b.lo);
235 ret.hi = _mm_madd_epi16(a.hi, b.hi);
236 return ret;
237 }
238 #define _mm256_madd_epi16(a,b) mm256_madd_epi16(a,b)
239
mm256_maddubs_epi16(mm256i_emu a,mm256i_emu b)240 static inline mm256i_emu mm256_maddubs_epi16(mm256i_emu a, mm256i_emu b) {
241 mm256i_emu ret;
242 ret.lo = _mm_maddubs_epi16(a.lo, b.lo);
243 ret.hi = _mm_maddubs_epi16(a.hi, b.hi);
244 return ret;
245 }
246 #define _mm256_maddubs_epi16(a,b) mm256_maddubs_epi16(a,b)
247
248
249
250 /* Emulating the conversion functions is tricky because they use __m256i but are defined in AVX.
251 So we need to make a special when only AVX is available. */
252 #ifdef __AVX__
253
254 typedef union {
255 mm256i_emu fake;
256 real_m256i real;
257 } mm256_union;
258
mm256_cvtepi32_ps(mm256i_emu a)259 static inline __m256 mm256_cvtepi32_ps(mm256i_emu a) {
260 mm256_union src;
261 src.fake = a;
262 return _mm256_cvtepi32_ps(src.real);
263 }
264 #define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
265
mm256_cvtps_epi32(__m256 a)266 static inline mm256i_emu mm256_cvtps_epi32(__m256 a) {
267 mm256_union ret;
268 ret.real = _mm256_cvtps_epi32(a);
269 return ret.fake;
270 }
271 #define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
272
273
274 #else
275
mm256_cvtepi32_ps(mm256i_emu a)276 static inline mm256_emu mm256_cvtepi32_ps(mm256i_emu a) {
277 mm256_emu ret;
278 ret.lo = _mm_cvtepi32_ps(a.lo);
279 ret.hi = _mm_cvtepi32_ps(a.hi);
280 return ret;
281 }
282 #define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
283
mm256_cvtps_epi32(mm256_emu a)284 static inline mm256i_emu mm256_cvtps_epi32(mm256_emu a) {
285 mm256i_emu ret;
286 ret.lo = _mm_cvtps_epi32(a.lo);
287 ret.hi = _mm_cvtps_epi32(a.hi);
288 return ret;
289 }
290 #define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
291
292 #endif /* __AVX__ */
293
294
295 #endif /* __AVX2__ */
296
297 /* In case we don't have FMA, make it a mul and an add. */
298 #if !(defined(__FMA__) && defined(__AVX__))
299 #define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
300 #define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
301 #endif
302
303 #ifdef __AVX2__
exp8_approx(__m256 X)304 static inline __m256 exp8_approx(__m256 X)
305 {
306 const __m256 K0 = _mm256_set1_ps(0.99992522f);
307 const __m256 K1 = _mm256_set1_ps(0.69583354f);
308 const __m256 K2 = _mm256_set1_ps(0.22606716f);
309 const __m256 K3 = _mm256_set1_ps(0.078024523f);
310 const __m256 log2_E = _mm256_set1_ps(1.44269504f);
311 const __m256 max_in = _mm256_set1_ps(50.f);
312 const __m256 min_in = _mm256_set1_ps(-50.f);
313 __m256 XF, Y;
314 __m256i I;
315 X = _mm256_mul_ps(X, log2_E);
316 X = _mm256_max_ps(min_in, _mm256_min_ps(max_in, X));
317 XF = _mm256_floor_ps(X);
318 I = _mm256_cvtps_epi32(XF);
319 X = _mm256_sub_ps(X, XF);
320 Y = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(K3, X, K2), X, K1), X, K0);
321 I = _mm256_slli_epi32(I, 23);
322 Y = _mm256_castsi256_ps(_mm256_add_epi32(I, _mm256_castps_si256(Y)));
323 return Y;
324 }
325
vector_ps_to_epi8(unsigned char * x,const float * _x,int len)326 static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
327 int i;
328 __m256 const127 = _mm256_set1_ps(127.f);
329 for (i=0;i<len;i+=8) {
330 __m256 xf;
331 __m256i xi;
332 xf = _mm256_loadu_ps(&_x[i]);
333 xf = _mm256_fmadd_ps(xf, const127, const127);
334 xi = _mm256_cvtps_epi32(xf);
335 xi = _mm256_packus_epi32(xi, _mm256_setzero_si256());
336 xi = _mm256_permute4x64_epi64(xi, 0xD8);
337 xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
338 xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
339 _mm256_storeu_si256 ((__m256i *)(void*)&x[i], xi);
340 }
341 }
342
343 #else
exp4_approx(__m128 X)344 static inline __m128 exp4_approx(__m128 X)
345 {
346 const __m128 K0 = _mm_set1_ps(0.99992522f);
347 const __m128 K1 = _mm_set1_ps(0.69583354f);
348 const __m128 K2 = _mm_set1_ps(0.22606716f);
349 const __m128 K3 = _mm_set1_ps(0.078024523f);
350 const __m128 log2_E = _mm_set1_ps(1.44269504);
351 const __m128 max_in = _mm_set1_ps(50.f);
352 const __m128 min_in = _mm_set1_ps(-50.f);
353 const __m128i mask = _mm_set1_epi32(0x7fffffff);
354 __m128 XF, Y;
355 __m128i I;
356 X = _mm_mul_ps(X, log2_E);
357 X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
358 XF = _mm_floor_ps(X);
359 I = _mm_cvtps_epi32(XF);
360 X = _mm_sub_ps(X, XF);
361 Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
362 I = _mm_slli_epi32(I, 23);
363 Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
364 return Y;
365 }
exp8_approx(__m256 X)366 static inline __m256 exp8_approx(__m256 X)
367 {
368 __m256 Y;
369 __m128 Xhi, Xlo, Yhi, Ylo;
370 Xhi = _mm256_extractf128_ps(X, 1);
371 Xlo = _mm256_extractf128_ps(X, 0);
372 Yhi = exp4_approx(Xhi);
373 Ylo = exp4_approx(Xlo);
374 Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
375 Y = _mm256_insertf128_ps(Y, Ylo, 0);
376 return Y;
377 }
378
vector_ps_to_epi8(unsigned char * x,const float * _x,int len)379 static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
380 int i;
381 for (i=0;i<len;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
382 }
383
384 #endif
385
386
387 #ifdef __AVX__
388
389 /* Approximating tanh() using a Padé-like rational function:
390 tanh(x) ~= x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
391 subject to the +/- 1 bounds.
392 The coefficients were determined by gradient descent trying to minimize
393 the maximum deviation over the whole range (this is only possible because
394 of the bounds). The max error is around 3e-4 and is dominated by the
395 reciprocal approximation (the max error of the rational function is
396 around 6e-5).
397 */
tanh8_approx(__m256 X)398 static inline __m256 tanh8_approx(__m256 X)
399 {
400 const __m256 N0 = _mm256_set1_ps(952.52801514f);
401 const __m256 N1 = _mm256_set1_ps(96.39235687f);
402 const __m256 N2 = _mm256_set1_ps(0.60863042f);
403 const __m256 D0 = _mm256_set1_ps(952.72399902f);
404 const __m256 D1 = _mm256_set1_ps(413.36801147f);
405 const __m256 D2 = _mm256_set1_ps(11.88600922f);
406 const __m256 max_out = _mm256_set1_ps(1.f);
407 const __m256 min_out = _mm256_set1_ps(-1.f);
408 __m256 X2, num, den;
409 X2 = _mm256_mul_ps(X, X);
410 num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
411 den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
412 num = _mm256_mul_ps(num, X);
413 den = _mm256_rcp_ps(den);
414 num = _mm256_mul_ps(num, den);
415 return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
416 }
417
418 /* Sigmoid approximation using a Padé-like rational function:
419 1/(1+exp(-x)) ~= 0.5 + x * (N0 + N1*x^2 + N2*x^4)/(D0 + D1*x^2 + D2*x^4)
420 subject to the [0, 1] bounds.
421 The coefficients are directly derived by dividing the tanh() coefficients
422 by powers of two to get the correct scaling. The max error is around 1.5e-4
423 and is dominated by the reciprocal approximation (the max error of the
424 rational function is around 3e-5).
425 */
sigmoid8_approx(__m256 X)426 static inline __m256 sigmoid8_approx(__m256 X)
427 {
428 const __m256 N0 = _mm256_set1_ps(238.13200378f);
429 const __m256 N1 = _mm256_set1_ps(6.02452230f);
430 const __m256 N2 = _mm256_set1_ps(0.00950985f);
431 const __m256 D0 = _mm256_set1_ps(952.72399902f);
432 const __m256 D1 = _mm256_set1_ps(103.34200287f);
433 const __m256 D2 = _mm256_set1_ps(0.74287558f);
434 const __m256 half = _mm256_set1_ps(0.5);
435 const __m256 max_out = _mm256_set1_ps(1.f);
436 const __m256 min_out = _mm256_set1_ps(0.f);
437 __m256 X2, num, den;
438 X2 = _mm256_mul_ps(X, X);
439 num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
440 den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
441 num = _mm256_mul_ps(num, X);
442 den = _mm256_rcp_ps(den);
443 num = _mm256_fmadd_ps(num, den, half);
444 return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
445 }
446
tanh_approx(float x)447 static inline float tanh_approx(float x)
448 {
449 float out[8];
450 __m256 X, Y;
451 X = _mm256_set1_ps(x);
452 Y = tanh8_approx(X);
453 _mm256_storeu_ps(out, Y);
454 return out[0];
455 }
456
sigmoid_approx(float x)457 static inline float sigmoid_approx(float x)
458 {
459 float out[8];
460 __m256 X, Y;
461 X = _mm256_set1_ps(x);
462 Y = sigmoid8_approx(X);
463 _mm256_storeu_ps(out, Y);
464 return out[0];
465 }
466
467 #else
468
tanh4_approx(__m128 X)469 static inline __m128 tanh4_approx(__m128 X)
470 {
471 const __m128 N0 = _mm_set1_ps(952.52801514f);
472 const __m128 N1 = _mm_set1_ps(96.39235687f);
473 const __m128 N2 = _mm_set1_ps(0.60863042f);
474 const __m128 D0 = _mm_set1_ps(952.72399902f);
475 const __m128 D1 = _mm_set1_ps(413.36801147f);
476 const __m128 D2 = _mm_set1_ps(11.88600922f);
477 const __m128 max_out = _mm_set1_ps(1.f);
478 const __m128 min_out = _mm_set1_ps(-1.f);
479 __m128 X2, num, den;
480 X2 = _mm_mul_ps(X, X);
481 num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
482 den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
483 num = _mm_mul_ps(num, X);
484 den = _mm_rcp_ps(den);
485 num = _mm_mul_ps(num, den);
486 return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
487 }
488
sigmoid4_approx(__m128 X)489 static inline __m128 sigmoid4_approx(__m128 X)
490 {
491 const __m128 N0 = _mm_set1_ps(238.13200378f);
492 const __m128 N1 = _mm_set1_ps(6.02452230f);
493 const __m128 N2 = _mm_set1_ps(0.00950985f);
494 const __m128 D0 = _mm_set1_ps(952.72399902f);
495 const __m128 D1 = _mm_set1_ps(103.34200287f);
496 const __m128 D2 = _mm_set1_ps(0.74287558f);
497 const __m128 half = _mm_set1_ps(0.5);
498 const __m128 max_out = _mm_set1_ps(1.f);
499 const __m128 min_out = _mm_set1_ps(0.f);
500 __m128 X2, num, den;
501 X2 = _mm_mul_ps(X, X);
502 num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
503 den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
504 num = _mm_mul_ps(num, X);
505 den = _mm_rcp_ps(den);
506 num = _mm_fmadd_ps(num, den, half);
507 return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
508 }
509
tanh_approx(float x)510 static inline float tanh_approx(float x)
511 {
512 float out[4];
513 __m128 X, Y;
514 X = _mm_set1_ps(x);
515 Y = tanh4_approx(X);
516 _mm_storeu_ps(out, Y);
517 return out[0];
518 }
519
sigmoid_approx(float x)520 static inline float sigmoid_approx(float x)
521 {
522 float out[4];
523 __m128 X, Y;
524 X = _mm_set1_ps(x);
525 Y = sigmoid4_approx(X);
526 _mm_storeu_ps(out, Y);
527 return out[0];
528 }
529
530 #endif
531
lpcnet_exp(float x)532 static inline float lpcnet_exp(float x)
533 {
534 float out[8];
535 __m256 X, Y;
536 X = _mm256_set1_ps(x);
537 Y = exp8_approx(X);
538 _mm256_storeu_ps(out, Y);
539 return out[0];
540 }
541
softmax(float * y,const float * x,int N)542 static inline void softmax(float *y, const float *x, int N)
543 {
544 int i;
545 for (i=0;i<N-7;i+=8)
546 {
547 __m256 X, Y;
548 X = _mm256_loadu_ps(&x[i]);
549 Y = exp8_approx(X);
550 _mm256_storeu_ps(&y[i], Y);
551 }
552 for (;i<N;i++)
553 y[i] = lpcnet_exp(x[i]);
554 }
555
556 #ifdef __AVX__
vec_tanh(float * y,const float * x,int N)557 static inline void vec_tanh(float *y, const float *x, int N)
558 {
559 int i;
560 for (i=0;i<N-7;i+=8)
561 {
562 __m256 X, Y;
563 X = _mm256_loadu_ps(&x[i]);
564 Y = tanh8_approx(X);
565 _mm256_storeu_ps(&y[i], Y);
566 }
567 for (;i<N;i++)
568 {
569 y[i] = tanh_approx(x[i]);
570 }
571 }
572
vec_sigmoid(float * y,const float * x,int N)573 static inline void vec_sigmoid(float *y, const float *x, int N)
574 {
575 int i;
576 for (i=0;i<N-7;i+=8)
577 {
578 __m256 X, Y;
579 X = _mm256_loadu_ps(&x[i]);
580 Y = sigmoid8_approx(X);
581 _mm256_storeu_ps(&y[i], Y);
582 }
583 for (;i<N;i++)
584 {
585 y[i] = sigmoid_approx(x[i]);
586 }
587 }
588 #else
vec_tanh(float * y,const float * x,int N)589 static inline void vec_tanh(float *y, const float *x, int N)
590 {
591 int i;
592 for (i=0;i<N-3;i+=4)
593 {
594 __m128 X, Y;
595 X = _mm_loadu_ps(&x[i]);
596 Y = tanh4_approx(X);
597 _mm_storeu_ps(&y[i], Y);
598 }
599 for (;i<N;i++)
600 {
601 y[i] = tanh_approx(x[i]);
602 }
603 }
604
vec_sigmoid(float * y,const float * x,int N)605 static inline void vec_sigmoid(float *y, const float *x, int N)
606 {
607 int i;
608 for (i=0;i<N-3;i+=4)
609 {
610 __m128 X, Y;
611 X = _mm_loadu_ps(&x[i]);
612 Y = sigmoid4_approx(X);
613 _mm_storeu_ps(&y[i], Y);
614 }
615 for (;i<N;i++)
616 {
617 y[i] = sigmoid_approx(x[i]);
618 }
619 }
620
621 #endif
622
623 #if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
624
625 #define opus_mm256_dpbusds_epi32(src, a, b) _mm256_dpbusds_epi32(src, a, b)
626
627 #elif defined(__AVX2__)
628
opus_mm256_dpbusds_epi32(__m256i src,__m256i a,__m256i b)629 static inline __m256i opus_mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) {
630 __m256i ones, tmp;
631 ones = _mm256_set1_epi16(1);
632 tmp = _mm256_maddubs_epi16(a, b);
633 tmp = _mm256_madd_epi16(tmp, ones);
634 return _mm256_add_epi32(src, tmp);
635 }
636
637 #elif defined(__SSSE3__)
638
opus_mm256_dpbusds_epi32(mm256i_emu src,mm256i_emu a,mm256i_emu b)639 static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
640 mm256i_emu ones, tmp;
641 ones = _mm256_set1_epi16(1);
642 tmp = _mm256_maddubs_epi16(a, b);
643 tmp = _mm256_madd_epi16(tmp, ones);
644 return _mm256_add_epi32(src, tmp);
645 }
646
647 #elif defined(__SSE2__)
648
mm_dpbusds_epi32(__m128i src,__m128i a,__m128i b)649 static inline __m128i mm_dpbusds_epi32(__m128i src, __m128i a, __m128i b) {
650 __m128i ah, al, bh, bl, tmp;
651 ah = _mm_srli_epi16(a, 8);
652 bh = _mm_srai_epi16(b, 8);
653 al = _mm_srli_epi16(_mm_slli_epi16(a, 8), 8);
654 bl = _mm_srai_epi16(_mm_slli_epi16(b, 8), 8);
655 tmp = _mm_add_epi32(_mm_madd_epi16(ah, bh), _mm_madd_epi16(al, bl));
656 return _mm_add_epi32(src, tmp);
657 }
658
opus_mm256_dpbusds_epi32(mm256i_emu src,mm256i_emu a,mm256i_emu b)659 static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
660 mm256i_emu res;
661 res.hi = mm_dpbusds_epi32(src.hi, a.hi, b.hi);
662 res.lo = mm_dpbusds_epi32(src.lo, a.lo, b.lo);
663 return res;
664 }
665
666
667 #else
668
669 #error "No optimizations in vec_avx.h. This should never happen. "
670 #endif
671
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)672 static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
673 {
674 int i, j;
675 i=0;
676 for (;i<rows-15;i+=16)
677 {
678 float *y;
679 __m256 vy0, vy8;
680 y = &out[i];
681 vy0 = _mm256_setzero_ps();
682 vy8 = _mm256_setzero_ps();
683 for (j=0;j<cols;j++)
684 {
685 __m256 vxj;
686 __m256 vw;
687 vxj = _mm256_broadcast_ss(&x[j]);
688
689 vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
690 vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
691
692 vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
693 vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
694 }
695 _mm256_storeu_ps (&y[0], vy0);
696 _mm256_storeu_ps (&y[8], vy8);
697 }
698 for (;i<rows-7;i+=8)
699 {
700 float *y;
701 __m256 vy0;
702 y = &out[i];
703 vy0 = _mm256_setzero_ps();
704 for (j=0;j<cols;j++)
705 {
706 __m256 vxj;
707 __m256 vw;
708 vxj = _mm256_broadcast_ss(&x[j]);
709
710 vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
711 vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
712 }
713 _mm256_storeu_ps (&y[0], vy0);
714 }
715 for (;i<rows-3;i+=4)
716 {
717 float *y;
718 __m128 vy0;
719 y = &out[i];
720 vy0 = _mm_setzero_ps();
721 for (j=0;j<cols;j++)
722 {
723 __m128 vxj;
724 __m128 vw;
725 vxj = _mm_set1_ps(x[j]);
726
727 vw = _mm_loadu_ps(&weights[j*col_stride + i]);
728 vy0 = _mm_fmadd_ps(vw, vxj, vy0);
729 }
730 _mm_storeu_ps (&y[0], vy0);
731 }
732 for (;i<rows;i++)
733 {
734 out[i] = 0;
735 for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
736 }
737 }
738
sparse_sgemv8x4(float * out,const float * weights,const int * idx,int rows,const float * x)739 static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
740 {
741 int i, j;
742 for (i=0;i<rows;i+=8)
743 {
744 float *y;
745 int cols;
746 __m256 vy0;
747 y = &out[i];
748 vy0 = _mm256_setzero_ps();
749 cols = *idx++;
750 for (j=0;j<cols;j++)
751 {
752 int id;
753 __m256 vxj;
754 __m256 vw;
755 id = *idx++;
756 vxj = _mm256_broadcast_ss(&x[id]);
757 vw = _mm256_loadu_ps(&weights[0]);
758 vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
759
760 vxj = _mm256_broadcast_ss(&x[id+1]);
761 vw = _mm256_loadu_ps(&weights[8]);
762 vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
763
764 vxj = _mm256_broadcast_ss(&x[id+2]);
765 vw = _mm256_loadu_ps(&weights[16]);
766 vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
767
768 vxj = _mm256_broadcast_ss(&x[id+3]);
769 vw = _mm256_loadu_ps(&weights[24]);
770 vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
771
772 weights += 32;
773 }
774 _mm256_storeu_ps (&y[0], vy0);
775 }
776 }
777
sparse_cgemv8x4(float * _out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)778 static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
779 {
780 int i, j;
781 unsigned char x[MAX_INPUTS];
782 /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
783 vector_ps_to_epi8(x, _x, cols);
784 for (i=0;i<rows;i+=8)
785 {
786 int colblocks;
787 __m256i vy0;
788 __m256 vout;
789 colblocks = *idx++;
790 vy0 = _mm256_setzero_si256();
791 j=0;
792 #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
793 for (;j<colblocks-3;j+=4)
794 {
795 __m256i vxj;
796 __m256i vw;
797 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
798 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
799 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
800 w += 32;
801 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
802 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
803 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
804 w += 32;
805 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
806 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
807 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
808 w += 32;
809 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
810 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
811 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
812 w += 32;
813 }
814 #endif
815 for (;j<colblocks;j++)
816 {
817 __m256i vxj;
818 __m256i vw;
819 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
820 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
821 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
822 w += 32;
823 }
824 vout = _mm256_cvtepi32_ps(vy0);
825 vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
826 _mm256_storeu_ps(&_out[i], vout);
827 }
828 }
cgemv8x4(float * _out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)829 static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
830 {
831 int i, j;
832 unsigned char x[MAX_INPUTS];
833 /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/
834 vector_ps_to_epi8(x, _x, cols);
835 for (i=0;i<rows;i+=8)
836 {
837 __m256i vy0;
838 __m256 vout;
839 vy0 = _mm256_setzero_si256();
840 j=0;
841 #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */
842 for (;j<cols-12;j+=16)
843 {
844 __m256i vxj;
845 __m256i vw;
846 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
847 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
848 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
849 w += 32;
850 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+4]));
851 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
852 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
853 w += 32;
854 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+8]));
855 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
856 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
857 w += 32;
858 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+12]));
859 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
860 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
861 w += 32;
862 }
863 #endif
864 for (;j<cols;j+=4)
865 {
866 __m256i vxj;
867 __m256i vw;
868 vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
869 vw = _mm256_loadu_si256((const __m256i *)(void*)w);
870 vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
871 w += 32;
872 }
873 vout = _mm256_cvtepi32_ps(vy0);
874 vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
875 _mm256_storeu_ps(&_out[i], vout);
876 }
877 }
878
879 #define SCALE (128.f*127.f)
880 #define SCALE_1 (1.f/128.f/127.f)
881 #define USE_SU_BIAS
882
883
884 #endif /*VEC_AVX_H*/
885