1""" 2/* Copyright (c) 2022 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30""" Pytorch implementations of rate distortion optimized variational autoencoder """ 31 32import math as m 33 34import torch 35from torch import nn 36import torch.nn.functional as F 37import sys 38import os 39source_dir = os.path.split(os.path.abspath(__file__))[0] 40sys.path.append(os.path.join(source_dir, "../../lpcnet/")) 41from utils.sparsification import GRUSparsifier 42from torch.nn.utils import weight_norm 43 44# Quantization and rate related utily functions 45 46def soft_pvq(x, k): 47 """ soft pyramid vector quantizer """ 48 49 # L2 normalization 50 x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True)) 51 52 53 with torch.no_grad(): 54 # quantization loop, no need to track gradients here 55 x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True) 56 57 # set initial scaling factor to k 58 scale_factor = k 59 x_scaled = scale_factor * x_norm1 60 x_quant = torch.round(x_scaled) 61 62 # we aim for ||x_quant||_L1 = k 63 for _ in range(10): 64 # remove signs and calculate L1 norm 65 abs_x_quant = torch.abs(x_quant) 66 abs_x_scaled = torch.abs(x_scaled) 67 l1_x_quant = torch.sum(abs_x_quant, axis=-1) 68 69 # increase, where target is too small and decrease, where target is too large 70 plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values 71 minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values 72 factor = torch.where(l1_x_quant > k, minus, plus) 73 factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor) 74 scale_factor = scale_factor * factor.unsqueeze(-1) 75 76 # update x 77 x_scaled = scale_factor * x_norm1 78 x_quant = torch.round(x_quant) 79 80 # L2 normalization of quantized x 81 x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True)) 82 quantization_error = x_quant_norm2 - x_norm2 83 84 return x_norm2 + quantization_error.detach() 85 86def cache_parameters(func): 87 cache = dict() 88 def cached_func(*args): 89 if args in cache: 90 return cache[args] 91 else: 92 cache[args] = func(*args) 93 94 return cache[args] 95 return cached_func 96 97@cache_parameters 98def pvq_codebook_size(n, k): 99 100 if k == 0: 101 return 1 102 103 if n == 0: 104 return 0 105 106 return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1) 107 108 109def soft_rate_estimate(z, r, reduce=True): 110 """ rate approximation with dependent theta Eq. (7)""" 111 112 rate = torch.sum( 113 - torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6), 114 dim=-1 115 ) 116 117 if reduce: 118 rate = torch.mean(rate) 119 120 return rate 121 122 123def hard_rate_estimate(z, r, theta, reduce=True): 124 """ hard rate approximation """ 125 126 z_q = torch.round(z) 127 p0 = 1 - r ** (0.5 + 0.5 * theta) 128 alpha = torch.relu(1 - torch.abs(z_q)) ** 2 129 rate = - torch.sum( 130 (alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6) 131 + (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)), 132 dim=-1 133 ) 134 135 if reduce: 136 rate = torch.mean(rate) 137 138 return rate 139 140 141 142def soft_dead_zone(x, dead_zone): 143 """ approximates application of a dead zone to x """ 144 d = dead_zone * 0.05 145 return x - d * torch.tanh(x / (0.1 + d)) 146 147 148def hard_quantize(x): 149 """ round with copy gradient trick """ 150 return x + (torch.round(x) - x).detach() 151 152 153def noise_quantize(x): 154 """ simulates quantization with addition of random uniform noise """ 155 return x + (torch.rand_like(x) - 0.5) 156 157 158# loss functions 159 160 161def distortion_loss(y_true, y_pred, rate_lambda=None): 162 """ custom distortion loss for LPCNet features """ 163 164 if y_true.size(-1) != 20: 165 raise ValueError('distortion loss is designed to work with 20 features') 166 167 ceps_error = y_pred[..., :18] - y_true[..., :18] 168 pitch_error = 2*(y_pred[..., 18:19] - y_true[..., 18:19]) 169 corr_error = y_pred[..., 19:] - y_true[..., 19:] 170 pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2 171 172 loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1) 173 174 if type(rate_lambda) != type(None): 175 loss = loss / torch.sqrt(rate_lambda) 176 177 loss = torch.mean(loss) 178 179 return loss 180 181 182# sampling functions 183 184import random 185 186 187def random_split(start, stop, num_splits=3, min_len=3): 188 get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)]) 189 candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop] 190 191 while get_min_len(candidate) < min_len: 192 candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop] 193 194 return candidate 195 196 197 198# weight initialization and clipping 199def init_weights(module): 200 201 if isinstance(module, nn.GRU): 202 for p in module.named_parameters(): 203 if p[0].startswith('weight_hh_'): 204 nn.init.orthogonal_(p[1]) 205 206 207def weight_clip_factory(max_value): 208 """ weight clipping function concerning sum of abs values of adjecent weights """ 209 def clip_weight_(w): 210 stop = w.size(1) 211 # omit last column if stop is odd 212 if stop % 2: 213 stop -= 1 214 max_values = max_value * torch.ones_like(w[:, :stop]) 215 factor = max_value / torch.maximum(max_values, 216 torch.repeat_interleave( 217 torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]), 218 2, 219 1)) 220 with torch.no_grad(): 221 w[:, :stop] *= factor 222 223 def clip_weights(module): 224 if isinstance(module, nn.GRU) or isinstance(module, nn.Linear): 225 for name, w in module.named_parameters(): 226 if name.startswith('weight'): 227 clip_weight_(w) 228 229 return clip_weights 230 231def n(x): 232 return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) 233 234# RDOVAE module and submodules 235 236sparsify_start = 12000 237sparsify_stop = 24000 238sparsify_interval = 100 239sparsify_exponent = 3 240#sparsify_start = 0 241#sparsify_stop = 0 242 243sparse_params1 = { 244# 'W_hr' : (1.0, [8, 4], True), 245# 'W_hz' : (1.0, [8, 4], True), 246# 'W_hn' : (1.0, [8, 4], True), 247 'W_ir' : (0.6, [8, 4], False), 248 'W_iz' : (0.4, [8, 4], False), 249 'W_in' : (0.8, [8, 4], False) 250 } 251 252sparse_params2 = { 253# 'W_hr' : (1.0, [8, 4], True), 254# 'W_hz' : (1.0, [8, 4], True), 255# 'W_hn' : (1.0, [8, 4], True), 256 'W_ir' : (0.3, [8, 4], False), 257 'W_iz' : (0.2, [8, 4], False), 258 'W_in' : (0.4, [8, 4], False) 259 } 260 261 262class MyConv(nn.Module): 263 def __init__(self, input_dim, output_dim, dilation=1): 264 super(MyConv, self).__init__() 265 self.input_dim = input_dim 266 self.output_dim = output_dim 267 self.dilation=dilation 268 self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation) 269 def forward(self, x, state=None): 270 device = x.device 271 conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1) 272 return torch.tanh(self.conv(conv_in)).permute(0, 2, 1) 273 274class GLU(nn.Module): 275 def __init__(self, feat_size): 276 super(GLU, self).__init__() 277 278 torch.manual_seed(5) 279 280 self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) 281 282 self.init_weights() 283 284 def init_weights(self): 285 286 for m in self.modules(): 287 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 288 or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 289 nn.init.orthogonal_(m.weight.data) 290 291 def forward(self, x): 292 293 out = x * torch.sigmoid(self.gate(x)) 294 295 return out 296 297class CoreEncoder(nn.Module): 298 STATE_HIDDEN = 128 299 FRAMES_PER_STEP = 2 300 CONV_KERNEL_SIZE = 4 301 302 def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24): 303 """ core encoder for RDOVAE 304 305 Computes latents, initial states, and rate estimates from features and lambda parameter 306 307 """ 308 309 super(CoreEncoder, self).__init__() 310 311 # hyper parameters 312 self.feature_dim = feature_dim 313 self.output_dim = output_dim 314 self.cond_size = cond_size 315 self.cond_size2 = cond_size2 316 self.state_size = state_size 317 318 # derived parameters 319 self.input_dim = self.FRAMES_PER_STEP * self.feature_dim 320 321 # layers 322 self.dense_1 = nn.Linear(self.input_dim, 64) 323 self.gru1 = nn.GRU(64, 64, batch_first=True) 324 self.conv1 = MyConv(128, 96) 325 self.gru2 = nn.GRU(224, 64, batch_first=True) 326 self.conv2 = MyConv(288, 96, dilation=2) 327 self.gru3 = nn.GRU(384, 64, batch_first=True) 328 self.conv3 = MyConv(448, 96, dilation=2) 329 self.gru4 = nn.GRU(544, 64, batch_first=True) 330 self.conv4 = MyConv(608, 96, dilation=2) 331 self.gru5 = nn.GRU(704, 64, batch_first=True) 332 self.conv5 = MyConv(768, 96, dilation=2) 333 334 self.z_dense = nn.Linear(864, self.output_dim) 335 336 337 self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN) 338 339 self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size) 340 nb_params = sum(p.numel() for p in self.parameters()) 341 print(f"encoder: {nb_params} weights") 342 343 # initialize weights 344 self.apply(init_weights) 345 346 347 def forward(self, features): 348 349 # reshape features 350 x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2))) 351 352 batch = x.size(0) 353 device = x.device 354 355 # run encoding layer stack 356 x = n(torch.tanh(self.dense_1(x))) 357 x = torch.cat([x, n(self.gru1(x)[0])], -1) 358 x = torch.cat([x, n(self.conv1(x))], -1) 359 x = torch.cat([x, n(self.gru2(x)[0])], -1) 360 x = torch.cat([x, n(self.conv2(x))], -1) 361 x = torch.cat([x, n(self.gru3(x)[0])], -1) 362 x = torch.cat([x, n(self.conv3(x))], -1) 363 x = torch.cat([x, n(self.gru4(x)[0])], -1) 364 x = torch.cat([x, n(self.conv4(x))], -1) 365 x = torch.cat([x, n(self.gru5(x)[0])], -1) 366 x = torch.cat([x, n(self.conv5(x))], -1) 367 z = self.z_dense(x) 368 369 # init state for decoder 370 states = torch.tanh(self.state_dense_1(x)) 371 states = self.state_dense_2(states) 372 373 return z, states 374 375 376 377 378class CoreDecoder(nn.Module): 379 380 FRAMES_PER_STEP = 4 381 382 def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24): 383 """ core decoder for RDOVAE 384 385 Computes features from latents, initial state, and quantization index 386 387 """ 388 389 super(CoreDecoder, self).__init__() 390 391 # hyper parameters 392 self.input_dim = input_dim 393 self.output_dim = output_dim 394 self.cond_size = cond_size 395 self.cond_size2 = cond_size2 396 self.state_size = state_size 397 398 self.input_size = self.input_dim 399 400 # layers 401 self.dense_1 = nn.Linear(self.input_size, 96) 402 self.gru1 = nn.GRU(96, 96, batch_first=True) 403 self.conv1 = MyConv(192, 32) 404 self.gru2 = nn.GRU(224, 96, batch_first=True) 405 self.conv2 = MyConv(320, 32) 406 self.gru3 = nn.GRU(352, 96, batch_first=True) 407 self.conv3 = MyConv(448, 32) 408 self.gru4 = nn.GRU(480, 96, batch_first=True) 409 self.conv4 = MyConv(576, 32) 410 self.gru5 = nn.GRU(608, 96, batch_first=True) 411 self.conv5 = MyConv(704, 32) 412 self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim) 413 self.glu1 = GLU(96) 414 self.glu2 = GLU(96) 415 self.glu3 = GLU(96) 416 self.glu4 = GLU(96) 417 self.glu5 = GLU(96) 418 self.hidden_init = nn.Linear(self.state_size, 128) 419 self.gru_init = nn.Linear(128, 480) 420 421 nb_params = sum(p.numel() for p in self.parameters()) 422 print(f"decoder: {nb_params} weights") 423 # initialize weights 424 self.apply(init_weights) 425 self.sparsifier = [] 426 self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) 427 self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) 428 self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) 429 self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) 430 self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) 431 432 def sparsify(self): 433 for sparsifier in self.sparsifier: 434 sparsifier.step() 435 436 def forward(self, z, initial_state): 437 438 hidden = torch.tanh(self.hidden_init(initial_state)) 439 gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2)) 440 h1_state = gru_state[:,:,:96].contiguous() 441 h2_state = gru_state[:,:,96:192].contiguous() 442 h3_state = gru_state[:,:,192:288].contiguous() 443 h4_state = gru_state[:,:,288:384].contiguous() 444 h5_state = gru_state[:,:,384:].contiguous() 445 446 # run decoding layer stack 447 x = n(torch.tanh(self.dense_1(z))) 448 449 x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1) 450 x = torch.cat([x, n(self.conv1(x))], -1) 451 x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1) 452 x = torch.cat([x, n(self.conv2(x))], -1) 453 x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1) 454 x = torch.cat([x, n(self.conv3(x))], -1) 455 x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1) 456 x = torch.cat([x, n(self.conv4(x))], -1) 457 x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1) 458 x = torch.cat([x, n(self.conv5(x))], -1) 459 460 # output layer and reshaping 461 x10 = self.output(x) 462 features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP)) 463 464 return features 465 466 467class StatisticalModel(nn.Module): 468 def __init__(self, quant_levels, latent_dim, state_dim): 469 """ Statistical model for latent space 470 471 Computes scaling, deadzone, r, and theta 472 473 """ 474 475 super(StatisticalModel, self).__init__() 476 477 # copy parameters 478 self.latent_dim = latent_dim 479 self.state_dim = state_dim 480 self.total_dim = latent_dim + state_dim 481 self.quant_levels = quant_levels 482 self.embedding_dim = 6 * self.total_dim 483 484 # quantization embedding 485 self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim) 486 487 # initialize embedding to 0 488 with torch.no_grad(): 489 self.quant_embedding.weight[:] = 0 490 491 492 def forward(self, quant_ids): 493 """ takes quant_ids and returns statistical model parameters""" 494 495 x = self.quant_embedding(quant_ids) 496 497 # CAVE: theta_soft is not used anymore. Kick it out? 498 quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim]) 499 dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim]) 500 theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim]) 501 r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim]) 502 theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim]) 503 r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim]) 504 505 506 return { 507 'quant_embedding' : x, 508 'quant_scale' : quant_scale, 509 'dead_zone' : dead_zone, 510 'r_hard' : r_hard, 511 'theta_hard' : theta_hard, 512 'r_soft' : r_soft, 513 'theta_soft' : theta_soft 514 } 515 516 517class RDOVAE(nn.Module): 518 def __init__(self, 519 feature_dim, 520 latent_dim, 521 quant_levels, 522 cond_size, 523 cond_size2, 524 state_dim=24, 525 split_mode='split', 526 clip_weights=False, 527 pvq_num_pulses=82, 528 state_dropout_rate=0): 529 530 super(RDOVAE, self).__init__() 531 532 self.feature_dim = feature_dim 533 self.latent_dim = latent_dim 534 self.quant_levels = quant_levels 535 self.cond_size = cond_size 536 self.cond_size2 = cond_size2 537 self.split_mode = split_mode 538 self.state_dim = state_dim 539 self.pvq_num_pulses = pvq_num_pulses 540 self.state_dropout_rate = state_dropout_rate 541 542 # submodules encoder and decoder share the statistical model 543 self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim) 544 self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim)) 545 self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim)) 546 547 self.enc_stride = CoreEncoder.FRAMES_PER_STEP 548 self.dec_stride = CoreDecoder.FRAMES_PER_STEP 549 550 if clip_weights: 551 self.weight_clip_fn = weight_clip_factory(0.496) 552 else: 553 self.weight_clip_fn = None 554 555 if self.dec_stride % self.enc_stride != 0: 556 raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride") 557 558 def clip_weights(self): 559 if not type(self.weight_clip_fn) == type(None): 560 self.apply(self.weight_clip_fn) 561 562 def sparsify(self): 563 #self.core_encoder.module.sparsify() 564 self.core_decoder.module.sparsify() 565 566 def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4): 567 568 enc_stride = self.enc_stride 569 dec_stride = self.dec_stride 570 571 stride = dec_stride // enc_stride 572 573 chunks = [] 574 575 for offset in range(stride): 576 # start is the smalles number = offset mod stride that decodes to a valid range 577 start = offset 578 while enc_stride * (start + 1) - dec_stride < 0: 579 start += stride 580 581 # check if start is a valid index 582 if start >= z_frames: 583 raise ValueError("get_decoder_chunks_generic: range too small") 584 585 # stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride 586 stop = z_frames - (z_frames % stride) + offset 587 while stop < z_frames: 588 stop += stride 589 590 # calculate split points 591 length = (stop - start) 592 if mode == 'split': 593 split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop] 594 elif mode == 'random_split': 595 split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)] 596 else: 597 raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}") 598 599 600 for i in range(chunks_per_offset): 601 # (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop) 602 # encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j) 603 # provided that i - j = 1 mod stride 604 chunks.append({ 605 'z_start' : split_points[i], 606 'z_stop' : split_points[i + 1] - stride + 1, 607 'z_stride' : stride, 608 'features_start' : enc_stride * (split_points[i] + 1) - dec_stride, 609 'features_stop' : enc_stride * (split_points[i + 1] - stride + 1) 610 }) 611 612 return chunks 613 614 615 def forward(self, features, q_id): 616 617 # calculate statistical model from quantization ID 618 statistical_model = self.statistical_model(q_id) 619 620 # run encoder 621 z, states = self.core_encoder(features) 622 623 # scaling, dead-zone and quantization 624 z = z * statistical_model['quant_scale'][:,:,:self.latent_dim] 625 z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim]) 626 627 # quantization 628 z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim] 629 z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim] 630 #states_q = soft_pvq(states, self.pvq_num_pulses) 631 states = states * statistical_model['quant_scale'][:,:,self.latent_dim:] 632 states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:]) 633 634 states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:] 635 states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:] 636 637 if self.state_dropout_rate > 0: 638 drop = torch.rand(states_q.size(0)) < self.state_dropout_rate 639 mask = torch.ones_like(states_q) 640 mask[drop] = 0 641 states_q = states_q * mask 642 643 # decoder 644 chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode) 645 646 outputs_hq = [] 647 outputs_sq = [] 648 for chunk in chunks: 649 # decoder with hard quantized input 650 z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) 651 dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :] 652 features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) 653 outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) 654 655 656 # decoder with soft quantized input 657 z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) 658 dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :] 659 features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) 660 outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) 661 662 return { 663 'outputs_hard_quant' : outputs_hq, 664 'outputs_soft_quant' : outputs_sq, 665 'z' : z, 666 'states' : states, 667 'statistical_model' : statistical_model 668 } 669 670 def encode(self, features): 671 """ encoder with quantization and rate estimation """ 672 673 z, states = self.core_encoder(features) 674 675 # quantization of initial states 676 states = soft_pvq(states, self.pvq_num_pulses) 677 state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses)) 678 679 return z, states, state_size 680 681 def decode(self, z, initial_state): 682 """ decoder (flips sequences by itself) """ 683 684 z_reverse = torch.flip(z, [1]) 685 features_reverse = self.core_decoder(z_reverse, initial_state) 686 features = torch.flip(features_reverse, [1]) 687 688 return features 689 690 def quantize(self, z, q_ids): 691 """ quantization of latent vectors """ 692 693 stats = self.statistical_model(q_ids) 694 695 zq = z * stats['quant_scale'][:self.latent_dim] 696 zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim]) 697 zq = torch.round(zq) 698 699 sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False) 700 701 return zq, sizes 702 703 def unquantize(self, zq, q_ids): 704 """ re-scaling of latent vector """ 705 706 stats = self.statistical_model(q_ids) 707 708 z = zq / stats['quant_scale'][:,:,:self.latent_dim] 709 710 return z 711 712 def freeze_model(self): 713 714 # freeze all parameters 715 for p in self.parameters(): 716 p.requires_grad = False 717 718 for p in self.statistical_model.parameters(): 719 p.requires_grad = True 720