1import os 2import argparse 3import random 4import numpy as np 5 6import torch 7from torch import nn 8import torch.nn.functional as F 9import tqdm 10 11import plc 12from plc_dataset import PLCDataset 13 14parser = argparse.ArgumentParser() 15 16parser.add_argument('features', type=str, help='path to feature file in .f32 format') 17parser.add_argument('loss', type=str, help='path to signal file in .s8 format') 18parser.add_argument('output', type=str, help='path to output folder') 19 20parser.add_argument('--suffix', type=str, help="model name suffix", default="") 21parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None) 22 23 24model_group = parser.add_argument_group(title="model parameters") 25model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 128", default=128) 26model_group.add_argument('--gru-size', type=int, help="GRU size, default: 128", default=128) 27 28training_group = parser.add_argument_group(title="training parameters") 29training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512) 30training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3) 31training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20) 32training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15) 33training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4) 34training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) 35 36args = parser.parse_args() 37 38if args.cuda_visible_devices != None: 39 os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices 40 41# checkpoints 42checkpoint_dir = os.path.join(args.output, 'checkpoints') 43checkpoint = dict() 44os.makedirs(checkpoint_dir, exist_ok=True) 45 46 47# training parameters 48batch_size = args.batch_size 49lr = args.lr 50epochs = args.epochs 51sequence_length = args.sequence_length 52lr_decay = args.lr_decay 53 54adam_betas = [0.8, 0.95] 55adam_eps = 1e-8 56features_file = args.features 57loss_file = args.loss 58 59# model parameters 60cond_size = args.cond_size 61 62 63checkpoint['batch_size'] = batch_size 64checkpoint['lr'] = lr 65checkpoint['lr_decay'] = lr_decay 66checkpoint['epochs'] = epochs 67checkpoint['sequence_length'] = sequence_length 68checkpoint['adam_betas'] = adam_betas 69 70 71device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 72 73checkpoint['model_args'] = () 74checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gru_size': args.gru_size} 75print(checkpoint['model_kwargs']) 76model = plc.PLC(*checkpoint['model_args'], **checkpoint['model_kwargs']) 77 78if type(args.initial_checkpoint) != type(None): 79 checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') 80 model.load_state_dict(checkpoint['state_dict'], strict=False) 81 82checkpoint['state_dict'] = model.state_dict() 83 84 85dataset = PLCDataset(features_file, loss_file, sequence_length=sequence_length) 86dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) 87 88 89optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) 90 91 92# learning rate scheduler 93scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x)) 94 95states = None 96 97plc_loss = plc.plc_loss(18, device=device) 98if __name__ == '__main__': 99 model.to(device) 100 101 for epoch in range(1, epochs + 1): 102 103 running_loss = 0 104 running_l1_loss = 0 105 running_ceps_loss = 0 106 running_band_loss = 0 107 running_pitch_loss = 0 108 109 print(f"training epoch {epoch}...") 110 with tqdm.tqdm(dataloader, unit='batch') as tepoch: 111 for i, (features, lost, target) in enumerate(tepoch): 112 optimizer.zero_grad() 113 features = features.to(device) 114 lost = lost.to(device) 115 target = target.to(device) 116 117 out, states = model(features, lost) 118 119 loss, l1_loss, ceps_loss, band_loss, pitch_loss = plc_loss(target, out) 120 121 loss.backward() 122 optimizer.step() 123 124 #model.clip_weights() 125 126 scheduler.step() 127 128 running_loss += loss.detach().cpu().item() 129 running_l1_loss += l1_loss.detach().cpu().item() 130 running_ceps_loss += ceps_loss.detach().cpu().item() 131 running_band_loss += band_loss.detach().cpu().item() 132 running_pitch_loss += pitch_loss.detach().cpu().item() 133 tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}", 134 l1_loss=f"{running_l1_loss/(i+1):8.5f}", 135 ceps_loss=f"{running_ceps_loss/(i+1):8.5f}", 136 band_loss=f"{running_band_loss/(i+1):8.5f}", 137 pitch_loss=f"{running_pitch_loss/(i+1):8.5f}", 138 ) 139 140 # save checkpoint 141 checkpoint_path = os.path.join(checkpoint_dir, f'plc{args.suffix}_{epoch}.pth') 142 checkpoint['state_dict'] = model.state_dict() 143 checkpoint['loss'] = running_loss / len(dataloader) 144 checkpoint['epoch'] = epoch 145 torch.save(checkpoint, checkpoint_path) 146