xref: /aosp_15_r20/external/rnnoise/src/denoise.c (revision 1295d6828459cc82c3c29cc5d7d297215250a74b)
1*1295d682SXin Li /* Copyright (c) 2018 Gregor Richards
2*1295d682SXin Li  * Copyright (c) 2017 Mozilla */
3*1295d682SXin Li /*
4*1295d682SXin Li    Redistribution and use in source and binary forms, with or without
5*1295d682SXin Li    modification, are permitted provided that the following conditions
6*1295d682SXin Li    are met:
7*1295d682SXin Li 
8*1295d682SXin Li    - Redistributions of source code must retain the above copyright
9*1295d682SXin Li    notice, this list of conditions and the following disclaimer.
10*1295d682SXin Li 
11*1295d682SXin Li    - Redistributions in binary form must reproduce the above copyright
12*1295d682SXin Li    notice, this list of conditions and the following disclaimer in the
13*1295d682SXin Li    documentation and/or other materials provided with the distribution.
14*1295d682SXin Li 
15*1295d682SXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*1295d682SXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*1295d682SXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*1295d682SXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*1295d682SXin Li    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*1295d682SXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*1295d682SXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*1295d682SXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*1295d682SXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*1295d682SXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*1295d682SXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*1295d682SXin Li */
27*1295d682SXin Li 
28*1295d682SXin Li #ifdef HAVE_CONFIG_H
29*1295d682SXin Li #include "config.h"
30*1295d682SXin Li #endif
31*1295d682SXin Li 
32*1295d682SXin Li #include <stdlib.h>
33*1295d682SXin Li #include <string.h>
34*1295d682SXin Li #include <stdio.h>
35*1295d682SXin Li #include "kiss_fft.h"
36*1295d682SXin Li #include "common.h"
37*1295d682SXin Li #include <math.h>
38*1295d682SXin Li #include "rnnoise.h"
39*1295d682SXin Li #include "pitch.h"
40*1295d682SXin Li #include "arch.h"
41*1295d682SXin Li #include "rnn.h"
42*1295d682SXin Li #include "rnn_data.h"
43*1295d682SXin Li 
44*1295d682SXin Li #define FRAME_SIZE_SHIFT 2
45*1295d682SXin Li #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
46*1295d682SXin Li #define WINDOW_SIZE (2*FRAME_SIZE)
47*1295d682SXin Li #define FREQ_SIZE (FRAME_SIZE + 1)
48*1295d682SXin Li 
49*1295d682SXin Li #define PITCH_MIN_PERIOD 60
50*1295d682SXin Li #define PITCH_MAX_PERIOD 768
51*1295d682SXin Li #define PITCH_FRAME_SIZE 960
52*1295d682SXin Li #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
53*1295d682SXin Li 
54*1295d682SXin Li #define SQUARE(x) ((x)*(x))
55*1295d682SXin Li 
56*1295d682SXin Li #define NB_BANDS 22
57*1295d682SXin Li 
58*1295d682SXin Li #define CEPS_MEM 8
59*1295d682SXin Li #define NB_DELTA_CEPS 6
60*1295d682SXin Li 
61*1295d682SXin Li #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
62*1295d682SXin Li 
63*1295d682SXin Li 
64*1295d682SXin Li #ifndef TRAINING
65*1295d682SXin Li #define TRAINING 0
66*1295d682SXin Li #endif
67*1295d682SXin Li 
68*1295d682SXin Li 
69*1295d682SXin Li /* The built-in model, used if no file is given as input */
70*1295d682SXin Li extern const struct RNNModel rnnoise_model_orig;
71*1295d682SXin Li 
72*1295d682SXin Li 
73*1295d682SXin Li static const opus_int16 eband5ms[] = {
74*1295d682SXin Li /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
75*1295d682SXin Li   0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
76*1295d682SXin Li };
77*1295d682SXin Li 
78*1295d682SXin Li 
79*1295d682SXin Li typedef struct {
80*1295d682SXin Li   int init;
81*1295d682SXin Li   kiss_fft_state *kfft;
82*1295d682SXin Li   float half_window[FRAME_SIZE];
83*1295d682SXin Li   float dct_table[NB_BANDS*NB_BANDS];
84*1295d682SXin Li } CommonState;
85*1295d682SXin Li 
86*1295d682SXin Li struct DenoiseState {
87*1295d682SXin Li   float analysis_mem[FRAME_SIZE];
88*1295d682SXin Li   float cepstral_mem[CEPS_MEM][NB_BANDS];
89*1295d682SXin Li   int memid;
90*1295d682SXin Li   float synthesis_mem[FRAME_SIZE];
91*1295d682SXin Li   float pitch_buf[PITCH_BUF_SIZE];
92*1295d682SXin Li   float pitch_enh_buf[PITCH_BUF_SIZE];
93*1295d682SXin Li   float last_gain;
94*1295d682SXin Li   int last_period;
95*1295d682SXin Li   float mem_hp_x[2];
96*1295d682SXin Li   float lastg[NB_BANDS];
97*1295d682SXin Li   RNNState rnn;
98*1295d682SXin Li };
99*1295d682SXin Li 
compute_band_energy(float * bandE,const kiss_fft_cpx * X)100*1295d682SXin Li void compute_band_energy(float *bandE, const kiss_fft_cpx *X) {
101*1295d682SXin Li   int i;
102*1295d682SXin Li   float sum[NB_BANDS] = {0};
103*1295d682SXin Li   for (i=0;i<NB_BANDS-1;i++)
104*1295d682SXin Li   {
105*1295d682SXin Li     int j;
106*1295d682SXin Li     int band_size;
107*1295d682SXin Li     band_size = (eband5ms[i+1]-eband5ms[i])<<FRAME_SIZE_SHIFT;
108*1295d682SXin Li     for (j=0;j<band_size;j++) {
109*1295d682SXin Li       float tmp;
110*1295d682SXin Li       float frac = (float)j/band_size;
111*1295d682SXin Li       tmp = SQUARE(X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].r);
112*1295d682SXin Li       tmp += SQUARE(X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].i);
113*1295d682SXin Li       sum[i] += (1-frac)*tmp;
114*1295d682SXin Li       sum[i+1] += frac*tmp;
115*1295d682SXin Li     }
116*1295d682SXin Li   }
117*1295d682SXin Li   sum[0] *= 2;
118*1295d682SXin Li   sum[NB_BANDS-1] *= 2;
119*1295d682SXin Li   for (i=0;i<NB_BANDS;i++)
120*1295d682SXin Li   {
121*1295d682SXin Li     bandE[i] = sum[i];
122*1295d682SXin Li   }
123*1295d682SXin Li }
124*1295d682SXin Li 
compute_band_corr(float * bandE,const kiss_fft_cpx * X,const kiss_fft_cpx * P)125*1295d682SXin Li void compute_band_corr(float *bandE, const kiss_fft_cpx *X, const kiss_fft_cpx *P) {
126*1295d682SXin Li   int i;
127*1295d682SXin Li   float sum[NB_BANDS] = {0};
128*1295d682SXin Li   for (i=0;i<NB_BANDS-1;i++)
129*1295d682SXin Li   {
130*1295d682SXin Li     int j;
131*1295d682SXin Li     int band_size;
132*1295d682SXin Li     band_size = (eband5ms[i+1]-eband5ms[i])<<FRAME_SIZE_SHIFT;
133*1295d682SXin Li     for (j=0;j<band_size;j++) {
134*1295d682SXin Li       float tmp;
135*1295d682SXin Li       float frac = (float)j/band_size;
136*1295d682SXin Li       tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].r * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].r;
137*1295d682SXin Li       tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].i * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].i;
138*1295d682SXin Li       sum[i] += (1-frac)*tmp;
139*1295d682SXin Li       sum[i+1] += frac*tmp;
140*1295d682SXin Li     }
141*1295d682SXin Li   }
142*1295d682SXin Li   sum[0] *= 2;
143*1295d682SXin Li   sum[NB_BANDS-1] *= 2;
144*1295d682SXin Li   for (i=0;i<NB_BANDS;i++)
145*1295d682SXin Li   {
146*1295d682SXin Li     bandE[i] = sum[i];
147*1295d682SXin Li   }
148*1295d682SXin Li }
149*1295d682SXin Li 
interp_band_gain(float * g,const float * bandE)150*1295d682SXin Li void interp_band_gain(float *g, const float *bandE) {
151*1295d682SXin Li   int i;
152*1295d682SXin Li   memset(g, 0, FREQ_SIZE);
153*1295d682SXin Li   for (i=0;i<NB_BANDS-1;i++)
154*1295d682SXin Li   {
155*1295d682SXin Li     int j;
156*1295d682SXin Li     int band_size;
157*1295d682SXin Li     band_size = (eband5ms[i+1]-eband5ms[i])<<FRAME_SIZE_SHIFT;
158*1295d682SXin Li     for (j=0;j<band_size;j++) {
159*1295d682SXin Li       float frac = (float)j/band_size;
160*1295d682SXin Li       g[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j] = (1-frac)*bandE[i] + frac*bandE[i+1];
161*1295d682SXin Li     }
162*1295d682SXin Li   }
163*1295d682SXin Li }
164*1295d682SXin Li 
165*1295d682SXin Li 
166*1295d682SXin Li CommonState common;
167*1295d682SXin Li 
check_init()168*1295d682SXin Li static void check_init() {
169*1295d682SXin Li   int i;
170*1295d682SXin Li   if (common.init) return;
171*1295d682SXin Li   common.kfft = opus_fft_alloc_twiddles(2*FRAME_SIZE, NULL, NULL, NULL, 0);
172*1295d682SXin Li   for (i=0;i<FRAME_SIZE;i++)
173*1295d682SXin Li     common.half_window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
174*1295d682SXin Li   for (i=0;i<NB_BANDS;i++) {
175*1295d682SXin Li     int j;
176*1295d682SXin Li     for (j=0;j<NB_BANDS;j++) {
177*1295d682SXin Li       common.dct_table[i*NB_BANDS + j] = cos((i+.5)*j*M_PI/NB_BANDS);
178*1295d682SXin Li       if (j==0) common.dct_table[i*NB_BANDS + j] *= sqrt(.5);
179*1295d682SXin Li     }
180*1295d682SXin Li   }
181*1295d682SXin Li   common.init = 1;
182*1295d682SXin Li }
183*1295d682SXin Li 
dct(float * out,const float * in)184*1295d682SXin Li static void dct(float *out, const float *in) {
185*1295d682SXin Li   int i;
186*1295d682SXin Li   check_init();
187*1295d682SXin Li   for (i=0;i<NB_BANDS;i++) {
188*1295d682SXin Li     int j;
189*1295d682SXin Li     float sum = 0;
190*1295d682SXin Li     for (j=0;j<NB_BANDS;j++) {
191*1295d682SXin Li       sum += in[j] * common.dct_table[j*NB_BANDS + i];
192*1295d682SXin Li     }
193*1295d682SXin Li     out[i] = sum*sqrt(2./22);
194*1295d682SXin Li   }
195*1295d682SXin Li }
196*1295d682SXin Li 
197*1295d682SXin Li #if 0
198*1295d682SXin Li static void idct(float *out, const float *in) {
199*1295d682SXin Li   int i;
200*1295d682SXin Li   check_init();
201*1295d682SXin Li   for (i=0;i<NB_BANDS;i++) {
202*1295d682SXin Li     int j;
203*1295d682SXin Li     float sum = 0;
204*1295d682SXin Li     for (j=0;j<NB_BANDS;j++) {
205*1295d682SXin Li       sum += in[j] * common.dct_table[i*NB_BANDS + j];
206*1295d682SXin Li     }
207*1295d682SXin Li     out[i] = sum*sqrt(2./22);
208*1295d682SXin Li   }
209*1295d682SXin Li }
210*1295d682SXin Li #endif
211*1295d682SXin Li 
forward_transform(kiss_fft_cpx * out,const float * in)212*1295d682SXin Li static void forward_transform(kiss_fft_cpx *out, const float *in) {
213*1295d682SXin Li   int i;
214*1295d682SXin Li   kiss_fft_cpx x[WINDOW_SIZE];
215*1295d682SXin Li   kiss_fft_cpx y[WINDOW_SIZE];
216*1295d682SXin Li   check_init();
217*1295d682SXin Li   for (i=0;i<WINDOW_SIZE;i++) {
218*1295d682SXin Li     x[i].r = in[i];
219*1295d682SXin Li     x[i].i = 0;
220*1295d682SXin Li   }
221*1295d682SXin Li   opus_fft(common.kfft, x, y, 0);
222*1295d682SXin Li   for (i=0;i<FREQ_SIZE;i++) {
223*1295d682SXin Li     out[i] = y[i];
224*1295d682SXin Li   }
225*1295d682SXin Li }
226*1295d682SXin Li 
inverse_transform(float * out,const kiss_fft_cpx * in)227*1295d682SXin Li static void inverse_transform(float *out, const kiss_fft_cpx *in) {
228*1295d682SXin Li   int i;
229*1295d682SXin Li   kiss_fft_cpx x[WINDOW_SIZE];
230*1295d682SXin Li   kiss_fft_cpx y[WINDOW_SIZE];
231*1295d682SXin Li   check_init();
232*1295d682SXin Li   for (i=0;i<FREQ_SIZE;i++) {
233*1295d682SXin Li     x[i] = in[i];
234*1295d682SXin Li   }
235*1295d682SXin Li   for (;i<WINDOW_SIZE;i++) {
236*1295d682SXin Li     x[i].r = x[WINDOW_SIZE - i].r;
237*1295d682SXin Li     x[i].i = -x[WINDOW_SIZE - i].i;
238*1295d682SXin Li   }
239*1295d682SXin Li   opus_fft(common.kfft, x, y, 0);
240*1295d682SXin Li   /* output in reverse order for IFFT. */
241*1295d682SXin Li   out[0] = WINDOW_SIZE*y[0].r;
242*1295d682SXin Li   for (i=1;i<WINDOW_SIZE;i++) {
243*1295d682SXin Li     out[i] = WINDOW_SIZE*y[WINDOW_SIZE - i].r;
244*1295d682SXin Li   }
245*1295d682SXin Li }
246*1295d682SXin Li 
apply_window(float * x)247*1295d682SXin Li static void apply_window(float *x) {
248*1295d682SXin Li   int i;
249*1295d682SXin Li   check_init();
250*1295d682SXin Li   for (i=0;i<FRAME_SIZE;i++) {
251*1295d682SXin Li     x[i] *= common.half_window[i];
252*1295d682SXin Li     x[WINDOW_SIZE - 1 - i] *= common.half_window[i];
253*1295d682SXin Li   }
254*1295d682SXin Li }
255*1295d682SXin Li 
rnnoise_get_size()256*1295d682SXin Li int rnnoise_get_size() {
257*1295d682SXin Li   return sizeof(DenoiseState);
258*1295d682SXin Li }
259*1295d682SXin Li 
rnnoise_get_frame_size()260*1295d682SXin Li int rnnoise_get_frame_size() {
261*1295d682SXin Li   return FRAME_SIZE;
262*1295d682SXin Li }
263*1295d682SXin Li 
rnnoise_init(DenoiseState * st,RNNModel * model)264*1295d682SXin Li int rnnoise_init(DenoiseState *st, RNNModel *model) {
265*1295d682SXin Li   memset(st, 0, sizeof(*st));
266*1295d682SXin Li   if (model)
267*1295d682SXin Li     st->rnn.model = model;
268*1295d682SXin Li   else
269*1295d682SXin Li     st->rnn.model = &rnnoise_model_orig;
270*1295d682SXin Li   st->rnn.vad_gru_state = calloc(sizeof(float), st->rnn.model->vad_gru_size);
271*1295d682SXin Li   st->rnn.noise_gru_state = calloc(sizeof(float), st->rnn.model->noise_gru_size);
272*1295d682SXin Li   st->rnn.denoise_gru_state = calloc(sizeof(float), st->rnn.model->denoise_gru_size);
273*1295d682SXin Li   return 0;
274*1295d682SXin Li }
275*1295d682SXin Li 
rnnoise_create(RNNModel * model)276*1295d682SXin Li DenoiseState *rnnoise_create(RNNModel *model) {
277*1295d682SXin Li   DenoiseState *st;
278*1295d682SXin Li   st = malloc(rnnoise_get_size());
279*1295d682SXin Li   rnnoise_init(st, model);
280*1295d682SXin Li   return st;
281*1295d682SXin Li }
282*1295d682SXin Li 
rnnoise_destroy(DenoiseState * st)283*1295d682SXin Li void rnnoise_destroy(DenoiseState *st) {
284*1295d682SXin Li   free(st->rnn.vad_gru_state);
285*1295d682SXin Li   free(st->rnn.noise_gru_state);
286*1295d682SXin Li   free(st->rnn.denoise_gru_state);
287*1295d682SXin Li   free(st);
288*1295d682SXin Li }
289*1295d682SXin Li 
290*1295d682SXin Li #if TRAINING
291*1295d682SXin Li int lowpass = FREQ_SIZE;
292*1295d682SXin Li int band_lp = NB_BANDS;
293*1295d682SXin Li #endif
294*1295d682SXin Li 
frame_analysis(DenoiseState * st,kiss_fft_cpx * X,float * Ex,const float * in)295*1295d682SXin Li static void frame_analysis(DenoiseState *st, kiss_fft_cpx *X, float *Ex, const float *in) {
296*1295d682SXin Li   int i;
297*1295d682SXin Li   float x[WINDOW_SIZE];
298*1295d682SXin Li   RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
299*1295d682SXin Li   for (i=0;i<FRAME_SIZE;i++) x[FRAME_SIZE + i] = in[i];
300*1295d682SXin Li   RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
301*1295d682SXin Li   apply_window(x);
302*1295d682SXin Li   forward_transform(X, x);
303*1295d682SXin Li #if TRAINING
304*1295d682SXin Li   for (i=lowpass;i<FREQ_SIZE;i++)
305*1295d682SXin Li     X[i].r = X[i].i = 0;
306*1295d682SXin Li #endif
307*1295d682SXin Li   compute_band_energy(Ex, X);
308*1295d682SXin Li }
309*1295d682SXin Li 
compute_frame_features(DenoiseState * st,kiss_fft_cpx * X,kiss_fft_cpx * P,float * Ex,float * Ep,float * Exp,float * features,const float * in)310*1295d682SXin Li static int compute_frame_features(DenoiseState *st, kiss_fft_cpx *X, kiss_fft_cpx *P,
311*1295d682SXin Li                                   float *Ex, float *Ep, float *Exp, float *features, const float *in) {
312*1295d682SXin Li   int i;
313*1295d682SXin Li   float E = 0;
314*1295d682SXin Li   float *ceps_0, *ceps_1, *ceps_2;
315*1295d682SXin Li   float spec_variability = 0;
316*1295d682SXin Li   float Ly[NB_BANDS];
317*1295d682SXin Li   float p[WINDOW_SIZE];
318*1295d682SXin Li   float pitch_buf[PITCH_BUF_SIZE>>1];
319*1295d682SXin Li   int pitch_index;
320*1295d682SXin Li   float gain;
321*1295d682SXin Li   float *(pre[1]);
322*1295d682SXin Li   float tmp[NB_BANDS];
323*1295d682SXin Li   float follow, logMax;
324*1295d682SXin Li   frame_analysis(st, X, Ex, in);
325*1295d682SXin Li   RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
326*1295d682SXin Li   RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
327*1295d682SXin Li   pre[0] = &st->pitch_buf[0];
328*1295d682SXin Li   pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
329*1295d682SXin Li   pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
330*1295d682SXin Li                PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
331*1295d682SXin Li   pitch_index = PITCH_MAX_PERIOD-pitch_index;
332*1295d682SXin Li 
333*1295d682SXin Li   gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
334*1295d682SXin Li           PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
335*1295d682SXin Li   st->last_period = pitch_index;
336*1295d682SXin Li   st->last_gain = gain;
337*1295d682SXin Li   for (i=0;i<WINDOW_SIZE;i++)
338*1295d682SXin Li     p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
339*1295d682SXin Li   apply_window(p);
340*1295d682SXin Li   forward_transform(P, p);
341*1295d682SXin Li   compute_band_energy(Ep, P);
342*1295d682SXin Li   compute_band_corr(Exp, X, P);
343*1295d682SXin Li   for (i=0;i<NB_BANDS;i++) Exp[i] = Exp[i]/sqrt(.001+Ex[i]*Ep[i]);
344*1295d682SXin Li   dct(tmp, Exp);
345*1295d682SXin Li   for (i=0;i<NB_DELTA_CEPS;i++) features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
346*1295d682SXin Li   features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
347*1295d682SXin Li   features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
348*1295d682SXin Li   features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
349*1295d682SXin Li   logMax = -2;
350*1295d682SXin Li   follow = -2;
351*1295d682SXin Li   for (i=0;i<NB_BANDS;i++) {
352*1295d682SXin Li     Ly[i] = log10(1e-2+Ex[i]);
353*1295d682SXin Li     Ly[i] = MAX16(logMax-7, MAX16(follow-1.5, Ly[i]));
354*1295d682SXin Li     logMax = MAX16(logMax, Ly[i]);
355*1295d682SXin Li     follow = MAX16(follow-1.5, Ly[i]);
356*1295d682SXin Li     E += Ex[i];
357*1295d682SXin Li   }
358*1295d682SXin Li   if (!TRAINING && E < 0.04) {
359*1295d682SXin Li     /* If there's no audio, avoid messing up the state. */
360*1295d682SXin Li     RNN_CLEAR(features, NB_FEATURES);
361*1295d682SXin Li     return 1;
362*1295d682SXin Li   }
363*1295d682SXin Li   dct(features, Ly);
364*1295d682SXin Li   features[0] -= 12;
365*1295d682SXin Li   features[1] -= 4;
366*1295d682SXin Li   ceps_0 = st->cepstral_mem[st->memid];
367*1295d682SXin Li   ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
368*1295d682SXin Li   ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
369*1295d682SXin Li   for (i=0;i<NB_BANDS;i++) ceps_0[i] = features[i];
370*1295d682SXin Li   st->memid++;
371*1295d682SXin Li   for (i=0;i<NB_DELTA_CEPS;i++) {
372*1295d682SXin Li     features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
373*1295d682SXin Li     features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
374*1295d682SXin Li     features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
375*1295d682SXin Li   }
376*1295d682SXin Li   /* Spectral variability features. */
377*1295d682SXin Li   if (st->memid == CEPS_MEM) st->memid = 0;
378*1295d682SXin Li   for (i=0;i<CEPS_MEM;i++)
379*1295d682SXin Li   {
380*1295d682SXin Li     int j;
381*1295d682SXin Li     float mindist = 1e15f;
382*1295d682SXin Li     for (j=0;j<CEPS_MEM;j++)
383*1295d682SXin Li     {
384*1295d682SXin Li       int k;
385*1295d682SXin Li       float dist=0;
386*1295d682SXin Li       for (k=0;k<NB_BANDS;k++)
387*1295d682SXin Li       {
388*1295d682SXin Li         float tmp;
389*1295d682SXin Li         tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
390*1295d682SXin Li         dist += tmp*tmp;
391*1295d682SXin Li       }
392*1295d682SXin Li       if (j!=i)
393*1295d682SXin Li         mindist = MIN32(mindist, dist);
394*1295d682SXin Li     }
395*1295d682SXin Li     spec_variability += mindist;
396*1295d682SXin Li   }
397*1295d682SXin Li   features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
398*1295d682SXin Li   return TRAINING && E < 0.1;
399*1295d682SXin Li }
400*1295d682SXin Li 
frame_synthesis(DenoiseState * st,float * out,const kiss_fft_cpx * y)401*1295d682SXin Li static void frame_synthesis(DenoiseState *st, float *out, const kiss_fft_cpx *y) {
402*1295d682SXin Li   float x[WINDOW_SIZE];
403*1295d682SXin Li   int i;
404*1295d682SXin Li   inverse_transform(x, y);
405*1295d682SXin Li   apply_window(x);
406*1295d682SXin Li   for (i=0;i<FRAME_SIZE;i++) out[i] = x[i] + st->synthesis_mem[i];
407*1295d682SXin Li   RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
408*1295d682SXin Li }
409*1295d682SXin Li 
biquad(float * y,float mem[2],const float * x,const float * b,const float * a,int N)410*1295d682SXin Li static void biquad(float *y, float mem[2], const float *x, const float *b, const float *a, int N) {
411*1295d682SXin Li   int i;
412*1295d682SXin Li   for (i=0;i<N;i++) {
413*1295d682SXin Li     float xi, yi;
414*1295d682SXin Li     xi = x[i];
415*1295d682SXin Li     yi = x[i] + mem[0];
416*1295d682SXin Li     mem[0] = mem[1] + (b[0]*(double)xi - a[0]*(double)yi);
417*1295d682SXin Li     mem[1] = (b[1]*(double)xi - a[1]*(double)yi);
418*1295d682SXin Li     y[i] = yi;
419*1295d682SXin Li   }
420*1295d682SXin Li }
421*1295d682SXin Li 
pitch_filter(kiss_fft_cpx * X,const kiss_fft_cpx * P,const float * Ex,const float * Ep,const float * Exp,const float * g)422*1295d682SXin Li void pitch_filter(kiss_fft_cpx *X, const kiss_fft_cpx *P, const float *Ex, const float *Ep,
423*1295d682SXin Li                   const float *Exp, const float *g) {
424*1295d682SXin Li   int i;
425*1295d682SXin Li   float r[NB_BANDS];
426*1295d682SXin Li   float rf[FREQ_SIZE] = {0};
427*1295d682SXin Li   for (i=0;i<NB_BANDS;i++) {
428*1295d682SXin Li #if 0
429*1295d682SXin Li     if (Exp[i]>g[i]) r[i] = 1;
430*1295d682SXin Li     else r[i] = Exp[i]*(1-g[i])/(.001 + g[i]*(1-Exp[i]));
431*1295d682SXin Li     r[i] = MIN16(1, MAX16(0, r[i]));
432*1295d682SXin Li #else
433*1295d682SXin Li     if (Exp[i]>g[i]) r[i] = 1;
434*1295d682SXin Li     else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
435*1295d682SXin Li     r[i] = sqrt(MIN16(1, MAX16(0, r[i])));
436*1295d682SXin Li #endif
437*1295d682SXin Li     r[i] *= sqrt(Ex[i]/(1e-8+Ep[i]));
438*1295d682SXin Li   }
439*1295d682SXin Li   interp_band_gain(rf, r);
440*1295d682SXin Li   for (i=0;i<FREQ_SIZE;i++) {
441*1295d682SXin Li     X[i].r += rf[i]*P[i].r;
442*1295d682SXin Li     X[i].i += rf[i]*P[i].i;
443*1295d682SXin Li   }
444*1295d682SXin Li   float newE[NB_BANDS];
445*1295d682SXin Li   compute_band_energy(newE, X);
446*1295d682SXin Li   float norm[NB_BANDS];
447*1295d682SXin Li   float normf[FREQ_SIZE]={0};
448*1295d682SXin Li   for (i=0;i<NB_BANDS;i++) {
449*1295d682SXin Li     norm[i] = sqrt(Ex[i]/(1e-8+newE[i]));
450*1295d682SXin Li   }
451*1295d682SXin Li   interp_band_gain(normf, norm);
452*1295d682SXin Li   for (i=0;i<FREQ_SIZE;i++) {
453*1295d682SXin Li     X[i].r *= normf[i];
454*1295d682SXin Li     X[i].i *= normf[i];
455*1295d682SXin Li   }
456*1295d682SXin Li }
457*1295d682SXin Li 
rnnoise_process_frame(DenoiseState * st,float * out,const float * in)458*1295d682SXin Li float rnnoise_process_frame(DenoiseState *st, float *out, const float *in) {
459*1295d682SXin Li   int i;
460*1295d682SXin Li   kiss_fft_cpx X[FREQ_SIZE];
461*1295d682SXin Li   kiss_fft_cpx P[WINDOW_SIZE];
462*1295d682SXin Li   float x[FRAME_SIZE];
463*1295d682SXin Li   float Ex[NB_BANDS], Ep[NB_BANDS];
464*1295d682SXin Li   float Exp[NB_BANDS];
465*1295d682SXin Li   float features[NB_FEATURES];
466*1295d682SXin Li   float g[NB_BANDS];
467*1295d682SXin Li   float gf[FREQ_SIZE]={1};
468*1295d682SXin Li   float vad_prob = 0;
469*1295d682SXin Li   int silence;
470*1295d682SXin Li   static const float a_hp[2] = {-1.99599, 0.99600};
471*1295d682SXin Li   static const float b_hp[2] = {-2, 1};
472*1295d682SXin Li   biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
473*1295d682SXin Li   silence = compute_frame_features(st, X, P, Ex, Ep, Exp, features, x);
474*1295d682SXin Li 
475*1295d682SXin Li   if (!silence) {
476*1295d682SXin Li     compute_rnn(&st->rnn, g, &vad_prob, features);
477*1295d682SXin Li     pitch_filter(X, P, Ex, Ep, Exp, g);
478*1295d682SXin Li     for (i=0;i<NB_BANDS;i++) {
479*1295d682SXin Li       float alpha = .6f;
480*1295d682SXin Li       g[i] = MAX16(g[i], alpha*st->lastg[i]);
481*1295d682SXin Li       st->lastg[i] = g[i];
482*1295d682SXin Li     }
483*1295d682SXin Li     interp_band_gain(gf, g);
484*1295d682SXin Li #if 1
485*1295d682SXin Li     for (i=0;i<FREQ_SIZE;i++) {
486*1295d682SXin Li       X[i].r *= gf[i];
487*1295d682SXin Li       X[i].i *= gf[i];
488*1295d682SXin Li     }
489*1295d682SXin Li #endif
490*1295d682SXin Li   }
491*1295d682SXin Li 
492*1295d682SXin Li   frame_synthesis(st, out, X);
493*1295d682SXin Li   return vad_prob;
494*1295d682SXin Li }
495*1295d682SXin Li 
496*1295d682SXin Li #if TRAINING
497*1295d682SXin Li 
uni_rand()498*1295d682SXin Li static float uni_rand() {
499*1295d682SXin Li   return rand()/(double)RAND_MAX-.5;
500*1295d682SXin Li }
501*1295d682SXin Li 
rand_resp(float * a,float * b)502*1295d682SXin Li static void rand_resp(float *a, float *b) {
503*1295d682SXin Li   a[0] = .75*uni_rand();
504*1295d682SXin Li   a[1] = .75*uni_rand();
505*1295d682SXin Li   b[0] = .75*uni_rand();
506*1295d682SXin Li   b[1] = .75*uni_rand();
507*1295d682SXin Li }
508*1295d682SXin Li 
main(int argc,char ** argv)509*1295d682SXin Li int main(int argc, char **argv) {
510*1295d682SXin Li   int i;
511*1295d682SXin Li   int count=0;
512*1295d682SXin Li   static const float a_hp[2] = {-1.99599, 0.99600};
513*1295d682SXin Li   static const float b_hp[2] = {-2, 1};
514*1295d682SXin Li   float a_noise[2] = {0};
515*1295d682SXin Li   float b_noise[2] = {0};
516*1295d682SXin Li   float a_sig[2] = {0};
517*1295d682SXin Li   float b_sig[2] = {0};
518*1295d682SXin Li   float mem_hp_x[2]={0};
519*1295d682SXin Li   float mem_hp_n[2]={0};
520*1295d682SXin Li   float mem_resp_x[2]={0};
521*1295d682SXin Li   float mem_resp_n[2]={0};
522*1295d682SXin Li   float x[FRAME_SIZE];
523*1295d682SXin Li   float n[FRAME_SIZE];
524*1295d682SXin Li   float xn[FRAME_SIZE];
525*1295d682SXin Li   int vad_cnt=0;
526*1295d682SXin Li   int gain_change_count=0;
527*1295d682SXin Li   float speech_gain = 1, noise_gain = 1;
528*1295d682SXin Li   FILE *f1, *f2;
529*1295d682SXin Li   int maxCount;
530*1295d682SXin Li   DenoiseState *st;
531*1295d682SXin Li   DenoiseState *noise_state;
532*1295d682SXin Li   DenoiseState *noisy;
533*1295d682SXin Li   st = rnnoise_create(NULL);
534*1295d682SXin Li   noise_state = rnnoise_create(NULL);
535*1295d682SXin Li   noisy = rnnoise_create(NULL);
536*1295d682SXin Li   if (argc!=4) {
537*1295d682SXin Li     fprintf(stderr, "usage: %s <speech> <noise> <count>\n", argv[0]);
538*1295d682SXin Li     return 1;
539*1295d682SXin Li   }
540*1295d682SXin Li   f1 = fopen(argv[1], "r");
541*1295d682SXin Li   f2 = fopen(argv[2], "r");
542*1295d682SXin Li   maxCount = atoi(argv[3]);
543*1295d682SXin Li   for(i=0;i<150;i++) {
544*1295d682SXin Li     short tmp[FRAME_SIZE];
545*1295d682SXin Li     fread(tmp, sizeof(short), FRAME_SIZE, f2);
546*1295d682SXin Li   }
547*1295d682SXin Li   while (1) {
548*1295d682SXin Li     kiss_fft_cpx X[FREQ_SIZE], Y[FREQ_SIZE], N[FREQ_SIZE], P[WINDOW_SIZE];
549*1295d682SXin Li     float Ex[NB_BANDS], Ey[NB_BANDS], En[NB_BANDS], Ep[NB_BANDS];
550*1295d682SXin Li     float Exp[NB_BANDS];
551*1295d682SXin Li     float Ln[NB_BANDS];
552*1295d682SXin Li     float features[NB_FEATURES];
553*1295d682SXin Li     float g[NB_BANDS];
554*1295d682SXin Li     short tmp[FRAME_SIZE];
555*1295d682SXin Li     float vad=0;
556*1295d682SXin Li     float E=0;
557*1295d682SXin Li     if (count==maxCount) break;
558*1295d682SXin Li     if ((count%1000)==0) fprintf(stderr, "%d\r", count);
559*1295d682SXin Li     if (++gain_change_count > 2821) {
560*1295d682SXin Li       speech_gain = pow(10., (-40+(rand()%60))/20.);
561*1295d682SXin Li       noise_gain = pow(10., (-30+(rand()%50))/20.);
562*1295d682SXin Li       if (rand()%10==0) noise_gain = 0;
563*1295d682SXin Li       noise_gain *= speech_gain;
564*1295d682SXin Li       if (rand()%10==0) speech_gain = 0;
565*1295d682SXin Li       gain_change_count = 0;
566*1295d682SXin Li       rand_resp(a_noise, b_noise);
567*1295d682SXin Li       rand_resp(a_sig, b_sig);
568*1295d682SXin Li       lowpass = FREQ_SIZE * 3000./24000. * pow(50., rand()/(double)RAND_MAX);
569*1295d682SXin Li       for (i=0;i<NB_BANDS;i++) {
570*1295d682SXin Li         if (eband5ms[i]<<FRAME_SIZE_SHIFT > lowpass) {
571*1295d682SXin Li           band_lp = i;
572*1295d682SXin Li           break;
573*1295d682SXin Li         }
574*1295d682SXin Li       }
575*1295d682SXin Li     }
576*1295d682SXin Li     if (speech_gain != 0) {
577*1295d682SXin Li       fread(tmp, sizeof(short), FRAME_SIZE, f1);
578*1295d682SXin Li       if (feof(f1)) {
579*1295d682SXin Li         rewind(f1);
580*1295d682SXin Li         fread(tmp, sizeof(short), FRAME_SIZE, f1);
581*1295d682SXin Li       }
582*1295d682SXin Li       for (i=0;i<FRAME_SIZE;i++) x[i] = speech_gain*tmp[i];
583*1295d682SXin Li       for (i=0;i<FRAME_SIZE;i++) E += tmp[i]*(float)tmp[i];
584*1295d682SXin Li     } else {
585*1295d682SXin Li       for (i=0;i<FRAME_SIZE;i++) x[i] = 0;
586*1295d682SXin Li       E = 0;
587*1295d682SXin Li     }
588*1295d682SXin Li     if (noise_gain!=0) {
589*1295d682SXin Li       fread(tmp, sizeof(short), FRAME_SIZE, f2);
590*1295d682SXin Li       if (feof(f2)) {
591*1295d682SXin Li         rewind(f2);
592*1295d682SXin Li         fread(tmp, sizeof(short), FRAME_SIZE, f2);
593*1295d682SXin Li       }
594*1295d682SXin Li       for (i=0;i<FRAME_SIZE;i++) n[i] = noise_gain*tmp[i];
595*1295d682SXin Li     } else {
596*1295d682SXin Li       for (i=0;i<FRAME_SIZE;i++) n[i] = 0;
597*1295d682SXin Li     }
598*1295d682SXin Li     biquad(x, mem_hp_x, x, b_hp, a_hp, FRAME_SIZE);
599*1295d682SXin Li     biquad(x, mem_resp_x, x, b_sig, a_sig, FRAME_SIZE);
600*1295d682SXin Li     biquad(n, mem_hp_n, n, b_hp, a_hp, FRAME_SIZE);
601*1295d682SXin Li     biquad(n, mem_resp_n, n, b_noise, a_noise, FRAME_SIZE);
602*1295d682SXin Li     for (i=0;i<FRAME_SIZE;i++) xn[i] = x[i] + n[i];
603*1295d682SXin Li     if (E > 1e9f) {
604*1295d682SXin Li       vad_cnt=0;
605*1295d682SXin Li     } else if (E > 1e8f) {
606*1295d682SXin Li       vad_cnt -= 5;
607*1295d682SXin Li     } else if (E > 1e7f) {
608*1295d682SXin Li       vad_cnt++;
609*1295d682SXin Li     } else {
610*1295d682SXin Li       vad_cnt+=2;
611*1295d682SXin Li     }
612*1295d682SXin Li     if (vad_cnt < 0) vad_cnt = 0;
613*1295d682SXin Li     if (vad_cnt > 15) vad_cnt = 15;
614*1295d682SXin Li 
615*1295d682SXin Li     if (vad_cnt >= 10) vad = 0;
616*1295d682SXin Li     else if (vad_cnt > 0) vad = 0.5f;
617*1295d682SXin Li     else vad = 1.f;
618*1295d682SXin Li 
619*1295d682SXin Li     frame_analysis(st, Y, Ey, x);
620*1295d682SXin Li     frame_analysis(noise_state, N, En, n);
621*1295d682SXin Li     for (i=0;i<NB_BANDS;i++) Ln[i] = log10(1e-2+En[i]);
622*1295d682SXin Li     int silence = compute_frame_features(noisy, X, P, Ex, Ep, Exp, features, xn);
623*1295d682SXin Li     pitch_filter(X, P, Ex, Ep, Exp, g);
624*1295d682SXin Li     //printf("%f %d\n", noisy->last_gain, noisy->last_period);
625*1295d682SXin Li     for (i=0;i<NB_BANDS;i++) {
626*1295d682SXin Li       g[i] = sqrt((Ey[i]+1e-3)/(Ex[i]+1e-3));
627*1295d682SXin Li       if (g[i] > 1) g[i] = 1;
628*1295d682SXin Li       if (silence || i > band_lp) g[i] = -1;
629*1295d682SXin Li       if (Ey[i] < 5e-2 && Ex[i] < 5e-2) g[i] = -1;
630*1295d682SXin Li       if (vad==0 && noise_gain==0) g[i] = -1;
631*1295d682SXin Li     }
632*1295d682SXin Li     count++;
633*1295d682SXin Li #if 1
634*1295d682SXin Li     fwrite(features, sizeof(float), NB_FEATURES, stdout);
635*1295d682SXin Li     fwrite(g, sizeof(float), NB_BANDS, stdout);
636*1295d682SXin Li     fwrite(Ln, sizeof(float), NB_BANDS, stdout);
637*1295d682SXin Li     fwrite(&vad, sizeof(float), 1, stdout);
638*1295d682SXin Li #endif
639*1295d682SXin Li   }
640*1295d682SXin Li   fprintf(stderr, "matrix size: %d x %d\n", count, NB_FEATURES + 2*NB_BANDS + 1);
641*1295d682SXin Li   fclose(f1);
642*1295d682SXin Li   fclose(f2);
643*1295d682SXin Li   return 0;
644*1295d682SXin Li }
645*1295d682SXin Li 
646*1295d682SXin Li #endif
647