xref: /aosp_15_r20/external/libopus/dnn/torch/osce/train_vocoder.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import argparse
32import sys
33
34import yaml
35
36try:
37    import git
38    has_git = True
39except:
40    has_git = False
41
42import torch
43from torch.optim.lr_scheduler import LambdaLR
44
45from scipy.io import wavfile
46
47import pesq
48
49from data import LPCNetVocodingDataset
50from models import model_dict
51from engine.vocoder_engine import train_one_epoch, evaluate
52
53
54from utils.lpcnet_features import load_lpcnet_features
55from utils.misc import count_parameters
56
57from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
58
59
60parser = argparse.ArgumentParser()
61
62parser.add_argument('setup', type=str, help='setup yaml file')
63parser.add_argument('output', type=str, help='output path')
64parser.add_argument('--device', type=str, help='compute device', default=None)
65parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
66parser.add_argument('--test-features', type=str, help='path to features for testing', default=None)
67parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
68
69args = parser.parse_args()
70
71
72torch.set_num_threads(4)
73
74with open(args.setup, 'r') as f:
75    setup = yaml.load(f.read(), yaml.FullLoader)
76
77checkpoint_prefix = 'checkpoint'
78output_prefix = 'output'
79setup_name = 'setup.yml'
80output_file='out.txt'
81
82
83# check model
84if not 'name' in setup['model']:
85    print(f'warning: did not find model entry in setup, using default PitchPostFilter')
86    model_name = 'pitchpostfilter'
87else:
88    model_name = setup['model']['name']
89
90# prepare output folder
91if os.path.exists(args.output):
92    print("warning: output folder exists")
93
94    reply = input('continue? (y/n): ')
95    while reply not in {'y', 'n'}:
96        reply = input('continue? (y/n): ')
97
98    if reply == 'n':
99        os._exit()
100else:
101    os.makedirs(args.output, exist_ok=True)
102
103checkpoint_dir = os.path.join(args.output, 'checkpoints')
104os.makedirs(checkpoint_dir, exist_ok=True)
105
106# add repo info to setup
107if has_git:
108    working_dir = os.path.split(__file__)[0]
109    try:
110        repo = git.Repo(working_dir, search_parent_directories=True)
111        setup['repo'] = dict()
112        hash = repo.head.object.hexsha
113        urls = list(repo.remote().urls)
114        is_dirty = repo.is_dirty()
115
116        if is_dirty:
117            print("warning: repo is dirty")
118
119        setup['repo']['hash'] = hash
120        setup['repo']['urls'] = urls
121        setup['repo']['dirty'] = is_dirty
122    except:
123        has_git = False
124
125# dump setup
126with open(os.path.join(args.output, setup_name), 'w') as f:
127    yaml.dump(setup, f)
128
129ref = None
130# prepare inference test if wanted
131inference_test = False
132if type(args.test_features) != type(None):
133    test_features = load_lpcnet_features(args.test_features)
134    features = test_features['features']
135    periods = test_features['periods']
136    inference_folder = os.path.join(args.output, 'inference_test')
137    os.makedirs(inference_folder, exist_ok=True)
138    inference_test = True
139
140
141# training parameters
142batch_size      = setup['training']['batch_size']
143epochs          = setup['training']['epochs']
144lr              = setup['training']['lr']
145lr_decay_factor = setup['training']['lr_decay_factor']
146
147# load training dataset
148data_config = setup['data']
149data = LPCNetVocodingDataset(setup['dataset'], **data_config)
150
151# load validation dataset if given
152if 'validation_dataset' in setup:
153    validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config)
154
155    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
156
157    run_validation = True
158else:
159    run_validation = False
160
161# create model
162model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
163
164if args.initial_checkpoint is not None:
165    print(f"loading state dict from {args.initial_checkpoint}...")
166    chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
167    model.load_state_dict(chkpt['state_dict'])
168
169# set compute device
170if type(args.device) == type(None):
171    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172else:
173    device = torch.device(args.device)
174
175# push model to device
176model.to(device)
177
178# dataloader
179dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
180
181# optimizer is introduced to trainable parameters
182parameters = [p for p in model.parameters() if p.requires_grad]
183optimizer = torch.optim.Adam(parameters, lr=lr)
184
185# learning rate scheduler
186scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
187
188# loss
189w_l1 = setup['training']['loss']['w_l1']
190w_lm = setup['training']['loss']['w_lm']
191w_slm = setup['training']['loss']['w_slm']
192w_sc = setup['training']['loss']['w_sc']
193w_logmel = setup['training']['loss']['w_logmel']
194w_wsc = setup['training']['loss']['w_wsc']
195w_xcorr = setup['training']['loss']['w_xcorr']
196w_sxcorr = setup['training']['loss']['w_sxcorr']
197w_l2 = setup['training']['loss']['w_l2']
198
199w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
200
201stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
202logmelloss = MRLogMelLoss().to(device)
203
204def xcorr_loss(y_true, y_pred):
205    dims = list(range(1, len(y_true.shape)))
206
207    loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
208
209    return torch.mean(loss)
210
211def td_l2_norm(y_true, y_pred):
212    dims = list(range(1, len(y_true.shape)))
213
214    loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
215
216    return loss.mean()
217
218def td_l1(y_true, y_pred, pow=0):
219    dims = list(range(1, len(y_true.shape)))
220    tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
221
222    return torch.mean(tmp)
223
224def criterion(x, y):
225
226    return (w_l1 * td_l1(x, y, pow=1) +  stftloss(x, y) + w_logmel * logmelloss(x, y)
227            + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
228
229
230
231# model checkpoint
232checkpoint = {
233    'setup'         : setup,
234    'state_dict'    : model.state_dict(),
235    'loss'          : -1
236}
237
238
239if not args.no_redirect:
240    print(f"re-directing output to {os.path.join(args.output, output_file)}")
241    sys.stdout = open(os.path.join(args.output, output_file), "w")
242
243print("summary:")
244
245print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
246if hasattr(model, 'flop_count'):
247    print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
248
249if ref is not None:
250    pass
251
252best_loss = 1e9
253
254for ep in range(1, epochs + 1):
255    print(f"training epoch {ep}...")
256    new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
257
258
259    # save checkpoint
260    checkpoint['state_dict'] = model.state_dict()
261    checkpoint['loss']       = new_loss
262
263    if run_validation:
264        print("running validation...")
265        validation_loss = evaluate(model, criterion, validation_dataloader, device)
266        checkpoint['validation_loss'] = validation_loss
267
268        if validation_loss < best_loss:
269            torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
270            best_loss = validation_loss
271
272    if inference_test:
273        print("running inference test...")
274        out = model.process(features, periods).cpu().numpy()
275        wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
276        if ref is not None:
277            mos = pesq.pesq(16000, ref, out, mode='wb')
278            print(f"MOS (PESQ): {mos}")
279
280
281    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
282    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
283
284
285    print()
286
287print('Done')
288