xref: /aosp_15_r20/external/libopus/dnn/torch/rdovae/train_rdovae.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2022 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import argparse
32
33import torch
34import tqdm
35
36from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate
37
38
39parser = argparse.ArgumentParser()
40
41parser.add_argument('features', type=str, help='path to feature file in .f32 format')
42parser.add_argument('output', type=str, help='path to output folder')
43
44parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
45
46
47model_group = parser.add_argument_group(title="model parameters")
48model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
49model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
50model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
51model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
52model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16)
53model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4)
54model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
55model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
56model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
57
58training_group = parser.add_argument_group(title="training parameters")
59training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
60training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
61training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
62training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256)
63training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
64training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split')
65training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
66training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
67training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
68
69args = parser.parse_args()
70
71# set visible devices
72os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
73
74# checkpoints
75checkpoint_dir = os.path.join(args.output, 'checkpoints')
76checkpoint = dict()
77os.makedirs(checkpoint_dir, exist_ok=True)
78
79# training parameters
80batch_size = args.batch_size
81lr = args.lr
82epochs = args.epochs
83sequence_length = args.sequence_length
84lr_decay_factor = args.lr_decay_factor
85split_mode = args.split_mode
86# not exposed
87adam_betas = [0.8, 0.95]
88adam_eps = 1e-8
89
90checkpoint['batch_size'] = batch_size
91checkpoint['lr'] = lr
92checkpoint['lr_decay_factor'] = lr_decay_factor
93checkpoint['split_mode'] = split_mode
94checkpoint['epochs'] = epochs
95checkpoint['sequence_length'] = sequence_length
96checkpoint['adam_betas'] = adam_betas
97
98# logging
99log_interval = 10
100
101# device
102device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
103
104# model parameters
105cond_size  = args.cond_size
106cond_size2 = args.cond_size2
107latent_dim = args.latent_dim
108quant_levels = args.quant_levels
109lambda_min = args.lambda_min
110lambda_max = args.lambda_max
111state_dim = args.state_dim
112# not expsed
113num_features = 20
114
115
116# training data
117feature_file = args.features
118
119# model
120checkpoint['model_args']    = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
121checkpoint['model_kwargs']  = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate}
122model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
123
124if type(args.initial_checkpoint) != type(None):
125    checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
126    model.load_state_dict(checkpoint['state_dict'], strict=False)
127
128checkpoint['state_dict']    = model.state_dict()
129
130if args.train_decoder_only:
131    if args.initial_checkpoint is None:
132        print("warning: training decoder only without providing initial checkpoint")
133
134    for p in model.core_encoder.module.parameters():
135        p.requires_grad = False
136
137    for p in model.statistical_model.parameters():
138        p.requires_grad = False
139
140# dataloader
141checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36)
142checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels}
143dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs'])
144dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
145
146
147
148# optimizer
149params = [p for p in model.parameters() if p.requires_grad]
150optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps)
151
152
153# learning rate scheduler
154scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
155
156if __name__ == '__main__':
157
158    # push model to device
159    model.to(device)
160
161    # training loop
162
163    for epoch in range(1, epochs + 1):
164
165        print(f"training epoch {epoch}...")
166
167        # running stats
168        running_rate_loss       = 0
169        running_soft_dist_loss  = 0
170        running_hard_dist_loss  = 0
171        running_hard_rate_loss  = 0
172        running_soft_rate_loss  = 0
173        running_total_loss      = 0
174        running_rate_metric     = 0
175        running_states_rate_metric     = 0
176        previous_total_loss     = 0
177        running_first_frame_loss = 0
178
179        with tqdm.tqdm(dataloader, unit='batch') as tepoch:
180            for i, (features, rate_lambda, q_ids) in enumerate(tepoch):
181
182                # zero out gradients
183                optimizer.zero_grad()
184
185                # push inputs to device
186                features    = features.to(device)
187                q_ids       = q_ids.to(device)
188                rate_lambda = rate_lambda.to(device)
189
190
191                rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1)
192
193                # run model
194                model_output = model(features, q_ids)
195
196                # collect outputs
197                z                   = model_output['z']
198                states              = model_output['states']
199                outputs_hard_quant  = model_output['outputs_hard_quant']
200                outputs_soft_quant  = model_output['outputs_soft_quant']
201                statistical_model   = model_output['statistical_model']
202
203                # rate loss
204                hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
205                soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
206                states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
207                states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
208                soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
209                hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
210                rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
211                hard_rate_metric = torch.mean(hard_rate)
212                states_rate_metric = torch.mean(states_hard_rate)
213
214                ## distortion losses
215
216                # hard quantized decoder input
217                distortion_loss_hard_quant = torch.zeros_like(rate_loss)
218                for dec_features, start, stop in outputs_hard_quant:
219                    distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant)
220
221                first_frame_loss = torch.zeros_like(rate_loss)
222                for dec_features, start, stop in outputs_hard_quant:
223                    first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant)
224
225                # soft quantized decoder input
226                distortion_loss_soft_quant = torch.zeros_like(rate_loss)
227                for dec_features, start, stop in outputs_soft_quant:
228                    distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant)
229
230                # total loss
231                total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
232
233                if args.enable_first_frame_loss:
234                    total_loss = .97*total_loss + 0.03 * first_frame_loss
235
236
237                total_loss.backward()
238
239                optimizer.step()
240
241                model.clip_weights()
242                model.sparsify()
243
244                scheduler.step()
245
246                # collect running stats
247                running_hard_dist_loss  += float(distortion_loss_hard_quant.detach().cpu())
248                running_soft_dist_loss  += float(distortion_loss_soft_quant.detach().cpu())
249                running_rate_loss       += float(rate_loss.detach().cpu())
250                running_rate_metric     += float(hard_rate_metric.detach().cpu())
251                running_states_rate_metric     += float(states_rate_metric.detach().cpu())
252                running_total_loss      += float(total_loss.detach().cpu())
253                running_first_frame_loss += float(first_frame_loss.detach().cpu())
254                running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
255                running_hard_rate_loss += float(hard_rate_loss.detach().cpu())
256
257                if (i + 1) % log_interval == 0:
258                    current_loss = (running_total_loss - previous_total_loss) / log_interval
259                    tepoch.set_postfix(
260                        current_loss=current_loss,
261                        total_loss=running_total_loss / (i + 1),
262                        dist_hq=running_hard_dist_loss / (i + 1),
263                        dist_sq=running_soft_dist_loss / (i + 1),
264                        rate_loss=running_rate_loss / (i + 1),
265                        rate=running_rate_metric / (i + 1),
266                        states_rate=running_states_rate_metric / (i + 1),
267                        ffloss=running_first_frame_loss / (i + 1),
268                        rateloss_hard=running_hard_rate_loss / (i + 1),
269                        rateloss_soft=running_soft_rate_loss / (i + 1)
270                    )
271                    previous_total_loss = running_total_loss
272
273        # save checkpoint
274        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
275        checkpoint['state_dict'] = model.state_dict()
276        checkpoint['loss'] = running_total_loss / len(dataloader)
277        checkpoint['epoch'] = epoch
278        torch.save(checkpoint, checkpoint_path)
279