xref: /aosp_15_r20/external/libopus/dnn/torch/plc/train_plc.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import os
2import argparse
3import random
4import numpy as np
5
6import torch
7from torch import nn
8import torch.nn.functional as F
9import tqdm
10
11import plc
12from plc_dataset import PLCDataset
13
14parser = argparse.ArgumentParser()
15
16parser.add_argument('features', type=str, help='path to feature file in .f32 format')
17parser.add_argument('loss', type=str, help='path to signal file in .s8 format')
18parser.add_argument('output', type=str, help='path to output folder')
19
20parser.add_argument('--suffix', type=str, help="model name suffix", default="")
21parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
22
23
24model_group = parser.add_argument_group(title="model parameters")
25model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 128", default=128)
26model_group.add_argument('--gru-size', type=int, help="GRU size, default: 128", default=128)
27
28training_group = parser.add_argument_group(title="training parameters")
29training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
30training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
31training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
32training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
33training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
34training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
35
36args = parser.parse_args()
37
38if args.cuda_visible_devices != None:
39    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
40
41# checkpoints
42checkpoint_dir = os.path.join(args.output, 'checkpoints')
43checkpoint = dict()
44os.makedirs(checkpoint_dir, exist_ok=True)
45
46
47# training parameters
48batch_size = args.batch_size
49lr = args.lr
50epochs = args.epochs
51sequence_length = args.sequence_length
52lr_decay = args.lr_decay
53
54adam_betas = [0.8, 0.95]
55adam_eps = 1e-8
56features_file = args.features
57loss_file = args.loss
58
59# model parameters
60cond_size  = args.cond_size
61
62
63checkpoint['batch_size'] = batch_size
64checkpoint['lr'] = lr
65checkpoint['lr_decay'] = lr_decay
66checkpoint['epochs'] = epochs
67checkpoint['sequence_length'] = sequence_length
68checkpoint['adam_betas'] = adam_betas
69
70
71device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
72
73checkpoint['model_args']    = ()
74checkpoint['model_kwargs']  = {'cond_size': cond_size, 'gru_size': args.gru_size}
75print(checkpoint['model_kwargs'])
76model = plc.PLC(*checkpoint['model_args'], **checkpoint['model_kwargs'])
77
78if type(args.initial_checkpoint) != type(None):
79    checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
80    model.load_state_dict(checkpoint['state_dict'], strict=False)
81
82checkpoint['state_dict']    = model.state_dict()
83
84
85dataset = PLCDataset(features_file, loss_file, sequence_length=sequence_length)
86dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
87
88
89optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
90
91
92# learning rate scheduler
93scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
94
95states = None
96
97plc_loss = plc.plc_loss(18, device=device)
98if __name__ == '__main__':
99    model.to(device)
100
101    for epoch in range(1, epochs + 1):
102
103        running_loss = 0
104        running_l1_loss = 0
105        running_ceps_loss = 0
106        running_band_loss = 0
107        running_pitch_loss = 0
108
109        print(f"training epoch {epoch}...")
110        with tqdm.tqdm(dataloader, unit='batch') as tepoch:
111            for i, (features, lost, target) in enumerate(tepoch):
112                optimizer.zero_grad()
113                features = features.to(device)
114                lost = lost.to(device)
115                target = target.to(device)
116
117                out, states = model(features, lost)
118
119                loss, l1_loss, ceps_loss, band_loss, pitch_loss = plc_loss(target, out)
120
121                loss.backward()
122                optimizer.step()
123
124                #model.clip_weights()
125
126                scheduler.step()
127
128                running_loss += loss.detach().cpu().item()
129                running_l1_loss += l1_loss.detach().cpu().item()
130                running_ceps_loss += ceps_loss.detach().cpu().item()
131                running_band_loss += band_loss.detach().cpu().item()
132                running_pitch_loss += pitch_loss.detach().cpu().item()
133                tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
134                                   l1_loss=f"{running_l1_loss/(i+1):8.5f}",
135                                   ceps_loss=f"{running_ceps_loss/(i+1):8.5f}",
136                                   band_loss=f"{running_band_loss/(i+1):8.5f}",
137                                   pitch_loss=f"{running_pitch_loss/(i+1):8.5f}",
138                                   )
139
140        # save checkpoint
141        checkpoint_path = os.path.join(checkpoint_dir, f'plc{args.suffix}_{epoch}.pth')
142        checkpoint['state_dict'] = model.state_dict()
143        checkpoint['loss'] = running_loss / len(dataloader)
144        checkpoint['epoch'] = epoch
145        torch.save(checkpoint, checkpoint_path)
146