1""" 2/* Copyright (c) 2022 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import argparse 32 33import torch 34import tqdm 35 36from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate 37 38 39parser = argparse.ArgumentParser() 40 41parser.add_argument('features', type=str, help='path to feature file in .f32 format') 42parser.add_argument('output', type=str, help='path to output folder') 43 44parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="") 45 46 47model_group = parser.add_argument_group(title="model parameters") 48model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80) 49model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) 50model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256) 51model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24) 52model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16) 53model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4) 54model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104) 55model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82) 56model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0) 57 58training_group = parser.add_argument_group(title="training parameters") 59training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32) 60training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4) 61training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100) 62training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256) 63training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5) 64training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split') 65training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames') 66training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) 67training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only') 68 69args = parser.parse_args() 70 71# set visible devices 72os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices 73 74# checkpoints 75checkpoint_dir = os.path.join(args.output, 'checkpoints') 76checkpoint = dict() 77os.makedirs(checkpoint_dir, exist_ok=True) 78 79# training parameters 80batch_size = args.batch_size 81lr = args.lr 82epochs = args.epochs 83sequence_length = args.sequence_length 84lr_decay_factor = args.lr_decay_factor 85split_mode = args.split_mode 86# not exposed 87adam_betas = [0.8, 0.95] 88adam_eps = 1e-8 89 90checkpoint['batch_size'] = batch_size 91checkpoint['lr'] = lr 92checkpoint['lr_decay_factor'] = lr_decay_factor 93checkpoint['split_mode'] = split_mode 94checkpoint['epochs'] = epochs 95checkpoint['sequence_length'] = sequence_length 96checkpoint['adam_betas'] = adam_betas 97 98# logging 99log_interval = 10 100 101# device 102device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 103 104# model parameters 105cond_size = args.cond_size 106cond_size2 = args.cond_size2 107latent_dim = args.latent_dim 108quant_levels = args.quant_levels 109lambda_min = args.lambda_min 110lambda_max = args.lambda_max 111state_dim = args.state_dim 112# not expsed 113num_features = 20 114 115 116# training data 117feature_file = args.features 118 119# model 120checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2) 121checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate} 122model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) 123 124if type(args.initial_checkpoint) != type(None): 125 checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') 126 model.load_state_dict(checkpoint['state_dict'], strict=False) 127 128checkpoint['state_dict'] = model.state_dict() 129 130if args.train_decoder_only: 131 if args.initial_checkpoint is None: 132 print("warning: training decoder only without providing initial checkpoint") 133 134 for p in model.core_encoder.module.parameters(): 135 p.requires_grad = False 136 137 for p in model.statistical_model.parameters(): 138 p.requires_grad = False 139 140# dataloader 141checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36) 142checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels} 143dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs']) 144dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) 145 146 147 148# optimizer 149params = [p for p in model.parameters() if p.requires_grad] 150optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps) 151 152 153# learning rate scheduler 154scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) 155 156if __name__ == '__main__': 157 158 # push model to device 159 model.to(device) 160 161 # training loop 162 163 for epoch in range(1, epochs + 1): 164 165 print(f"training epoch {epoch}...") 166 167 # running stats 168 running_rate_loss = 0 169 running_soft_dist_loss = 0 170 running_hard_dist_loss = 0 171 running_hard_rate_loss = 0 172 running_soft_rate_loss = 0 173 running_total_loss = 0 174 running_rate_metric = 0 175 running_states_rate_metric = 0 176 previous_total_loss = 0 177 running_first_frame_loss = 0 178 179 with tqdm.tqdm(dataloader, unit='batch') as tepoch: 180 for i, (features, rate_lambda, q_ids) in enumerate(tepoch): 181 182 # zero out gradients 183 optimizer.zero_grad() 184 185 # push inputs to device 186 features = features.to(device) 187 q_ids = q_ids.to(device) 188 rate_lambda = rate_lambda.to(device) 189 190 191 rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1) 192 193 # run model 194 model_output = model(features, q_ids) 195 196 # collect outputs 197 z = model_output['z'] 198 states = model_output['states'] 199 outputs_hard_quant = model_output['outputs_hard_quant'] 200 outputs_soft_quant = model_output['outputs_soft_quant'] 201 statistical_model = model_output['statistical_model'] 202 203 # rate loss 204 hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False) 205 soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False) 206 states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False) 207 states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False) 208 soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate)) 209 hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate)) 210 rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss) 211 hard_rate_metric = torch.mean(hard_rate) 212 states_rate_metric = torch.mean(states_hard_rate) 213 214 ## distortion losses 215 216 # hard quantized decoder input 217 distortion_loss_hard_quant = torch.zeros_like(rate_loss) 218 for dec_features, start, stop in outputs_hard_quant: 219 distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant) 220 221 first_frame_loss = torch.zeros_like(rate_loss) 222 for dec_features, start, stop in outputs_hard_quant: 223 first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant) 224 225 # soft quantized decoder input 226 distortion_loss_soft_quant = torch.zeros_like(rate_loss) 227 for dec_features, start, stop in outputs_soft_quant: 228 distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant) 229 230 # total loss 231 total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2 232 233 if args.enable_first_frame_loss: 234 total_loss = .97*total_loss + 0.03 * first_frame_loss 235 236 237 total_loss.backward() 238 239 optimizer.step() 240 241 model.clip_weights() 242 model.sparsify() 243 244 scheduler.step() 245 246 # collect running stats 247 running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu()) 248 running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu()) 249 running_rate_loss += float(rate_loss.detach().cpu()) 250 running_rate_metric += float(hard_rate_metric.detach().cpu()) 251 running_states_rate_metric += float(states_rate_metric.detach().cpu()) 252 running_total_loss += float(total_loss.detach().cpu()) 253 running_first_frame_loss += float(first_frame_loss.detach().cpu()) 254 running_soft_rate_loss += float(soft_rate_loss.detach().cpu()) 255 running_hard_rate_loss += float(hard_rate_loss.detach().cpu()) 256 257 if (i + 1) % log_interval == 0: 258 current_loss = (running_total_loss - previous_total_loss) / log_interval 259 tepoch.set_postfix( 260 current_loss=current_loss, 261 total_loss=running_total_loss / (i + 1), 262 dist_hq=running_hard_dist_loss / (i + 1), 263 dist_sq=running_soft_dist_loss / (i + 1), 264 rate_loss=running_rate_loss / (i + 1), 265 rate=running_rate_metric / (i + 1), 266 states_rate=running_states_rate_metric / (i + 1), 267 ffloss=running_first_frame_loss / (i + 1), 268 rateloss_hard=running_hard_rate_loss / (i + 1), 269 rateloss_soft=running_soft_rate_loss / (i + 1) 270 ) 271 previous_total_loss = running_total_loss 272 273 # save checkpoint 274 checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth') 275 checkpoint['state_dict'] = model.state_dict() 276 checkpoint['loss'] = running_total_loss / len(dataloader) 277 checkpoint['epoch'] = epoch 278 torch.save(checkpoint, checkpoint_path) 279