1#!/usr/bin/python3 2'''Copyright (c) 2021-2022 Amazon 3 Copyright (c) 2018-2019 Mozilla 4 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 20 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27''' 28 29# Train an LPCNet model 30 31import argparse 32from plc_loader import PLCLoader 33 34parser = argparse.ArgumentParser(description='Train a PLC model') 35 36parser.add_argument('features', metavar='<features file>', help='binary features file (float32)') 37parser.add_argument('lost_file', metavar='<packet loss file>', help='packet loss traces (int8)') 38parser.add_argument('output', metavar='<output>', help='trained model file (.h5)') 39parser.add_argument('--model', metavar='<model>', default='lpcnet_plc', help='PLC model python definition (without .py)') 40group1 = parser.add_mutually_exclusive_group() 41group1.add_argument('--quantize', metavar='<input weights>', help='quantize model') 42group1.add_argument('--retrain', metavar='<input weights>', help='continue training model') 43parser.add_argument('--gru-size', metavar='<units>', default=256, type=int, help='number of units in GRU (default 256)') 44parser.add_argument('--cond-size', metavar='<units>', default=128, type=int, help='number of units in conditioning network (default 128)') 45parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)') 46parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)') 47parser.add_argument('--seq-length', metavar='<sequence length>', default=1000, type=int, help='sequence length to use (default 1000)') 48parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate') 49parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay') 50parser.add_argument('--band-loss', metavar='<weight>', default=1.0, type=float, help='weight of band loss (default 1.0)') 51parser.add_argument('--loss-bias', metavar='<bias>', default=0.0, type=float, help='loss bias towards low energy (default 0.0)') 52parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files') 53 54 55args = parser.parse_args() 56 57import importlib 58lpcnet = importlib.import_module(args.model) 59 60import sys 61import numpy as np 62from tensorflow.keras.optimizers import Adam 63from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger 64import tensorflow.keras.backend as K 65import h5py 66 67import tensorflow as tf 68#gpus = tf.config.experimental.list_physical_devices('GPU') 69#if gpus: 70# try: 71# tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)]) 72# except RuntimeError as e: 73# print(e) 74 75nb_epochs = args.epochs 76 77# Try reducing batch_size if you run out of memory on your GPU 78batch_size = args.batch_size 79 80quantize = args.quantize is not None 81retrain = args.retrain is not None 82 83if quantize: 84 lr = 0.00003 85 decay = 0 86 input_model = args.quantize 87else: 88 lr = 0.001 89 decay = 2.5e-5 90 91if args.lr is not None: 92 lr = args.lr 93 94if args.decay is not None: 95 decay = args.decay 96 97if retrain: 98 input_model = args.retrain 99 100def plc_loss(alpha=1.0, bias=0.): 101 def loss(y_true,y_pred): 102 mask = y_true[:,:,-1:] 103 y_true = y_true[:,:,:-1] 104 e = (y_pred - y_true)*mask 105 e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho') 106 bias_mask = K.minimum(1., K.maximum(0., 4*y_true[:,:,-1:])) 107 l1_loss = K.mean(K.abs(e)) + 0.1*K.mean(K.maximum(0., -e[:,:,-1:])) + alpha*K.mean(K.abs(e_bands) + bias*bias_mask*K.maximum(0., e_bands)) + K.mean(K.minimum(K.abs(e[:,:,18:19]),1.)) + 8*K.mean(K.minimum(K.abs(e[:,:,18:19]),.4)) 108 return l1_loss 109 return loss 110 111def plc_l1_loss(): 112 def L1_loss(y_true,y_pred): 113 mask = y_true[:,:,-1:] 114 y_true = y_true[:,:,:-1] 115 e = (y_pred - y_true)*mask 116 l1_loss = K.mean(K.abs(e)) 117 return l1_loss 118 return L1_loss 119 120def plc_ceps_loss(): 121 def ceps_loss(y_true,y_pred): 122 mask = y_true[:,:,-1:] 123 y_true = y_true[:,:,:-1] 124 e = (y_pred - y_true)*mask 125 l1_loss = K.mean(K.abs(e[:,:,:-2])) 126 return l1_loss 127 return ceps_loss 128 129def plc_band_loss(): 130 def L1_band_loss(y_true,y_pred): 131 mask = y_true[:,:,-1:] 132 y_true = y_true[:,:,:-1] 133 e = (y_pred - y_true)*mask 134 e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho') 135 l1_loss = K.mean(K.abs(e_bands)) 136 return l1_loss 137 return L1_band_loss 138 139def plc_pitch_loss(): 140 def pitch_loss(y_true,y_pred): 141 mask = y_true[:,:,-1:] 142 y_true = y_true[:,:,:-1] 143 e = (y_pred - y_true)*mask 144 l1_loss = K.mean(K.minimum(K.abs(e[:,:,18:19]),.4)) 145 return l1_loss 146 return pitch_loss 147 148opt = Adam(lr, decay=decay, beta_2=0.99) 149strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 150 151with strategy.scope(): 152 model = lpcnet.new_lpcnet_plc_model(rnn_units=args.gru_size, batch_size=batch_size, training=True, quantize=quantize, cond_size=args.cond_size) 153 model.compile(optimizer=opt, loss=plc_loss(alpha=args.band_loss, bias=args.loss_bias), metrics=[plc_l1_loss(), plc_ceps_loss(), plc_band_loss(), plc_pitch_loss()]) 154 model.summary() 155 156lpc_order = 16 157 158feature_file = args.features 159nb_features = model.nb_used_features + lpc_order + model.nb_burg_features 160nb_used_features = model.nb_used_features 161nb_burg_features = model.nb_burg_features 162sequence_size = args.seq_length 163 164# u for unquantised, load 16 bit PCM samples and convert to mu-law 165 166 167features = np.memmap(feature_file, dtype='float32', mode='r') 168nb_sequences = len(features)//(nb_features*sequence_size)//batch_size*batch_size 169features = features[:nb_sequences*sequence_size*nb_features] 170 171features = np.reshape(features, (nb_sequences, sequence_size, nb_features)) 172 173features = features[:, :, :nb_used_features+model.nb_burg_features] 174 175lost = np.memmap(args.lost_file, dtype='int8', mode='r') 176 177# dump models to disk as we go 178checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.gru_size, '{epoch:02d}')) 179 180if args.retrain is not None: 181 model.load_weights(args.retrain) 182 183if quantize or retrain: 184 #Adapting from an existing model 185 model.load_weights(input_model) 186 187model.save_weights('{}_{}_initial.h5'.format(args.output, args.gru_size)) 188 189loader = PLCLoader(features, lost, nb_burg_features, batch_size) 190 191callbacks = [checkpoint] 192if args.logdir is not None: 193 logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.gru_size) 194 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) 195 callbacks.append(tensorboard_callback) 196 197model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks) 198