1 /* Copyright (c) 2018 David Rowe
2 2018 Mozilla
3 2008-2011 Octasic Inc.
4 2012-2017 Jean-Marc Valin */
5 /*
6 Redistribution and use in source and binary forms, with or without
7 modification, are permitted provided that the following conditions
8 are met:
9
10 - Redistributions of source code must retain the above copyright
11 notice, this list of conditions and the following disclaimer.
12
13 - Redistributions in binary form must reproduce the above copyright
14 notice, this list of conditions and the following disclaimer in the
15 documentation and/or other materials provided with the distribution.
16
17 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
21 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29 /* NEON support for ARM machines */
30
31 #ifndef VEC_NEON_H
32 #define VEC_NEON_H
33
34 #include <arm_neon.h>
35 #include "os_support.h"
36
37 #if defined(__arm__) && !defined(__aarch64__)
38 /* Emulate vcvtnq_s32_f32() for ARMv7 Neon. */
vcvtnq_s32_f32(float32x4_t x)39 static OPUS_INLINE int32x4_t vcvtnq_s32_f32(float32x4_t x) {
40 return vrshrq_n_s32(vcvtq_n_s32_f32(x, 8), 8);
41 }
42
vpaddq_s16(int16x8_t a,int16x8_t b)43 static OPUS_INLINE int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
44 return vcombine_s16(vpadd_s16(vget_low_s16(a), vget_high_s16(a)), vpadd_s16(vget_low_s16(b), vget_high_s16(b)));
45 }
46
vmull_high_s8(int8x16_t a,int8x16_t b)47 static OPUS_INLINE int16x8_t vmull_high_s8(int8x16_t a, int8x16_t b) {
48 return vmull_s8(vget_high_s8(a), vget_high_s8(b));
49 }
50 #endif
51
52 #ifdef __ARM_FEATURE_FMA
53 /* If we can, force the compiler to use an FMA instruction rather than break
54 vmlaq_f32() into fmul/fadd. */
55 #define vmlaq_f32(a,b,c) vfmaq_f32(a,b,c)
56 #endif
57
58 #ifndef LPCNET_TEST
exp4_approx(float32x4_t x)59 static inline float32x4_t exp4_approx(float32x4_t x) {
60 int32x4_t i;
61 float32x4_t xf;
62
63 x = vmaxq_f32(vminq_f32(x, vdupq_n_f32(88.f)), vdupq_n_f32(-88.f));
64
65 /* express exp(x) as exp2(x/log(2)), add 127 for the exponent later */
66 x = vmlaq_f32(vdupq_n_f32(127.f), x, vdupq_n_f32(1.44269504f));
67
68 /* split into integer and fractional parts */
69 i = vcvtq_s32_f32(x);
70 xf = vcvtq_f32_s32(i);
71 x = vsubq_f32(x, xf);
72
73 float32x4_t K0 = vdupq_n_f32(0.99992522f);
74 float32x4_t K1 = vdupq_n_f32(0.69583354f);
75 float32x4_t K2 = vdupq_n_f32(0.22606716f);
76 float32x4_t K3 = vdupq_n_f32(0.078024523f);
77 float32x4_t Y = vmlaq_f32(K0, x, vmlaq_f32(K1, x, vmlaq_f32(K2, K3, x)));
78
79 /* compute 2^i */
80 float32x4_t exponent = vreinterpretq_f32_s32(vshlq_n_s32(i, 23));
81
82 Y = vmulq_f32(Y, exponent);
83 return Y;
84 }
85
tanh4_approx(float32x4_t X)86 static inline float32x4_t tanh4_approx(float32x4_t X)
87 {
88 const float32x4_t N0 = vdupq_n_f32(952.52801514f);
89 const float32x4_t N1 = vdupq_n_f32(96.39235687f);
90 const float32x4_t N2 = vdupq_n_f32(0.60863042f);
91 const float32x4_t D0 = vdupq_n_f32(952.72399902f);
92 const float32x4_t D1 = vdupq_n_f32(413.36801147f);
93 const float32x4_t D2 = vdupq_n_f32(11.88600922f);
94 const float32x4_t max_out = vdupq_n_f32(1.f);
95 const float32x4_t min_out = vdupq_n_f32(-1.f);
96 float32x4_t X2, num, den;
97 X2 = vmulq_f32(X, X);
98 num = vmlaq_f32(N0, X2, vmlaq_f32(N1, N2, X2));
99 den = vmlaq_f32(D0, X2, vmlaq_f32(D1, D2, X2));
100 num = vmulq_f32(num, X);
101 den = vrecpeq_f32(den);
102 num = vmulq_f32(num, den);
103 return vmaxq_f32(min_out, vminq_f32(max_out, num));
104 }
105
sigmoid4_approx(float32x4_t X)106 static inline float32x4_t sigmoid4_approx(float32x4_t X)
107 {
108 const float32x4_t N0 = vdupq_n_f32(238.13200378f);
109 const float32x4_t N1 = vdupq_n_f32(6.02452230f);
110 const float32x4_t N2 = vdupq_n_f32(0.00950985f);
111 const float32x4_t D0 = vdupq_n_f32(952.72399902f);
112 const float32x4_t D1 = vdupq_n_f32(103.34200287f);
113 const float32x4_t D2 = vdupq_n_f32(0.74287558f);
114 const float32x4_t half = vdupq_n_f32(0.5f);
115 const float32x4_t max_out = vdupq_n_f32(1.f);
116 const float32x4_t min_out = vdupq_n_f32(0.f);
117 float32x4_t X2, num, den;
118 X2 = vmulq_f32(X, X);
119 num = vmlaq_f32(N0, X2, vmlaq_f32(N1, N2, X2));
120 den = vmlaq_f32(D0, X2, vmlaq_f32(D1, D2, X2));
121 num = vmulq_f32(num, X);
122 den = vrecpeq_f32(den);
123 num = vmlaq_f32(half, num, den);
124 return vmaxq_f32(min_out, vminq_f32(max_out, num));
125 }
126
lpcnet_exp(float x)127 static inline float lpcnet_exp(float x)
128 {
129 float out[4];
130 float32x4_t X, Y;
131 X = vdupq_n_f32(x);
132 Y = exp4_approx(X);
133 vst1q_f32(out, Y);
134 return out[0];
135 }
136
tanh_approx(float x)137 static inline float tanh_approx(float x)
138 {
139 float out[4];
140 float32x4_t X, Y;
141 X = vdupq_n_f32(x);
142 Y = tanh4_approx(X);
143 vst1q_f32(out, Y);
144 return out[0];
145 }
146
sigmoid_approx(float x)147 static inline float sigmoid_approx(float x)
148 {
149 float out[4];
150 float32x4_t X, Y;
151 X = vdupq_n_f32(x);
152 Y = sigmoid4_approx(X);
153 vst1q_f32(out, Y);
154 return out[0];
155 }
156
softmax(float * y,const float * x,int N)157 static inline void softmax(float *y, const float *x, int N)
158 {
159 int i;
160 for (i=0;i<N-3;i+=4)
161 {
162 float32x4_t X, Y;
163 X = vld1q_f32(&x[i]);
164 Y = exp4_approx(X);
165 vst1q_f32(&y[i], Y);
166 }
167 for (;i<N;i++)
168 y[i] = lpcnet_exp(x[i]);
169 }
170
vec_tanh(float * y,const float * x,int N)171 static inline void vec_tanh(float *y, const float *x, int N)
172 {
173 int i;
174 for (i=0;i<N-3;i+=4)
175 {
176 float32x4_t X, Y;
177 X = vld1q_f32(&x[i]);
178 Y = tanh4_approx(X);
179 vst1q_f32(&y[i], Y);
180 }
181 for (;i<N;i++)
182 {
183 float ex2;
184 ex2 = lpcnet_exp(2*x[i]);
185 y[i] = (ex2-1)/(ex2+1);
186 }
187 }
188
vec_sigmoid(float * y,const float * x,int N)189 static inline void vec_sigmoid(float *y, const float *x, int N)
190 {
191 int i;
192 for (i=0;i<N-3;i+=4)
193 {
194 float32x4_t X, Y;
195 X = vld1q_f32(&x[i]);
196 Y = sigmoid4_approx(X);
197 vst1q_f32(&y[i], Y);
198 }
199 for (;i<N;i++)
200 {
201 float ex;
202 ex = lpcnet_exp(x[i]);
203 y[i] = (ex)/(ex+1);
204 }
205 }
206 #endif
207
sgemv16x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)208 static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
209 {
210 int i, j;
211 for (i=0;i<rows;i+=16)
212 {
213 float * restrict y = &out[i];
214
215 /* keep y[0..15] in registers for duration of inner loop */
216
217 float32x4_t y0_3 = vdupq_n_f32(0);
218 float32x4_t y4_7 = vdupq_n_f32(0);
219 float32x4_t y8_11 = vdupq_n_f32(0);
220 float32x4_t y12_15 = vdupq_n_f32(0);
221
222 for (j=0;j<cols;j++)
223 {
224 const float * restrict w;
225 float32x4_t wvec0_3, wvec4_7, wvec8_11, wvec12_15;
226 float32x4_t xj;
227
228 w = &weights[j*col_stride + i];
229 wvec0_3 = vld1q_f32(&w[0]);
230 wvec4_7 = vld1q_f32(&w[4]);
231 wvec8_11 = vld1q_f32(&w[8]);
232 wvec12_15 = vld1q_f32(&w[12]);
233
234 xj = vld1q_dup_f32(&x[j]);
235
236 y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
237 y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
238 y8_11 = vmlaq_f32(y8_11, wvec8_11, xj);
239 y12_15 = vmlaq_f32(y12_15, wvec12_15, xj);
240 }
241
242 /* save y[0..15] back to memory */
243
244 vst1q_f32(&y[0], y0_3);
245 vst1q_f32(&y[4], y4_7);
246 vst1q_f32(&y[8], y8_11);
247 vst1q_f32(&y[12], y12_15);
248
249 }
250 }
251
sgemv8x1(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)252 static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
253 {
254 int i, j;
255 for (i=0;i<rows;i+=8)
256 {
257 float * restrict y = &out[i];
258
259 /* keep y[0..15] in registers for duration of inner loop */
260
261 float32x4_t y0_3 = vdupq_n_f32(0);
262 float32x4_t y4_7 = vdupq_n_f32(0);
263
264 for (j=0;j<cols;j++)
265 {
266 const float * restrict w;
267 float32x4_t wvec0_3, wvec4_7;
268 float32x4_t xj;
269
270 w = &weights[j*col_stride + i];
271 wvec0_3 = vld1q_f32(&w[0]);
272 wvec4_7 = vld1q_f32(&w[4]);
273
274 xj = vld1q_dup_f32(&x[j]);
275
276 y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
277 y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
278 }
279
280 /* save y[0..15] back to memory */
281
282 vst1q_f32(&y[0], y0_3);
283 vst1q_f32(&y[4], y4_7);
284 }
285 }
286
sgemv(float * out,const float * weights,int rows,int cols,int col_stride,const float * x)287 static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
288 {
289 if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
290 else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
291 else {
292 int i, j;
293 for (i=0;i<rows;i++)
294 {
295 out[i] = 0;
296 for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
297 }
298 }
299 }
300
301 /* Temporarily use unoptimized version */
sparse_sgemv8x4(float * out,const float * w,const int * idx,int rows,const float * x)302 static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
303 {
304 int i, j;
305 OPUS_CLEAR(out, rows);
306 for (i=0;i<rows;i+=8)
307 {
308 int cols;
309 cols = *idx++;
310 for (j=0;j<cols;j++)
311 {
312 int pos;
313 float * restrict y;
314 float xj0, xj1, xj2, xj3;
315 pos = (*idx++);
316 xj0 = x[pos+0];
317 xj1 = x[pos+1];
318 xj2 = x[pos+2];
319 xj3 = x[pos+3];
320 y = &out[i];
321 y[0] += w[0]*xj0;
322 y[1] += w[1]*xj0;
323 y[2] += w[2]*xj0;
324 y[3] += w[3]*xj0;
325 y[4] += w[4]*xj0;
326 y[5] += w[5]*xj0;
327 y[6] += w[6]*xj0;
328 y[7] += w[7]*xj0;
329
330 y[0] += w[8]*xj1;
331 y[1] += w[9]*xj1;
332 y[2] += w[10]*xj1;
333 y[3] += w[11]*xj1;
334 y[4] += w[12]*xj1;
335 y[5] += w[13]*xj1;
336 y[6] += w[14]*xj1;
337 y[7] += w[15]*xj1;
338
339 y[0] += w[16]*xj2;
340 y[1] += w[17]*xj2;
341 y[2] += w[18]*xj2;
342 y[3] += w[19]*xj2;
343 y[4] += w[20]*xj2;
344 y[5] += w[21]*xj2;
345 y[6] += w[22]*xj2;
346 y[7] += w[23]*xj2;
347
348 y[0] += w[24]*xj3;
349 y[1] += w[25]*xj3;
350 y[2] += w[26]*xj3;
351 y[3] += w[27]*xj3;
352 y[4] += w[28]*xj3;
353 y[5] += w[29]*xj3;
354 y[6] += w[30]*xj3;
355 y[7] += w[31]*xj3;
356 w += 32;
357 }
358 }
359 }
360
361
362 #define SCALE (128.f*127.f)
363 #define SCALE_1 (1.f/128.f/127.f)
364
365 #define MAX_INPUTS 2048
366 #define MAX_OUTPUTS 8192
367
368 #if __ARM_FEATURE_DOTPROD
vdotprod(int32x4_t acc,int8x16_t a,int8x16_t b)369 static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b) {
370 return vdotq_s32(acc, a, b);
371 }
372 #else
vdotprod(int32x4_t acc,int8x16_t a,int8x16_t b)373 static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b)
374 {
375 return vpadalq_s16(acc, vpaddq_s16(vmull_s8(vget_low_s8(a), vget_low_s8(b)), vmull_high_s8(a, b)));
376 }
377 #endif
378
cgemv8x4(float * _out,const opus_int8 * w,const float * scale,int rows,int cols,const float * _x)379 static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
380 {
381 int i, j;
382 opus_int32 x_int[MAX_INPUTS/4];
383 opus_int8 *x = (opus_int8*) x_int;
384 const float32x4_t const127 = vdupq_n_f32(127.);
385 for (i=0;i<cols;i+=8) {
386 int32x4_t xi0, xi4;
387 int16x8_t x_short;
388 xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
389 xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
390 x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
391 vst1_s8(&x[i], vmovn_s16(x_short));
392 }
393 for (i=0;i<rows;i+=8)
394 {
395 int32x4_t acc0, acc1;
396 int32x4_t acc2, acc3;
397 acc0 = vdupq_n_s32(0);
398 acc1 = vdupq_n_s32(0);
399 acc2 = vdupq_n_s32(0);
400 acc3 = vdupq_n_s32(0);
401 j=0;
402 for (;j<cols-4;j+=8)
403 {
404 int8x16_t vw0, vw1, vw2, vw3, vx0, vx1;
405 vx0 = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j]);
406 vw0 = vld1q_s8(w);
407 vw1 = vld1q_s8(&w[16]);
408 acc0 = vdotprod(acc0, vw0, vx0);
409 acc1 = vdotprod(acc1, vw1, vx0);
410 vx1 = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j+4]);
411 vw2 = vld1q_s8(&w[32]);
412 vw3 = vld1q_s8(&w[48]);
413 acc2 = vdotprod(acc2, vw2, vx1);
414 acc3 = vdotprod(acc3, vw3, vx1);
415 w += 64;
416 }
417 acc0 = vaddq_s32(acc0, acc2);
418 acc1 = vaddq_s32(acc1, acc3);
419 for (;j<cols;j+=4)
420 {
421 int8x16_t vw0, vw1, vx;
422 vx = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[j]);
423 vw0 = vld1q_s8(w);
424 vw1 = vld1q_s8(&w[16]);
425 acc0 = vdotprod(acc0, vw0, vx);
426 acc1 = vdotprod(acc1, vw1, vx);
427 w += 32;
428 }
429 vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
430 vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
431 }
432 }
433
sparse_cgemv8x4(float * _out,const opus_int8 * w,const int * idx,const float * scale,int rows,int cols,const float * _x)434 static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
435 {
436 int i, j;
437 opus_int32 x_int[MAX_INPUTS/4];
438 opus_int8 *x = (opus_int8*) x_int;
439 const float32x4_t const127 = vdupq_n_f32(127.);
440 for (i=0;i<cols;i+=8) {
441 int32x4_t xi0, xi4;
442 int16x8_t x_short;
443 xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i])));
444 xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4])));
445 x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4));
446 vst1_s8(&x[i], vmovn_s16(x_short));
447 }
448 for (i=0;i<rows;i+=8)
449 {
450 int colblocks;
451 int32x4_t acc0, acc1;
452 acc0 = vdupq_n_s32(0);
453 acc1 = vdupq_n_s32(0);
454 colblocks = *idx++;
455 for (j=0;j<colblocks;j++)
456 {
457 int pos;
458 pos = (*idx++);
459 int8x16_t vw0, vw1, vx;
460 vx = (int8x16_t)vld1q_dup_s32((int*)(void*)&x[pos]);
461 vw0 = vld1q_s8(w);
462 vw1 = vld1q_s8(&w[16]);
463 acc0 = vdotprod(acc0, vw0, vx);
464 acc1 = vdotprod(acc1, vw1, vx);
465 w += 32;
466 }
467 vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0)));
468 vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1)));
469 }
470 }
471
472
473 #endif
474