1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30seed=1888 31 32import os 33import argparse 34import sys 35import random 36random.seed(seed) 37 38import yaml 39 40try: 41 import git 42 has_git = True 43except: 44 has_git = False 45 46import torch 47torch.manual_seed(seed) 48torch.backends.cudnn.benchmark = False 49from torch.optim.lr_scheduler import LambdaLR 50 51import numpy as np 52np.random.seed(seed) 53 54from scipy.io import wavfile 55 56import pesq 57 58from data import SilkEnhancementSet 59from models import model_dict 60from engine.engine import train_one_epoch, evaluate 61 62 63from utils.silk_features import load_inference_data 64from utils.misc import count_parameters, count_nonzero_parameters 65 66from losses.stft_loss import MRSTFTLoss, MRLogMelLoss 67 68 69parser = argparse.ArgumentParser() 70 71parser.add_argument('setup', type=str, help='setup yaml file') 72parser.add_argument('output', type=str, help='output path') 73parser.add_argument('--device', type=str, help='compute device', default=None) 74parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None) 75parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None) 76parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout') 77 78args = parser.parse_args() 79 80 81 82torch.set_num_threads(4) 83 84with open(args.setup, 'r') as f: 85 setup = yaml.load(f.read(), yaml.FullLoader) 86 87checkpoint_prefix = 'checkpoint' 88output_prefix = 'output' 89setup_name = 'setup.yml' 90output_file='out.txt' 91 92 93# check model 94if not 'name' in setup['model']: 95 print(f'warning: did not find model entry in setup, using default PitchPostFilter') 96 model_name = 'pitchpostfilter' 97else: 98 model_name = setup['model']['name'] 99 100# prepare output folder 101if os.path.exists(args.output): 102 print("warning: output folder exists") 103 104 reply = input('continue? (y/n): ') 105 while reply not in {'y', 'n'}: 106 reply = input('continue? (y/n): ') 107 108 if reply == 'n': 109 os._exit(0) 110else: 111 os.makedirs(args.output, exist_ok=True) 112 113checkpoint_dir = os.path.join(args.output, 'checkpoints') 114os.makedirs(checkpoint_dir, exist_ok=True) 115 116# add repo info to setup 117if has_git: 118 working_dir = os.path.split(__file__)[0] 119 try: 120 repo = git.Repo(working_dir, search_parent_directories=True) 121 setup['repo'] = dict() 122 hash = repo.head.object.hexsha 123 urls = list(repo.remote().urls) 124 is_dirty = repo.is_dirty() 125 126 if is_dirty: 127 print("warning: repo is dirty") 128 with open(os.path.join(args.output, 'repo.diff'), "w") as f: 129 f.write(repo.git.execute(["git", "diff"])) 130 131 setup['repo']['hash'] = hash 132 setup['repo']['urls'] = urls 133 setup['repo']['dirty'] = is_dirty 134 except: 135 has_git = False 136 137# dump setup 138with open(os.path.join(args.output, setup_name), 'w') as f: 139 yaml.dump(setup, f) 140 141ref = None 142if args.testdata is not None: 143 144 testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data']) 145 146 inference_test = True 147 inference_folder = os.path.join(args.output, 'inference_test') 148 os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True) 149 150 try: 151 ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16) 152 except: 153 pass 154else: 155 inference_test = False 156 157# training parameters 158batch_size = setup['training']['batch_size'] 159epochs = setup['training']['epochs'] 160lr = setup['training']['lr'] 161lr_decay_factor = setup['training']['lr_decay_factor'] 162 163# load training dataset 164data_config = setup['data'] 165data = SilkEnhancementSet(setup['dataset'], **data_config) 166 167# load validation dataset if given 168if 'validation_dataset' in setup: 169 validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config) 170 171 validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8) 172 173 run_validation = True 174else: 175 run_validation = False 176 177# create model 178model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs']) 179 180if args.initial_checkpoint is not None: 181 print(f"loading state dict from {args.initial_checkpoint}...") 182 chkpt = torch.load(args.initial_checkpoint, map_location='cpu') 183 model.load_state_dict(chkpt['state_dict']) 184 185# set compute device 186if type(args.device) == type(None): 187 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 188else: 189 device = torch.device(args.device) 190 191# push model to device 192model.to(device) 193 194# dataloader 195dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8) 196 197# optimizer is introduced to trainable parameters 198parameters = [p for p in model.parameters() if p.requires_grad] 199optimizer = torch.optim.Adam(parameters, lr=lr) 200 201# learning rate scheduler 202scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) 203 204# loss 205w_l1 = setup['training']['loss']['w_l1'] 206w_lm = setup['training']['loss']['w_lm'] 207w_slm = setup['training']['loss']['w_slm'] 208w_sc = setup['training']['loss']['w_sc'] 209w_logmel = setup['training']['loss']['w_logmel'] 210w_wsc = setup['training']['loss']['w_wsc'] 211w_xcorr = setup['training']['loss']['w_xcorr'] 212w_sxcorr = setup['training']['loss']['w_sxcorr'] 213w_l2 = setup['training']['loss']['w_l2'] 214 215w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 216 217stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device) 218logmelloss = MRLogMelLoss().to(device) 219 220def xcorr_loss(y_true, y_pred): 221 dims = list(range(1, len(y_true.shape))) 222 223 loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9) 224 225 return torch.mean(loss) 226 227def td_l2_norm(y_true, y_pred): 228 dims = list(range(1, len(y_true.shape))) 229 230 loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6) 231 232 return loss.mean() 233 234def td_l1(y_true, y_pred, pow=0): 235 dims = list(range(1, len(y_true.shape))) 236 tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow) 237 238 return torch.mean(tmp) 239 240def criterion(x, y): 241 242 return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) 243 + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum 244 245 246 247# model checkpoint 248checkpoint = { 249 'setup' : setup, 250 'state_dict' : model.state_dict(), 251 'loss' : -1 252} 253 254 255 256 257if not args.no_redirect: 258 print(f"re-directing output to {os.path.join(args.output, output_file)}") 259 sys.stdout = open(os.path.join(args.output, output_file), "w") 260 261print("summary:") 262 263print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters") 264if hasattr(model, 'flop_count'): 265 print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS") 266 267if ref is not None: 268 noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16) 269 initial_mos = pesq.pesq(16000, ref, noisy, mode='wb') 270 print(f"initial MOS (PESQ): {initial_mos}") 271 272best_loss = 1e9 273 274for ep in range(1, epochs + 1): 275 print(f"training epoch {ep}...") 276 new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler) 277 278 279 # save checkpoint 280 checkpoint['state_dict'] = model.state_dict() 281 checkpoint['loss'] = new_loss 282 283 if run_validation: 284 print("running validation...") 285 validation_loss = evaluate(model, criterion, validation_dataloader, device) 286 checkpoint['validation_loss'] = validation_loss 287 288 if validation_loss < best_loss: 289 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth')) 290 best_loss = validation_loss 291 292 if inference_test: 293 print("running inference test...") 294 out = model.process(testsignal, features, periods, numbits).cpu().numpy() 295 wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out) 296 if ref is not None: 297 mos = pesq.pesq(16000, ref, out, mode='wb') 298 print(f"MOS (PESQ): {mos}") 299 300 301 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth')) 302 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) 303 304 305 print(f"non-zero parameters: {count_nonzero_parameters(model)}\n") 306 307print('Done') 308