xref: /aosp_15_r20/external/libopus/dnn/torch/osce/train_model.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30seed=1888
31
32import os
33import argparse
34import sys
35import random
36random.seed(seed)
37
38import yaml
39
40try:
41    import git
42    has_git = True
43except:
44    has_git = False
45
46import torch
47torch.manual_seed(seed)
48torch.backends.cudnn.benchmark = False
49from torch.optim.lr_scheduler import LambdaLR
50
51import numpy as np
52np.random.seed(seed)
53
54from scipy.io import wavfile
55
56import pesq
57
58from data import SilkEnhancementSet
59from models import model_dict
60from engine.engine import train_one_epoch, evaluate
61
62
63from utils.silk_features import load_inference_data
64from utils.misc import count_parameters, count_nonzero_parameters
65
66from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
67
68
69parser = argparse.ArgumentParser()
70
71parser.add_argument('setup', type=str, help='setup yaml file')
72parser.add_argument('output', type=str, help='output path')
73parser.add_argument('--device', type=str, help='compute device', default=None)
74parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
75parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None)
76parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout')
77
78args = parser.parse_args()
79
80
81
82torch.set_num_threads(4)
83
84with open(args.setup, 'r') as f:
85    setup = yaml.load(f.read(), yaml.FullLoader)
86
87checkpoint_prefix = 'checkpoint'
88output_prefix = 'output'
89setup_name = 'setup.yml'
90output_file='out.txt'
91
92
93# check model
94if not 'name' in setup['model']:
95    print(f'warning: did not find model entry in setup, using default PitchPostFilter')
96    model_name = 'pitchpostfilter'
97else:
98    model_name = setup['model']['name']
99
100# prepare output folder
101if os.path.exists(args.output):
102    print("warning: output folder exists")
103
104    reply = input('continue? (y/n): ')
105    while reply not in {'y', 'n'}:
106        reply = input('continue? (y/n): ')
107
108    if reply == 'n':
109        os._exit(0)
110else:
111    os.makedirs(args.output, exist_ok=True)
112
113checkpoint_dir = os.path.join(args.output, 'checkpoints')
114os.makedirs(checkpoint_dir, exist_ok=True)
115
116# add repo info to setup
117if has_git:
118    working_dir = os.path.split(__file__)[0]
119    try:
120        repo = git.Repo(working_dir, search_parent_directories=True)
121        setup['repo'] = dict()
122        hash = repo.head.object.hexsha
123        urls = list(repo.remote().urls)
124        is_dirty = repo.is_dirty()
125
126        if is_dirty:
127            print("warning: repo is dirty")
128            with open(os.path.join(args.output, 'repo.diff'), "w") as f:
129                f.write(repo.git.execute(["git", "diff"]))
130
131        setup['repo']['hash'] = hash
132        setup['repo']['urls'] = urls
133        setup['repo']['dirty'] = is_dirty
134    except:
135        has_git = False
136
137# dump setup
138with open(os.path.join(args.output, setup_name), 'w') as f:
139    yaml.dump(setup, f)
140
141ref = None
142if args.testdata is not None:
143
144    testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data'])
145
146    inference_test = True
147    inference_folder = os.path.join(args.output, 'inference_test')
148    os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True)
149
150    try:
151        ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16)
152    except:
153        pass
154else:
155    inference_test = False
156
157# training parameters
158batch_size      = setup['training']['batch_size']
159epochs          = setup['training']['epochs']
160lr              = setup['training']['lr']
161lr_decay_factor = setup['training']['lr_decay_factor']
162
163# load training dataset
164data_config = setup['data']
165data = SilkEnhancementSet(setup['dataset'], **data_config)
166
167# load validation dataset if given
168if 'validation_dataset' in setup:
169    validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config)
170
171    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8)
172
173    run_validation = True
174else:
175    run_validation = False
176
177# create model
178model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs'])
179
180if args.initial_checkpoint is not None:
181    print(f"loading state dict from {args.initial_checkpoint}...")
182    chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
183    model.load_state_dict(chkpt['state_dict'])
184
185# set compute device
186if type(args.device) == type(None):
187    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
188else:
189    device = torch.device(args.device)
190
191# push model to device
192model.to(device)
193
194# dataloader
195dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8)
196
197# optimizer is introduced to trainable parameters
198parameters = [p for p in model.parameters() if p.requires_grad]
199optimizer = torch.optim.Adam(parameters, lr=lr)
200
201# learning rate scheduler
202scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
203
204# loss
205w_l1 = setup['training']['loss']['w_l1']
206w_lm = setup['training']['loss']['w_lm']
207w_slm = setup['training']['loss']['w_slm']
208w_sc = setup['training']['loss']['w_sc']
209w_logmel = setup['training']['loss']['w_logmel']
210w_wsc = setup['training']['loss']['w_wsc']
211w_xcorr = setup['training']['loss']['w_xcorr']
212w_sxcorr = setup['training']['loss']['w_sxcorr']
213w_l2 = setup['training']['loss']['w_l2']
214
215w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2
216
217stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device)
218logmelloss = MRLogMelLoss().to(device)
219
220def xcorr_loss(y_true, y_pred):
221    dims = list(range(1, len(y_true.shape)))
222
223    loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9)
224
225    return torch.mean(loss)
226
227def td_l2_norm(y_true, y_pred):
228    dims = list(range(1, len(y_true.shape)))
229
230    loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6)
231
232    return loss.mean()
233
234def td_l1(y_true, y_pred, pow=0):
235    dims = list(range(1, len(y_true.shape)))
236    tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow)
237
238    return torch.mean(tmp)
239
240def criterion(x, y):
241
242    return (w_l1 * td_l1(x, y, pow=1) +  stftloss(x, y) + w_logmel * logmelloss(x, y)
243            + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum
244
245
246
247# model checkpoint
248checkpoint = {
249    'setup'         : setup,
250    'state_dict'    : model.state_dict(),
251    'loss'          : -1
252}
253
254
255
256
257if not args.no_redirect:
258    print(f"re-directing output to {os.path.join(args.output, output_file)}")
259    sys.stdout = open(os.path.join(args.output, output_file), "w")
260
261print("summary:")
262
263print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters")
264if hasattr(model, 'flop_count'):
265    print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS")
266
267if ref is not None:
268    noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16)
269    initial_mos = pesq.pesq(16000, ref, noisy, mode='wb')
270    print(f"initial MOS (PESQ): {initial_mos}")
271
272best_loss = 1e9
273
274for ep in range(1, epochs + 1):
275    print(f"training epoch {ep}...")
276    new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
277
278
279    # save checkpoint
280    checkpoint['state_dict'] = model.state_dict()
281    checkpoint['loss']       = new_loss
282
283    if run_validation:
284        print("running validation...")
285        validation_loss = evaluate(model, criterion, validation_dataloader, device)
286        checkpoint['validation_loss'] = validation_loss
287
288        if validation_loss < best_loss:
289            torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
290            best_loss = validation_loss
291
292    if inference_test:
293        print("running inference test...")
294        out = model.process(testsignal, features, periods, numbits).cpu().numpy()
295        wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out)
296        if ref is not None:
297            mos = pesq.pesq(16000, ref, out, mode='wb')
298            print(f"MOS (PESQ): {mos}")
299
300
301    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
302    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
303
304
305    print(f"non-zero parameters: {count_nonzero_parameters(model)}\n")
306
307print('Done')
308