1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import argparse 32import sys 33 34import yaml 35 36try: 37 import git 38 has_git = True 39except: 40 has_git = False 41 42import torch 43from torch.optim.lr_scheduler import LambdaLR 44 45from scipy.io import wavfile 46 47import pesq 48 49from data import LPCNetVocodingDataset 50from models import model_dict 51from engine.vocoder_engine import train_one_epoch, evaluate 52 53 54from utils.lpcnet_features import load_lpcnet_features 55from utils.misc import count_parameters 56 57from losses.stft_loss import MRSTFTLoss, MRLogMelLoss 58 59 60parser = argparse.ArgumentParser() 61 62parser.add_argument('setup', type=str, help='setup yaml file') 63parser.add_argument('output', type=str, help='output path') 64parser.add_argument('--device', type=str, help='compute device', default=None) 65parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None) 66parser.add_argument('--test-features', type=str, help='path to features for testing', default=None) 67parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout') 68 69args = parser.parse_args() 70 71 72torch.set_num_threads(4) 73 74with open(args.setup, 'r') as f: 75 setup = yaml.load(f.read(), yaml.FullLoader) 76 77checkpoint_prefix = 'checkpoint' 78output_prefix = 'output' 79setup_name = 'setup.yml' 80output_file='out.txt' 81 82 83# check model 84if not 'name' in setup['model']: 85 print(f'warning: did not find model entry in setup, using default PitchPostFilter') 86 model_name = 'pitchpostfilter' 87else: 88 model_name = setup['model']['name'] 89 90# prepare output folder 91if os.path.exists(args.output): 92 print("warning: output folder exists") 93 94 reply = input('continue? (y/n): ') 95 while reply not in {'y', 'n'}: 96 reply = input('continue? (y/n): ') 97 98 if reply == 'n': 99 os._exit() 100else: 101 os.makedirs(args.output, exist_ok=True) 102 103checkpoint_dir = os.path.join(args.output, 'checkpoints') 104os.makedirs(checkpoint_dir, exist_ok=True) 105 106# add repo info to setup 107if has_git: 108 working_dir = os.path.split(__file__)[0] 109 try: 110 repo = git.Repo(working_dir, search_parent_directories=True) 111 setup['repo'] = dict() 112 hash = repo.head.object.hexsha 113 urls = list(repo.remote().urls) 114 is_dirty = repo.is_dirty() 115 116 if is_dirty: 117 print("warning: repo is dirty") 118 119 setup['repo']['hash'] = hash 120 setup['repo']['urls'] = urls 121 setup['repo']['dirty'] = is_dirty 122 except: 123 has_git = False 124 125# dump setup 126with open(os.path.join(args.output, setup_name), 'w') as f: 127 yaml.dump(setup, f) 128 129ref = None 130# prepare inference test if wanted 131inference_test = False 132if type(args.test_features) != type(None): 133 test_features = load_lpcnet_features(args.test_features) 134 features = test_features['features'] 135 periods = test_features['periods'] 136 inference_folder = os.path.join(args.output, 'inference_test') 137 os.makedirs(inference_folder, exist_ok=True) 138 inference_test = True 139 140 141# training parameters 142batch_size = setup['training']['batch_size'] 143epochs = setup['training']['epochs'] 144lr = setup['training']['lr'] 145lr_decay_factor = setup['training']['lr_decay_factor'] 146 147# load training dataset 148data_config = setup['data'] 149data = LPCNetVocodingDataset(setup['dataset'], **data_config) 150 151# load validation dataset if given 152if 'validation_dataset' in setup: 153 validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config) 154 155 validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8) 156 157 run_validation = True 158else: 159 run_validation = False 160 161# create model 162model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs']) 163 164if args.initial_checkpoint is not None: 165 print(f"loading state dict from {args.initial_checkpoint}...") 166 chkpt = torch.load(args.initial_checkpoint, map_location='cpu') 167 model.load_state_dict(chkpt['state_dict']) 168 169# set compute device 170if type(args.device) == type(None): 171 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 172else: 173 device = torch.device(args.device) 174 175# push model to device 176model.to(device) 177 178# dataloader 179dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8) 180 181# optimizer is introduced to trainable parameters 182parameters = [p for p in model.parameters() if p.requires_grad] 183optimizer = torch.optim.Adam(parameters, lr=lr) 184 185# learning rate scheduler 186scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) 187 188# loss 189w_l1 = setup['training']['loss']['w_l1'] 190w_lm = setup['training']['loss']['w_lm'] 191w_slm = setup['training']['loss']['w_slm'] 192w_sc = setup['training']['loss']['w_sc'] 193w_logmel = setup['training']['loss']['w_logmel'] 194w_wsc = setup['training']['loss']['w_wsc'] 195w_xcorr = setup['training']['loss']['w_xcorr'] 196w_sxcorr = setup['training']['loss']['w_sxcorr'] 197w_l2 = setup['training']['loss']['w_l2'] 198 199w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 200 201stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device) 202logmelloss = MRLogMelLoss().to(device) 203 204def xcorr_loss(y_true, y_pred): 205 dims = list(range(1, len(y_true.shape))) 206 207 loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9) 208 209 return torch.mean(loss) 210 211def td_l2_norm(y_true, y_pred): 212 dims = list(range(1, len(y_true.shape))) 213 214 loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6) 215 216 return loss.mean() 217 218def td_l1(y_true, y_pred, pow=0): 219 dims = list(range(1, len(y_true.shape))) 220 tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow) 221 222 return torch.mean(tmp) 223 224def criterion(x, y): 225 226 return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) 227 + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum 228 229 230 231# model checkpoint 232checkpoint = { 233 'setup' : setup, 234 'state_dict' : model.state_dict(), 235 'loss' : -1 236} 237 238 239if not args.no_redirect: 240 print(f"re-directing output to {os.path.join(args.output, output_file)}") 241 sys.stdout = open(os.path.join(args.output, output_file), "w") 242 243print("summary:") 244 245print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters") 246if hasattr(model, 'flop_count'): 247 print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS") 248 249if ref is not None: 250 pass 251 252best_loss = 1e9 253 254for ep in range(1, epochs + 1): 255 print(f"training epoch {ep}...") 256 new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler) 257 258 259 # save checkpoint 260 checkpoint['state_dict'] = model.state_dict() 261 checkpoint['loss'] = new_loss 262 263 if run_validation: 264 print("running validation...") 265 validation_loss = evaluate(model, criterion, validation_dataloader, device) 266 checkpoint['validation_loss'] = validation_loss 267 268 if validation_loss < best_loss: 269 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth')) 270 best_loss = validation_loss 271 272 if inference_test: 273 print("running inference test...") 274 out = model.process(features, periods).cpu().numpy() 275 wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out) 276 if ref is not None: 277 mos = pesq.pesq(16000, ref, out, mode='wb') 278 print(f"MOS (PESQ): {mos}") 279 280 281 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth')) 282 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) 283 284 285 print() 286 287print('Done') 288