xref: /aosp_15_r20/external/rnnoise/training/rnn_train.py (revision 1295d6828459cc82c3c29cc5d7d297215250a74b)
1*1295d682SXin Li#!/usr/bin/python
2*1295d682SXin Li
3*1295d682SXin Lifrom __future__ import print_function
4*1295d682SXin Li
5*1295d682SXin Liimport keras
6*1295d682SXin Lifrom keras.models import Sequential
7*1295d682SXin Lifrom keras.models import Model
8*1295d682SXin Lifrom keras.layers import Input
9*1295d682SXin Lifrom keras.layers import Dense
10*1295d682SXin Lifrom keras.layers import LSTM
11*1295d682SXin Lifrom keras.layers import GRU
12*1295d682SXin Lifrom keras.layers import SimpleRNN
13*1295d682SXin Lifrom keras.layers import Dropout
14*1295d682SXin Lifrom keras.layers import concatenate
15*1295d682SXin Lifrom keras import losses
16*1295d682SXin Lifrom keras import regularizers
17*1295d682SXin Lifrom keras.constraints import min_max_norm
18*1295d682SXin Liimport h5py
19*1295d682SXin Li
20*1295d682SXin Lifrom keras.constraints import Constraint
21*1295d682SXin Lifrom keras import backend as K
22*1295d682SXin Liimport numpy as np
23*1295d682SXin Li
24*1295d682SXin Li#import tensorflow as tf
25*1295d682SXin Li#from keras.backend.tensorflow_backend import set_session
26*1295d682SXin Li#config = tf.ConfigProto()
27*1295d682SXin Li#config.gpu_options.per_process_gpu_memory_fraction = 0.42
28*1295d682SXin Li#set_session(tf.Session(config=config))
29*1295d682SXin Li
30*1295d682SXin Li
31*1295d682SXin Lidef my_crossentropy(y_true, y_pred):
32*1295d682SXin Li    return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
33*1295d682SXin Li
34*1295d682SXin Lidef mymask(y_true):
35*1295d682SXin Li    return K.minimum(y_true+1., 1.)
36*1295d682SXin Li
37*1295d682SXin Lidef msse(y_true, y_pred):
38*1295d682SXin Li    return K.mean(mymask(y_true) * K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1)
39*1295d682SXin Li
40*1295d682SXin Lidef mycost(y_true, y_pred):
41*1295d682SXin Li    return K.mean(mymask(y_true) * (10*K.square(K.square(K.sqrt(y_pred) - K.sqrt(y_true))) + K.square(K.sqrt(y_pred) - K.sqrt(y_true)) + 0.01*K.binary_crossentropy(y_pred, y_true)), axis=-1)
42*1295d682SXin Li
43*1295d682SXin Lidef my_accuracy(y_true, y_pred):
44*1295d682SXin Li    return K.mean(2*K.abs(y_true-0.5) * K.equal(y_true, K.round(y_pred)), axis=-1)
45*1295d682SXin Li
46*1295d682SXin Liclass WeightClip(Constraint):
47*1295d682SXin Li    '''Clips the weights incident to each hidden unit to be inside a range
48*1295d682SXin Li    '''
49*1295d682SXin Li    def __init__(self, c=2):
50*1295d682SXin Li        self.c = c
51*1295d682SXin Li
52*1295d682SXin Li    def __call__(self, p):
53*1295d682SXin Li        return K.clip(p, -self.c, self.c)
54*1295d682SXin Li
55*1295d682SXin Li    def get_config(self):
56*1295d682SXin Li        return {'name': self.__class__.__name__,
57*1295d682SXin Li            'c': self.c}
58*1295d682SXin Li
59*1295d682SXin Lireg = 0.000001
60*1295d682SXin Liconstraint = WeightClip(0.499)
61*1295d682SXin Li
62*1295d682SXin Liprint('Build model...')
63*1295d682SXin Limain_input = Input(shape=(None, 42), name='main_input')
64*1295d682SXin Litmp = Dense(24, activation='tanh', name='input_dense', kernel_constraint=constraint, bias_constraint=constraint)(main_input)
65*1295d682SXin Livad_gru = GRU(24, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='vad_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(tmp)
66*1295d682SXin Livad_output = Dense(1, activation='sigmoid', name='vad_output', kernel_constraint=constraint, bias_constraint=constraint)(vad_gru)
67*1295d682SXin Linoise_input = keras.layers.concatenate([tmp, vad_gru, main_input])
68*1295d682SXin Linoise_gru = GRU(48, activation='relu', recurrent_activation='sigmoid', return_sequences=True, name='noise_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(noise_input)
69*1295d682SXin Lidenoise_input = keras.layers.concatenate([vad_gru, noise_gru, main_input])
70*1295d682SXin Li
71*1295d682SXin Lidenoise_gru = GRU(96, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='denoise_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(denoise_input)
72*1295d682SXin Li
73*1295d682SXin Lidenoise_output = Dense(22, activation='sigmoid', name='denoise_output', kernel_constraint=constraint, bias_constraint=constraint)(denoise_gru)
74*1295d682SXin Li
75*1295d682SXin Limodel = Model(inputs=main_input, outputs=[denoise_output, vad_output])
76*1295d682SXin Li
77*1295d682SXin Limodel.compile(loss=[mycost, my_crossentropy],
78*1295d682SXin Li              metrics=[msse],
79*1295d682SXin Li              optimizer='adam', loss_weights=[10, 0.5])
80*1295d682SXin Li
81*1295d682SXin Li
82*1295d682SXin Libatch_size = 32
83*1295d682SXin Li
84*1295d682SXin Liprint('Loading data...')
85*1295d682SXin Liwith h5py.File('training.h5', 'r') as hf:
86*1295d682SXin Li    all_data = hf['data'][:]
87*1295d682SXin Liprint('done.')
88*1295d682SXin Li
89*1295d682SXin Liwindow_size = 2000
90*1295d682SXin Li
91*1295d682SXin Linb_sequences = len(all_data)//window_size
92*1295d682SXin Liprint(nb_sequences, ' sequences')
93*1295d682SXin Lix_train = all_data[:nb_sequences*window_size, :42]
94*1295d682SXin Lix_train = np.reshape(x_train, (nb_sequences, window_size, 42))
95*1295d682SXin Li
96*1295d682SXin Liy_train = np.copy(all_data[:nb_sequences*window_size, 42:64])
97*1295d682SXin Liy_train = np.reshape(y_train, (nb_sequences, window_size, 22))
98*1295d682SXin Li
99*1295d682SXin Linoise_train = np.copy(all_data[:nb_sequences*window_size, 64:86])
100*1295d682SXin Linoise_train = np.reshape(noise_train, (nb_sequences, window_size, 22))
101*1295d682SXin Li
102*1295d682SXin Livad_train = np.copy(all_data[:nb_sequences*window_size, 86:87])
103*1295d682SXin Livad_train = np.reshape(vad_train, (nb_sequences, window_size, 1))
104*1295d682SXin Li
105*1295d682SXin Liall_data = 0;
106*1295d682SXin Li#x_train = x_train.astype('float32')
107*1295d682SXin Li#y_train = y_train.astype('float32')
108*1295d682SXin Li
109*1295d682SXin Liprint(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape)
110*1295d682SXin Li
111*1295d682SXin Liprint('Train...')
112*1295d682SXin Limodel.fit(x_train, [y_train, vad_train],
113*1295d682SXin Li          batch_size=batch_size,
114*1295d682SXin Li          epochs=120,
115*1295d682SXin Li          validation_split=0.1)
116*1295d682SXin Limodel.save("weights.hdf5")
117