xref: /aosp_15_r20/external/libopus/dnn/torch/rdovae/rdovae/rdovae.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2022 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30""" Pytorch implementations of rate distortion optimized variational autoencoder """
31
32import math as m
33
34import torch
35from torch import nn
36import torch.nn.functional as F
37import sys
38import os
39source_dir = os.path.split(os.path.abspath(__file__))[0]
40sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
41from utils.sparsification import GRUSparsifier
42from torch.nn.utils import weight_norm
43
44# Quantization and rate related utily functions
45
46def soft_pvq(x, k):
47    """ soft pyramid vector quantizer """
48
49    # L2 normalization
50    x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True))
51
52
53    with torch.no_grad():
54        # quantization loop, no need to track gradients here
55        x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True)
56
57        # set initial scaling factor to k
58        scale_factor = k
59        x_scaled = scale_factor * x_norm1
60        x_quant = torch.round(x_scaled)
61
62        # we aim for ||x_quant||_L1 = k
63        for _ in range(10):
64            # remove signs and calculate L1 norm
65            abs_x_quant = torch.abs(x_quant)
66            abs_x_scaled = torch.abs(x_scaled)
67            l1_x_quant = torch.sum(abs_x_quant, axis=-1)
68
69            # increase, where target is too small and decrease, where target is too large
70            plus  = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
71            minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
72            factor = torch.where(l1_x_quant > k, minus, plus)
73            factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor)
74            scale_factor = scale_factor * factor.unsqueeze(-1)
75
76            # update x
77            x_scaled = scale_factor * x_norm1
78            x_quant = torch.round(x_quant)
79
80    # L2 normalization of quantized x
81    x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True))
82    quantization_error = x_quant_norm2 - x_norm2
83
84    return x_norm2 + quantization_error.detach()
85
86def cache_parameters(func):
87    cache = dict()
88    def cached_func(*args):
89        if args in cache:
90            return cache[args]
91        else:
92            cache[args] = func(*args)
93
94        return cache[args]
95    return cached_func
96
97@cache_parameters
98def pvq_codebook_size(n, k):
99
100    if k == 0:
101        return 1
102
103    if n == 0:
104        return 0
105
106    return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1)
107
108
109def soft_rate_estimate(z, r, reduce=True):
110    """ rate approximation with dependent theta Eq. (7)"""
111
112    rate = torch.sum(
113        - torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6),
114        dim=-1
115    )
116
117    if reduce:
118        rate = torch.mean(rate)
119
120    return rate
121
122
123def hard_rate_estimate(z, r, theta, reduce=True):
124    """ hard rate approximation """
125
126    z_q = torch.round(z)
127    p0 = 1 - r ** (0.5 + 0.5 * theta)
128    alpha = torch.relu(1 - torch.abs(z_q)) ** 2
129    rate = - torch.sum(
130        (alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6)
131        + (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)),
132        dim=-1
133    )
134
135    if reduce:
136        rate = torch.mean(rate)
137
138    return rate
139
140
141
142def soft_dead_zone(x, dead_zone):
143    """ approximates application of a dead zone to x """
144    d = dead_zone * 0.05
145    return x - d * torch.tanh(x / (0.1 + d))
146
147
148def hard_quantize(x):
149    """ round with copy gradient trick """
150    return x + (torch.round(x) - x).detach()
151
152
153def noise_quantize(x):
154    """ simulates quantization with addition of random uniform noise """
155    return x + (torch.rand_like(x) - 0.5)
156
157
158# loss functions
159
160
161def distortion_loss(y_true, y_pred, rate_lambda=None):
162    """ custom distortion loss for LPCNet features """
163
164    if y_true.size(-1) != 20:
165        raise ValueError('distortion loss is designed to work with 20 features')
166
167    ceps_error   = y_pred[..., :18] - y_true[..., :18]
168    pitch_error  = 2*(y_pred[..., 18:19] - y_true[..., 18:19])
169    corr_error   = y_pred[..., 19:] - y_true[..., 19:]
170    pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2
171
172    loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1)
173
174    if type(rate_lambda) != type(None):
175        loss = loss / torch.sqrt(rate_lambda)
176
177    loss = torch.mean(loss)
178
179    return loss
180
181
182# sampling functions
183
184import random
185
186
187def random_split(start, stop, num_splits=3, min_len=3):
188    get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)])
189    candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
190
191    while get_min_len(candidate) < min_len:
192        candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
193
194    return candidate
195
196
197
198# weight initialization and clipping
199def init_weights(module):
200
201    if isinstance(module, nn.GRU):
202        for p in module.named_parameters():
203            if p[0].startswith('weight_hh_'):
204                nn.init.orthogonal_(p[1])
205
206
207def weight_clip_factory(max_value):
208    """ weight clipping function concerning sum of abs values of adjecent weights """
209    def clip_weight_(w):
210        stop = w.size(1)
211        # omit last column if stop is odd
212        if stop % 2:
213            stop -= 1
214        max_values = max_value * torch.ones_like(w[:, :stop])
215        factor = max_value / torch.maximum(max_values,
216                                 torch.repeat_interleave(
217                                     torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]),
218                                     2,
219                                     1))
220        with torch.no_grad():
221            w[:, :stop] *= factor
222
223    def clip_weights(module):
224        if isinstance(module, nn.GRU) or isinstance(module, nn.Linear):
225            for name, w in module.named_parameters():
226                if name.startswith('weight'):
227                    clip_weight_(w)
228
229    return clip_weights
230
231def n(x):
232    return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
233
234# RDOVAE module and submodules
235
236sparsify_start     = 12000
237sparsify_stop      = 24000
238sparsify_interval  = 100
239sparsify_exponent  = 3
240#sparsify_start     = 0
241#sparsify_stop      = 0
242
243sparse_params1 = {
244#                'W_hr' : (1.0, [8, 4], True),
245#                'W_hz' : (1.0, [8, 4], True),
246#                'W_hn' : (1.0, [8, 4], True),
247                'W_ir' : (0.6, [8, 4], False),
248                'W_iz' : (0.4, [8, 4], False),
249                'W_in' : (0.8, [8, 4], False)
250                }
251
252sparse_params2 = {
253#                'W_hr' : (1.0, [8, 4], True),
254#                'W_hz' : (1.0, [8, 4], True),
255#                'W_hn' : (1.0, [8, 4], True),
256                'W_ir' : (0.3, [8, 4], False),
257                'W_iz' : (0.2, [8, 4], False),
258                'W_in' : (0.4, [8, 4], False)
259                }
260
261
262class MyConv(nn.Module):
263    def __init__(self, input_dim, output_dim, dilation=1):
264        super(MyConv, self).__init__()
265        self.input_dim = input_dim
266        self.output_dim = output_dim
267        self.dilation=dilation
268        self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation)
269    def forward(self, x, state=None):
270        device = x.device
271        conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
272        return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
273
274class GLU(nn.Module):
275    def __init__(self, feat_size):
276        super(GLU, self).__init__()
277
278        torch.manual_seed(5)
279
280        self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
281
282        self.init_weights()
283
284    def init_weights(self):
285
286        for m in self.modules():
287            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
288            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
289                nn.init.orthogonal_(m.weight.data)
290
291    def forward(self, x):
292
293        out = x * torch.sigmoid(self.gate(x))
294
295        return out
296
297class CoreEncoder(nn.Module):
298    STATE_HIDDEN = 128
299    FRAMES_PER_STEP = 2
300    CONV_KERNEL_SIZE = 4
301
302    def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24):
303        """ core encoder for RDOVAE
304
305            Computes latents, initial states, and rate estimates from features and lambda parameter
306
307        """
308
309        super(CoreEncoder, self).__init__()
310
311        # hyper parameters
312        self.feature_dim        = feature_dim
313        self.output_dim         = output_dim
314        self.cond_size          = cond_size
315        self.cond_size2         = cond_size2
316        self.state_size         = state_size
317
318        # derived parameters
319        self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
320
321        # layers
322        self.dense_1 = nn.Linear(self.input_dim, 64)
323        self.gru1 = nn.GRU(64, 64, batch_first=True)
324        self.conv1 = MyConv(128, 96)
325        self.gru2 = nn.GRU(224, 64, batch_first=True)
326        self.conv2 = MyConv(288, 96, dilation=2)
327        self.gru3 = nn.GRU(384, 64, batch_first=True)
328        self.conv3 = MyConv(448, 96, dilation=2)
329        self.gru4 = nn.GRU(544, 64, batch_first=True)
330        self.conv4 = MyConv(608, 96, dilation=2)
331        self.gru5 = nn.GRU(704, 64, batch_first=True)
332        self.conv5 = MyConv(768, 96, dilation=2)
333
334        self.z_dense = nn.Linear(864, self.output_dim)
335
336
337        self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN)
338
339        self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
340        nb_params = sum(p.numel() for p in self.parameters())
341        print(f"encoder: {nb_params} weights")
342
343        # initialize weights
344        self.apply(init_weights)
345
346
347    def forward(self, features):
348
349        # reshape features
350        x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2)))
351
352        batch = x.size(0)
353        device = x.device
354
355        # run encoding layer stack
356        x = n(torch.tanh(self.dense_1(x)))
357        x = torch.cat([x, n(self.gru1(x)[0])], -1)
358        x = torch.cat([x, n(self.conv1(x))], -1)
359        x = torch.cat([x, n(self.gru2(x)[0])], -1)
360        x = torch.cat([x, n(self.conv2(x))], -1)
361        x = torch.cat([x, n(self.gru3(x)[0])], -1)
362        x = torch.cat([x, n(self.conv3(x))], -1)
363        x = torch.cat([x, n(self.gru4(x)[0])], -1)
364        x = torch.cat([x, n(self.conv4(x))], -1)
365        x = torch.cat([x, n(self.gru5(x)[0])], -1)
366        x = torch.cat([x, n(self.conv5(x))], -1)
367        z = self.z_dense(x)
368
369        # init state for decoder
370        states = torch.tanh(self.state_dense_1(x))
371        states = self.state_dense_2(states)
372
373        return z, states
374
375
376
377
378class CoreDecoder(nn.Module):
379
380    FRAMES_PER_STEP = 4
381
382    def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24):
383        """ core decoder for RDOVAE
384
385            Computes features from latents, initial state, and quantization index
386
387        """
388
389        super(CoreDecoder, self).__init__()
390
391        # hyper parameters
392        self.input_dim  = input_dim
393        self.output_dim = output_dim
394        self.cond_size  = cond_size
395        self.cond_size2 = cond_size2
396        self.state_size = state_size
397
398        self.input_size = self.input_dim
399
400        # layers
401        self.dense_1    = nn.Linear(self.input_size, 96)
402        self.gru1 = nn.GRU(96, 96, batch_first=True)
403        self.conv1 = MyConv(192, 32)
404        self.gru2 = nn.GRU(224, 96, batch_first=True)
405        self.conv2 = MyConv(320, 32)
406        self.gru3 = nn.GRU(352, 96, batch_first=True)
407        self.conv3 = MyConv(448, 32)
408        self.gru4 = nn.GRU(480, 96, batch_first=True)
409        self.conv4 = MyConv(576, 32)
410        self.gru5 = nn.GRU(608, 96, batch_first=True)
411        self.conv5 = MyConv(704, 32)
412        self.output  = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
413        self.glu1 = GLU(96)
414        self.glu2 = GLU(96)
415        self.glu3 = GLU(96)
416        self.glu4 = GLU(96)
417        self.glu5 = GLU(96)
418        self.hidden_init = nn.Linear(self.state_size, 128)
419        self.gru_init = nn.Linear(128, 480)
420
421        nb_params = sum(p.numel() for p in self.parameters())
422        print(f"decoder: {nb_params} weights")
423        # initialize weights
424        self.apply(init_weights)
425        self.sparsifier = []
426        self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
427        self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
428        self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
429        self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
430        self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
431
432    def sparsify(self):
433        for sparsifier in self.sparsifier:
434            sparsifier.step()
435
436    def forward(self, z, initial_state):
437
438        hidden = torch.tanh(self.hidden_init(initial_state))
439        gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2))
440        h1_state = gru_state[:,:,:96].contiguous()
441        h2_state = gru_state[:,:,96:192].contiguous()
442        h3_state = gru_state[:,:,192:288].contiguous()
443        h4_state = gru_state[:,:,288:384].contiguous()
444        h5_state = gru_state[:,:,384:].contiguous()
445
446        # run decoding layer stack
447        x = n(torch.tanh(self.dense_1(z)))
448
449        x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1)
450        x = torch.cat([x, n(self.conv1(x))], -1)
451        x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1)
452        x = torch.cat([x, n(self.conv2(x))], -1)
453        x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1)
454        x = torch.cat([x, n(self.conv3(x))], -1)
455        x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1)
456        x = torch.cat([x, n(self.conv4(x))], -1)
457        x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1)
458        x = torch.cat([x, n(self.conv5(x))], -1)
459
460        # output layer and reshaping
461        x10 = self.output(x)
462        features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
463
464        return features
465
466
467class StatisticalModel(nn.Module):
468    def __init__(self, quant_levels, latent_dim, state_dim):
469        """ Statistical model for latent space
470
471            Computes scaling, deadzone, r, and theta
472
473        """
474
475        super(StatisticalModel, self).__init__()
476
477        # copy parameters
478        self.latent_dim     = latent_dim
479        self.state_dim      = state_dim
480        self.total_dim      = latent_dim + state_dim
481        self.quant_levels   = quant_levels
482        self.embedding_dim  = 6 * self.total_dim
483
484        # quantization embedding
485        self.quant_embedding    = nn.Embedding(quant_levels, self.embedding_dim)
486
487        # initialize embedding to 0
488        with torch.no_grad():
489            self.quant_embedding.weight[:] = 0
490
491
492    def forward(self, quant_ids):
493        """ takes quant_ids and returns statistical model parameters"""
494
495        x = self.quant_embedding(quant_ids)
496
497        # CAVE: theta_soft is not used anymore. Kick it out?
498        quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim])
499        dead_zone   = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim])
500        theta_soft  = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim])
501        r_soft      = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim])
502        theta_hard  = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim])
503        r_hard      = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim])
504
505
506        return {
507            'quant_embedding'   : x,
508            'quant_scale'       : quant_scale,
509            'dead_zone'         : dead_zone,
510            'r_hard'            : r_hard,
511            'theta_hard'        : theta_hard,
512            'r_soft'            : r_soft,
513            'theta_soft'        : theta_soft
514        }
515
516
517class RDOVAE(nn.Module):
518    def __init__(self,
519                 feature_dim,
520                 latent_dim,
521                 quant_levels,
522                 cond_size,
523                 cond_size2,
524                 state_dim=24,
525                 split_mode='split',
526                 clip_weights=False,
527                 pvq_num_pulses=82,
528                 state_dropout_rate=0):
529
530        super(RDOVAE, self).__init__()
531
532        self.feature_dim    = feature_dim
533        self.latent_dim     = latent_dim
534        self.quant_levels   = quant_levels
535        self.cond_size      = cond_size
536        self.cond_size2     = cond_size2
537        self.split_mode     = split_mode
538        self.state_dim      = state_dim
539        self.pvq_num_pulses = pvq_num_pulses
540        self.state_dropout_rate = state_dropout_rate
541
542        # submodules encoder and decoder share the statistical model
543        self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
544        self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
545        self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
546
547        self.enc_stride = CoreEncoder.FRAMES_PER_STEP
548        self.dec_stride = CoreDecoder.FRAMES_PER_STEP
549
550        if clip_weights:
551            self.weight_clip_fn = weight_clip_factory(0.496)
552        else:
553            self.weight_clip_fn = None
554
555        if self.dec_stride % self.enc_stride != 0:
556            raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride")
557
558    def clip_weights(self):
559        if not type(self.weight_clip_fn) == type(None):
560            self.apply(self.weight_clip_fn)
561
562    def sparsify(self):
563        #self.core_encoder.module.sparsify()
564        self.core_decoder.module.sparsify()
565
566    def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
567
568        enc_stride = self.enc_stride
569        dec_stride = self.dec_stride
570
571        stride = dec_stride // enc_stride
572
573        chunks = []
574
575        for offset in range(stride):
576            # start is the smalles number = offset mod stride that decodes to a valid range
577            start = offset
578            while enc_stride * (start + 1) - dec_stride < 0:
579                start += stride
580
581            # check if start is a valid index
582            if start >= z_frames:
583                raise ValueError("get_decoder_chunks_generic: range too small")
584
585            # stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride
586            stop = z_frames - (z_frames % stride) + offset
587            while stop < z_frames:
588                stop += stride
589
590            # calculate split points
591            length = (stop - start)
592            if mode == 'split':
593                split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
594            elif mode == 'random_split':
595                split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)]
596            else:
597                raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}")
598
599
600            for i in range(chunks_per_offset):
601                # (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop)
602                # encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j)
603                # provided that i - j = 1 mod stride
604                chunks.append({
605                    'z_start'         : split_points[i],
606                    'z_stop'          : split_points[i + 1] - stride + 1,
607                    'z_stride'        : stride,
608                    'features_start'  : enc_stride * (split_points[i] + 1) - dec_stride,
609                    'features_stop'   : enc_stride * (split_points[i + 1] - stride + 1)
610                })
611
612        return chunks
613
614
615    def forward(self, features, q_id):
616
617        # calculate statistical model from quantization ID
618        statistical_model = self.statistical_model(q_id)
619
620        # run encoder
621        z, states = self.core_encoder(features)
622
623        # scaling, dead-zone and quantization
624        z = z * statistical_model['quant_scale'][:,:,:self.latent_dim]
625        z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim])
626
627        # quantization
628        z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
629        z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
630        #states_q = soft_pvq(states, self.pvq_num_pulses)
631        states = states * statistical_model['quant_scale'][:,:,self.latent_dim:]
632        states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:])
633
634        states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
635        states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
636
637        if self.state_dropout_rate > 0:
638            drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
639            mask = torch.ones_like(states_q)
640            mask[drop] = 0
641            states_q = states_q * mask
642
643        # decoder
644        chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode)
645
646        outputs_hq = []
647        outputs_sq = []
648        for chunk in chunks:
649            # decoder with hard quantized input
650            z_dec_reverse       = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
651            dec_initial_state   = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
652            features_reverse = self.core_decoder(z_dec_reverse,  dec_initial_state)
653            outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
654
655
656            # decoder with soft quantized input
657            z_dec_reverse       = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :],  [1])
658            dec_initial_state   = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
659            features_reverse    = self.core_decoder(z_dec_reverse, dec_initial_state)
660            outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
661
662        return {
663            'outputs_hard_quant' : outputs_hq,
664            'outputs_soft_quant' : outputs_sq,
665            'z'                 : z,
666            'states'            : states,
667            'statistical_model' : statistical_model
668        }
669
670    def encode(self, features):
671        """ encoder with quantization and rate estimation """
672
673        z, states = self.core_encoder(features)
674
675        # quantization of initial states
676        states = soft_pvq(states, self.pvq_num_pulses)
677        state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses))
678
679        return z, states, state_size
680
681    def decode(self, z, initial_state):
682        """ decoder (flips sequences by itself) """
683
684        z_reverse       = torch.flip(z, [1])
685        features_reverse = self.core_decoder(z_reverse, initial_state)
686        features = torch.flip(features_reverse, [1])
687
688        return features
689
690    def quantize(self, z, q_ids):
691        """ quantization of latent vectors """
692
693        stats = self.statistical_model(q_ids)
694
695        zq = z * stats['quant_scale'][:self.latent_dim]
696        zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim])
697        zq = torch.round(zq)
698
699        sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False)
700
701        return zq, sizes
702
703    def unquantize(self, zq, q_ids):
704        """ re-scaling of latent vector """
705
706        stats = self.statistical_model(q_ids)
707
708        z = zq / stats['quant_scale'][:,:,:self.latent_dim]
709
710        return z
711
712    def freeze_model(self):
713
714        # freeze all parameters
715        for p in self.parameters():
716            p.requires_grad = False
717
718        for p in self.statistical_model.parameters():
719            p.requires_grad = True
720