xref: /aosp_15_r20/external/libopus/dnn/torch/fargan/train_fargan.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import os
2import argparse
3import random
4import numpy as np
5
6import torch
7from torch import nn
8import torch.nn.functional as F
9import tqdm
10
11import fargan
12from dataset import FARGANDataset
13from stft_loss import *
14
15parser = argparse.ArgumentParser()
16
17parser.add_argument('features', type=str, help='path to feature file in .f32 format')
18parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
19parser.add_argument('output', type=str, help='path to output folder')
20
21parser.add_argument('--suffix', type=str, help="model name suffix", default="")
22parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
23
24
25model_group = parser.add_argument_group(title="model parameters")
26model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
27model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
28
29training_group = parser.add_argument_group(title="training parameters")
30training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
31training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
32training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
33training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
34training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
35training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
36
37args = parser.parse_args()
38
39if args.cuda_visible_devices != None:
40    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
41
42# checkpoints
43checkpoint_dir = os.path.join(args.output, 'checkpoints')
44checkpoint = dict()
45os.makedirs(checkpoint_dir, exist_ok=True)
46
47
48# training parameters
49batch_size = args.batch_size
50lr = args.lr
51epochs = args.epochs
52sequence_length = args.sequence_length
53lr_decay = args.lr_decay
54
55adam_betas = [0.8, 0.95]
56adam_eps = 1e-8
57features_file = args.features
58signal_file = args.signal
59
60# model parameters
61cond_size  = args.cond_size
62
63
64checkpoint['batch_size'] = batch_size
65checkpoint['lr'] = lr
66checkpoint['lr_decay'] = lr_decay
67checkpoint['epochs'] = epochs
68checkpoint['sequence_length'] = sequence_length
69checkpoint['adam_betas'] = adam_betas
70
71
72device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
73
74checkpoint['model_args']    = ()
75checkpoint['model_kwargs']  = {'cond_size': cond_size, 'gamma': args.gamma}
76print(checkpoint['model_kwargs'])
77model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
78
79#model = fargan.FARGAN()
80#model = nn.DataParallel(model)
81
82if type(args.initial_checkpoint) != type(None):
83    checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
84    model.load_state_dict(checkpoint['state_dict'], strict=False)
85
86checkpoint['state_dict']    = model.state_dict()
87
88
89dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
90dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
91
92
93optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
94
95
96# learning rate scheduler
97scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
98
99states = None
100
101spect_loss =  MultiResolutionSTFTLoss(device).to(device)
102
103if __name__ == '__main__':
104    model.to(device)
105
106    for epoch in range(1, epochs + 1):
107
108        running_specc = 0
109        running_cont_loss = 0
110        running_loss = 0
111
112        print(f"training epoch {epoch}...")
113        with tqdm.tqdm(dataloader, unit='batch') as tepoch:
114            for i, (features, periods, target, lpc) in enumerate(tepoch):
115                optimizer.zero_grad()
116                features = features.to(device)
117                #lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
118                #print("interp size", lpc.shape)
119                #lpc = lpc.to(device)
120                #lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
121                #lpc = fargan.interp_lpc(lpc, 4)
122                periods = periods.to(device)
123                if (np.random.rand() > 0.1):
124                    target = target[:, :sequence_length*160]
125                    #lpc = lpc[:,:sequence_length*4,:]
126                    features = features[:,:sequence_length+4,:]
127                    periods = periods[:,:sequence_length+4]
128                else:
129                    target=target[::2, :]
130                    #lpc=lpc[::2,:]
131                    features=features[::2,:]
132                    periods=periods[::2,:]
133                target = target.to(device)
134                #print(target.shape, lpc.shape)
135                #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
136
137                #nb_pre = random.randrange(1, 6)
138                nb_pre = 2
139                pre = target[:, :nb_pre*160]
140                sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
141                sig = torch.cat([pre, sig], -1)
142
143                cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
144                specc_loss = spect_loss(sig, target.detach())
145                loss = .03*cont_loss + specc_loss
146
147                loss.backward()
148                optimizer.step()
149
150                #model.clip_weights()
151
152                scheduler.step()
153
154                running_specc += specc_loss.detach().cpu().item()
155                running_cont_loss += cont_loss.detach().cpu().item()
156
157                running_loss += loss.detach().cpu().item()
158                tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
159                                   cont_loss=f"{running_cont_loss/(i+1):8.5f}",
160                                   specc=f"{running_specc/(i+1):8.5f}",
161                                   )
162
163        # save checkpoint
164        checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
165        checkpoint['state_dict'] = model.state_dict()
166        checkpoint['loss'] = running_loss / len(dataloader)
167        checkpoint['epoch'] = epoch
168        torch.save(checkpoint, checkpoint_path)
169