xref: /aosp_15_r20/external/rnnoise/src/rnn.c (revision 1295d6828459cc82c3c29cc5d7d297215250a74b)
1*1295d682SXin Li /* Copyright (c) 2008-2011 Octasic Inc.
2*1295d682SXin Li                  2012-2017 Jean-Marc Valin */
3*1295d682SXin Li /*
4*1295d682SXin Li    Redistribution and use in source and binary forms, with or without
5*1295d682SXin Li    modification, are permitted provided that the following conditions
6*1295d682SXin Li    are met:
7*1295d682SXin Li 
8*1295d682SXin Li    - Redistributions of source code must retain the above copyright
9*1295d682SXin Li    notice, this list of conditions and the following disclaimer.
10*1295d682SXin Li 
11*1295d682SXin Li    - Redistributions in binary form must reproduce the above copyright
12*1295d682SXin Li    notice, this list of conditions and the following disclaimer in the
13*1295d682SXin Li    documentation and/or other materials provided with the distribution.
14*1295d682SXin Li 
15*1295d682SXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*1295d682SXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*1295d682SXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*1295d682SXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*1295d682SXin Li    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*1295d682SXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*1295d682SXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*1295d682SXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*1295d682SXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*1295d682SXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*1295d682SXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*1295d682SXin Li */
27*1295d682SXin Li 
28*1295d682SXin Li #ifdef HAVE_CONFIG_H
29*1295d682SXin Li #include "config.h"
30*1295d682SXin Li #endif
31*1295d682SXin Li 
32*1295d682SXin Li #include <math.h>
33*1295d682SXin Li #include "opus_types.h"
34*1295d682SXin Li #include "common.h"
35*1295d682SXin Li #include "arch.h"
36*1295d682SXin Li #include "tansig_table.h"
37*1295d682SXin Li #include "rnn.h"
38*1295d682SXin Li #include "rnn_data.h"
39*1295d682SXin Li #include <stdio.h>
40*1295d682SXin Li 
tansig_approx(float x)41*1295d682SXin Li static OPUS_INLINE float tansig_approx(float x)
42*1295d682SXin Li {
43*1295d682SXin Li     int i;
44*1295d682SXin Li     float y, dy;
45*1295d682SXin Li     float sign=1;
46*1295d682SXin Li     /* Tests are reversed to catch NaNs */
47*1295d682SXin Li     if (!(x<8))
48*1295d682SXin Li         return 1;
49*1295d682SXin Li     if (!(x>-8))
50*1295d682SXin Li         return -1;
51*1295d682SXin Li #ifndef FIXED_POINT
52*1295d682SXin Li     /* Another check in case of -ffast-math */
53*1295d682SXin Li     if (celt_isnan(x))
54*1295d682SXin Li        return 0;
55*1295d682SXin Li #endif
56*1295d682SXin Li     if (x<0)
57*1295d682SXin Li     {
58*1295d682SXin Li        x=-x;
59*1295d682SXin Li        sign=-1;
60*1295d682SXin Li     }
61*1295d682SXin Li     i = (int)floor(.5f+25*x);
62*1295d682SXin Li     x -= .04f*i;
63*1295d682SXin Li     y = tansig_table[i];
64*1295d682SXin Li     dy = 1-y*y;
65*1295d682SXin Li     y = y + x*dy*(1 - y*x);
66*1295d682SXin Li     return sign*y;
67*1295d682SXin Li }
68*1295d682SXin Li 
sigmoid_approx(float x)69*1295d682SXin Li static OPUS_INLINE float sigmoid_approx(float x)
70*1295d682SXin Li {
71*1295d682SXin Li    return .5 + .5*tansig_approx(.5*x);
72*1295d682SXin Li }
73*1295d682SXin Li 
relu(float x)74*1295d682SXin Li static OPUS_INLINE float relu(float x)
75*1295d682SXin Li {
76*1295d682SXin Li    return x < 0 ? 0 : x;
77*1295d682SXin Li }
78*1295d682SXin Li 
compute_dense(const DenseLayer * layer,float * output,const float * input)79*1295d682SXin Li void compute_dense(const DenseLayer *layer, float *output, const float *input)
80*1295d682SXin Li {
81*1295d682SXin Li    int i, j;
82*1295d682SXin Li    int N, M;
83*1295d682SXin Li    int stride;
84*1295d682SXin Li    M = layer->nb_inputs;
85*1295d682SXin Li    N = layer->nb_neurons;
86*1295d682SXin Li    stride = N;
87*1295d682SXin Li    for (i=0;i<N;i++)
88*1295d682SXin Li    {
89*1295d682SXin Li       /* Compute update gate. */
90*1295d682SXin Li       float sum = layer->bias[i];
91*1295d682SXin Li       for (j=0;j<M;j++)
92*1295d682SXin Li          sum += layer->input_weights[j*stride + i]*input[j];
93*1295d682SXin Li       output[i] = WEIGHTS_SCALE*sum;
94*1295d682SXin Li    }
95*1295d682SXin Li    if (layer->activation == ACTIVATION_SIGMOID) {
96*1295d682SXin Li       for (i=0;i<N;i++)
97*1295d682SXin Li          output[i] = sigmoid_approx(output[i]);
98*1295d682SXin Li    } else if (layer->activation == ACTIVATION_TANH) {
99*1295d682SXin Li       for (i=0;i<N;i++)
100*1295d682SXin Li          output[i] = tansig_approx(output[i]);
101*1295d682SXin Li    } else if (layer->activation == ACTIVATION_RELU) {
102*1295d682SXin Li       for (i=0;i<N;i++)
103*1295d682SXin Li          output[i] = relu(output[i]);
104*1295d682SXin Li    } else {
105*1295d682SXin Li      *(int*)0=0;
106*1295d682SXin Li    }
107*1295d682SXin Li }
108*1295d682SXin Li 
compute_gru(const GRULayer * gru,float * state,const float * input)109*1295d682SXin Li void compute_gru(const GRULayer *gru, float *state, const float *input)
110*1295d682SXin Li {
111*1295d682SXin Li    int i, j;
112*1295d682SXin Li    int N, M;
113*1295d682SXin Li    int stride;
114*1295d682SXin Li    float z[MAX_NEURONS];
115*1295d682SXin Li    float r[MAX_NEURONS];
116*1295d682SXin Li    float h[MAX_NEURONS];
117*1295d682SXin Li    M = gru->nb_inputs;
118*1295d682SXin Li    N = gru->nb_neurons;
119*1295d682SXin Li    stride = 3*N;
120*1295d682SXin Li    for (i=0;i<N;i++)
121*1295d682SXin Li    {
122*1295d682SXin Li       /* Compute update gate. */
123*1295d682SXin Li       float sum = gru->bias[i];
124*1295d682SXin Li       for (j=0;j<M;j++)
125*1295d682SXin Li          sum += gru->input_weights[j*stride + i]*input[j];
126*1295d682SXin Li       for (j=0;j<N;j++)
127*1295d682SXin Li          sum += gru->recurrent_weights[j*stride + i]*state[j];
128*1295d682SXin Li       z[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
129*1295d682SXin Li    }
130*1295d682SXin Li    for (i=0;i<N;i++)
131*1295d682SXin Li    {
132*1295d682SXin Li       /* Compute reset gate. */
133*1295d682SXin Li       float sum = gru->bias[N + i];
134*1295d682SXin Li       for (j=0;j<M;j++)
135*1295d682SXin Li          sum += gru->input_weights[N + j*stride + i]*input[j];
136*1295d682SXin Li       for (j=0;j<N;j++)
137*1295d682SXin Li          sum += gru->recurrent_weights[N + j*stride + i]*state[j];
138*1295d682SXin Li       r[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
139*1295d682SXin Li    }
140*1295d682SXin Li    for (i=0;i<N;i++)
141*1295d682SXin Li    {
142*1295d682SXin Li       /* Compute output. */
143*1295d682SXin Li       float sum = gru->bias[2*N + i];
144*1295d682SXin Li       for (j=0;j<M;j++)
145*1295d682SXin Li          sum += gru->input_weights[2*N + j*stride + i]*input[j];
146*1295d682SXin Li       for (j=0;j<N;j++)
147*1295d682SXin Li          sum += gru->recurrent_weights[2*N + j*stride + i]*state[j]*r[j];
148*1295d682SXin Li       if (gru->activation == ACTIVATION_SIGMOID) sum = sigmoid_approx(WEIGHTS_SCALE*sum);
149*1295d682SXin Li       else if (gru->activation == ACTIVATION_TANH) sum = tansig_approx(WEIGHTS_SCALE*sum);
150*1295d682SXin Li       else if (gru->activation == ACTIVATION_RELU) sum = relu(WEIGHTS_SCALE*sum);
151*1295d682SXin Li       else *(int*)0=0;
152*1295d682SXin Li       h[i] = z[i]*state[i] + (1-z[i])*sum;
153*1295d682SXin Li    }
154*1295d682SXin Li    for (i=0;i<N;i++)
155*1295d682SXin Li       state[i] = h[i];
156*1295d682SXin Li }
157*1295d682SXin Li 
158*1295d682SXin Li #define INPUT_SIZE 42
159*1295d682SXin Li 
compute_rnn(RNNState * rnn,float * gains,float * vad,const float * input)160*1295d682SXin Li void compute_rnn(RNNState *rnn, float *gains, float *vad, const float *input) {
161*1295d682SXin Li   int i;
162*1295d682SXin Li   float dense_out[MAX_NEURONS];
163*1295d682SXin Li   float noise_input[MAX_NEURONS*3];
164*1295d682SXin Li   float denoise_input[MAX_NEURONS*3];
165*1295d682SXin Li   compute_dense(rnn->model->input_dense, dense_out, input);
166*1295d682SXin Li   compute_gru(rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
167*1295d682SXin Li   compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
168*1295d682SXin Li   for (i=0;i<rnn->model->input_dense_size;i++) noise_input[i] = dense_out[i];
169*1295d682SXin Li   for (i=0;i<rnn->model->vad_gru_size;i++) noise_input[i+rnn->model->input_dense_size] = rnn->vad_gru_state[i];
170*1295d682SXin Li   for (i=0;i<INPUT_SIZE;i++) noise_input[i+rnn->model->input_dense_size+rnn->model->vad_gru_size] = input[i];
171*1295d682SXin Li   compute_gru(rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
172*1295d682SXin Li 
173*1295d682SXin Li   for (i=0;i<rnn->model->vad_gru_size;i++) denoise_input[i] = rnn->vad_gru_state[i];
174*1295d682SXin Li   for (i=0;i<rnn->model->noise_gru_size;i++) denoise_input[i+rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
175*1295d682SXin Li   for (i=0;i<INPUT_SIZE;i++) denoise_input[i+rnn->model->vad_gru_size+rnn->model->noise_gru_size] = input[i];
176*1295d682SXin Li   compute_gru(rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
177*1295d682SXin Li   compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
178*1295d682SXin Li }
179