1import os 2import argparse 3import random 4import numpy as np 5 6import torch 7from torch import nn 8import torch.nn.functional as F 9import tqdm 10 11import fargan 12from dataset import FARGANDataset 13from stft_loss import * 14 15parser = argparse.ArgumentParser() 16 17parser.add_argument('features', type=str, help='path to feature file in .f32 format') 18parser.add_argument('signal', type=str, help='path to signal file in .s16 format') 19parser.add_argument('output', type=str, help='path to output folder') 20 21parser.add_argument('--suffix', type=str, help="model name suffix", default="") 22parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None) 23 24 25model_group = parser.add_argument_group(title="model parameters") 26model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) 27model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9) 28 29training_group = parser.add_argument_group(title="training parameters") 30training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512) 31training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3) 32training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20) 33training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15) 34training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4) 35training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) 36 37args = parser.parse_args() 38 39if args.cuda_visible_devices != None: 40 os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices 41 42# checkpoints 43checkpoint_dir = os.path.join(args.output, 'checkpoints') 44checkpoint = dict() 45os.makedirs(checkpoint_dir, exist_ok=True) 46 47 48# training parameters 49batch_size = args.batch_size 50lr = args.lr 51epochs = args.epochs 52sequence_length = args.sequence_length 53lr_decay = args.lr_decay 54 55adam_betas = [0.8, 0.95] 56adam_eps = 1e-8 57features_file = args.features 58signal_file = args.signal 59 60# model parameters 61cond_size = args.cond_size 62 63 64checkpoint['batch_size'] = batch_size 65checkpoint['lr'] = lr 66checkpoint['lr_decay'] = lr_decay 67checkpoint['epochs'] = epochs 68checkpoint['sequence_length'] = sequence_length 69checkpoint['adam_betas'] = adam_betas 70 71 72device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 73 74checkpoint['model_args'] = () 75checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gamma': args.gamma} 76print(checkpoint['model_kwargs']) 77model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) 78 79#model = fargan.FARGAN() 80#model = nn.DataParallel(model) 81 82if type(args.initial_checkpoint) != type(None): 83 checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') 84 model.load_state_dict(checkpoint['state_dict'], strict=False) 85 86checkpoint['state_dict'] = model.state_dict() 87 88 89dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length) 90dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) 91 92 93optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) 94 95 96# learning rate scheduler 97scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x)) 98 99states = None 100 101spect_loss = MultiResolutionSTFTLoss(device).to(device) 102 103if __name__ == '__main__': 104 model.to(device) 105 106 for epoch in range(1, epochs + 1): 107 108 running_specc = 0 109 running_cont_loss = 0 110 running_loss = 0 111 112 print(f"training epoch {epoch}...") 113 with tqdm.tqdm(dataloader, unit='batch') as tepoch: 114 for i, (features, periods, target, lpc) in enumerate(tepoch): 115 optimizer.zero_grad() 116 features = features.to(device) 117 #lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4)) 118 #print("interp size", lpc.shape) 119 #lpc = lpc.to(device) 120 #lpc = lpc*(args.gamma**torch.arange(1,17, device=device)) 121 #lpc = fargan.interp_lpc(lpc, 4) 122 periods = periods.to(device) 123 if (np.random.rand() > 0.1): 124 target = target[:, :sequence_length*160] 125 #lpc = lpc[:,:sequence_length*4,:] 126 features = features[:,:sequence_length+4,:] 127 periods = periods[:,:sequence_length+4] 128 else: 129 target=target[::2, :] 130 #lpc=lpc[::2,:] 131 features=features[::2,:] 132 periods=periods[::2,:] 133 target = target.to(device) 134 #print(target.shape, lpc.shape) 135 #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma) 136 137 #nb_pre = random.randrange(1, 6) 138 nb_pre = 2 139 pre = target[:, :nb_pre*160] 140 sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None) 141 sig = torch.cat([pre, sig], -1) 142 143 cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160]) 144 specc_loss = spect_loss(sig, target.detach()) 145 loss = .03*cont_loss + specc_loss 146 147 loss.backward() 148 optimizer.step() 149 150 #model.clip_weights() 151 152 scheduler.step() 153 154 running_specc += specc_loss.detach().cpu().item() 155 running_cont_loss += cont_loss.detach().cpu().item() 156 157 running_loss += loss.detach().cpu().item() 158 tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}", 159 cont_loss=f"{running_cont_loss/(i+1):8.5f}", 160 specc=f"{running_specc/(i+1):8.5f}", 161 ) 162 163 # save checkpoint 164 checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth') 165 checkpoint['state_dict'] = model.state_dict() 166 checkpoint['loss'] = running_loss / len(dataloader) 167 checkpoint['epoch'] = epoch 168 torch.save(checkpoint, checkpoint_path) 169