xref: /aosp_15_r20/external/libopus/dnn/training_tf2/train_plc.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1#!/usr/bin/python3
2'''Copyright (c) 2021-2022 Amazon
3   Copyright (c) 2018-2019 Mozilla
4
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
20   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27'''
28
29# Train an LPCNet model
30
31import argparse
32from plc_loader import PLCLoader
33
34parser = argparse.ArgumentParser(description='Train a PLC model')
35
36parser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
37parser.add_argument('lost_file', metavar='<packet loss file>', help='packet loss traces (int8)')
38parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
39parser.add_argument('--model', metavar='<model>', default='lpcnet_plc', help='PLC model python definition (without .py)')
40group1 = parser.add_mutually_exclusive_group()
41group1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
42group1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
43parser.add_argument('--gru-size', metavar='<units>', default=256, type=int, help='number of units in GRU (default 256)')
44parser.add_argument('--cond-size', metavar='<units>', default=128, type=int, help='number of units in conditioning network (default 128)')
45parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
46parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
47parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)')
48parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
49parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
50parser.add_argument('--band-loss', metavar='<weight>', default=1.0, type=float, help='weight of band loss (default 1.0)')
51parser.add_argument('--loss-bias', metavar='<bias>', default=0.0, type=float, help='loss bias towards low energy (default 0.0)')
52parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
53
54
55args = parser.parse_args()
56
57import importlib
58lpcnet = importlib.import_module(args.model)
59
60import sys
61import numpy as np
62from tensorflow.keras.optimizers import Adam
63from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
64import tensorflow.keras.backend as K
65import h5py
66
67import tensorflow as tf
68#gpus = tf.config.experimental.list_physical_devices('GPU')
69#if gpus:
70#  try:
71#    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
72#  except RuntimeError as e:
73#    print(e)
74
75nb_epochs = args.epochs
76
77# Try reducing batch_size if you run out of memory on your GPU
78batch_size = args.batch_size
79
80quantize = args.quantize is not None
81retrain = args.retrain is not None
82
83if quantize:
84    lr = 0.00003
85    decay = 0
86    input_model = args.quantize
87else:
88    lr = 0.001
89    decay = 2.5e-5
90
91if args.lr is not None:
92    lr = args.lr
93
94if args.decay is not None:
95    decay = args.decay
96
97if retrain:
98    input_model = args.retrain
99
100def plc_loss(alpha=1.0, bias=0.):
101    def loss(y_true,y_pred):
102        mask = y_true[:,:,-1:]
103        y_true = y_true[:,:,:-1]
104        e = (y_pred - y_true)*mask
105        e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
106        bias_mask = K.minimum(1., K.maximum(0., 4*y_true[:,:,-1:]))
107        l1_loss = K.mean(K.abs(e)) + 0.1*K.mean(K.maximum(0., -e[:,:,-1:])) + alpha*K.mean(K.abs(e_bands) + bias*bias_mask*K.maximum(0., e_bands)) + K.mean(K.minimum(K.abs(e[:,:,18:19]),1.)) + 8*K.mean(K.minimum(K.abs(e[:,:,18:19]),.4))
108        return l1_loss
109    return loss
110
111def plc_l1_loss():
112    def L1_loss(y_true,y_pred):
113        mask = y_true[:,:,-1:]
114        y_true = y_true[:,:,:-1]
115        e = (y_pred - y_true)*mask
116        l1_loss = K.mean(K.abs(e))
117        return l1_loss
118    return L1_loss
119
120def plc_ceps_loss():
121    def ceps_loss(y_true,y_pred):
122        mask = y_true[:,:,-1:]
123        y_true = y_true[:,:,:-1]
124        e = (y_pred - y_true)*mask
125        l1_loss = K.mean(K.abs(e[:,:,:-2]))
126        return l1_loss
127    return ceps_loss
128
129def plc_band_loss():
130    def L1_band_loss(y_true,y_pred):
131        mask = y_true[:,:,-1:]
132        y_true = y_true[:,:,:-1]
133        e = (y_pred - y_true)*mask
134        e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
135        l1_loss = K.mean(K.abs(e_bands))
136        return l1_loss
137    return L1_band_loss
138
139def plc_pitch_loss():
140    def pitch_loss(y_true,y_pred):
141        mask = y_true[:,:,-1:]
142        y_true = y_true[:,:,:-1]
143        e = (y_pred - y_true)*mask
144        l1_loss = K.mean(K.minimum(K.abs(e[:,:,18:19]),.4))
145        return l1_loss
146    return pitch_loss
147
148opt = Adam(lr, decay=decay, beta_2=0.99)
149strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
150
151with strategy.scope():
152    model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size)
153    model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_ceps_loss(), plc_band_loss(), plc_pitch_loss()])
154    model.summary()
155
156lpc_order = 16
157
158feature_file = args.features
159nb_features = model.nb_used_features + lpc_order + model.nb_burg_features
160nb_used_features = model.nb_used_features
161nb_burg_features = model.nb_burg_features
162sequence_size = args.seq_length
163
164# u for unquantised, load 16 bit PCM samples and convert to mu-law
165
166
167features = np.memmap(feature_file, dtype='float32', mode='r')
168nb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size
169features = features[:nb_sequences*sequence_size*nb_features]
170
171features = np.reshape(features, (nb_sequences, sequence_size, nb_features))
172
173features = features[:, :, :nb_used_features+model.nb_burg_features]
174
175lost = np.memmap(args.lost_file, dtype='int8', mode='r')
176
177# dump models to disk as we go
178checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.gru_size, '{epoch:02d}'))
179
180if args.retrain is not None:
181    model.load_weights(args.retrain)
182
183if quantize or retrain:
184    #Adapting from an existing model
185    model.load_weights(input_model)
186
187model.save_weights('{}_{}_initial.h5'.format(args.output, args.gru_size))
188
189loader = PLCLoader(features, lost, nb_burg_features, batch_size)
190
191callbacks = [checkpoint]
192if args.logdir is not None:
193    logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.gru_size)
194    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
195    callbacks.append(tensorboard_callback)
196
197model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
198