xref: /aosp_15_r20/external/libopus/dnn/torch/lpcnet/models/multi_rate_lpcnet.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31from torch import nn
32from utils.layers.subconditioner import get_subconditioner
33from utils.layers import DualFC
34
35from utils.ulaw import lin2ulawq, ulaw2lin
36from utils.sample import sample_excitation
37from utils.pcm import clip_to_int16
38from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
39
40from utils.misc import interleave_tensors
41
42
43
44
45# MultiRateLPCNet
46class MultiRateLPCNet(nn.Module):
47    def __init__(self, config):
48        super(MultiRateLPCNet, self).__init__()
49
50        # general parameters
51        self.input_layout       = config['input_layout']
52        self.feature_history    = config['feature_history']
53        self.feature_lookahead  = config['feature_lookahead']
54        self.signals            = config['signals']
55
56        # frame rate network parameters
57        self.feature_dimension          = config['feature_dimension']
58        self.period_embedding_dim       = config['period_embedding_dim']
59        self.period_levels              = config['period_levels']
60        self.feature_channels           = self.feature_dimension + self.period_embedding_dim
61        self.feature_conditioning_dim   = config['feature_conditioning_dim']
62        self.feature_conv_kernel_size   = config['feature_conv_kernel_size']
63
64        # frame rate network layers
65        self.period_embedding   = nn.Embedding(self.period_levels, self.period_embedding_dim)
66        self.feature_conv1      = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
67        self.feature_conv2      = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
68        self.feature_dense1     = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
69        self.feature_dense2     = nn.Linear(*(2*[self.feature_conditioning_dim]))
70
71        # sample rate network parameters
72        self.frame_size             = config['frame_size']
73        self.signal_levels          = config['signal_levels']
74        self.signal_embedding_dim   = config['signal_embedding_dim']
75        self.gru_a_units            = config['gru_a_units']
76        self.gru_b_units            = config['gru_b_units']
77        self.output_levels          = config['output_levels']
78
79        # subconditioning B
80        sub_config = config['subconditioning']['subconditioning_b']
81        self.substeps_b = sub_config['number_of_subsamples']
82        self.subcondition_signals_b = sub_config['signals']
83        self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
84        method = sub_config['method']
85        kwargs = sub_config['kwargs']
86        if type(kwargs) == type(None):
87            kwargs = dict()
88
89        state_size = self.gru_b_units
90        self.subconditioner_b = get_subconditioner(method,
91            sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
92            state_size, self.signal_levels, len(sub_config['signals']),
93            **sub_config['kwargs'])
94
95         # subconditioning A
96        sub_config = config['subconditioning']['subconditioning_a']
97        self.substeps_a = sub_config['number_of_subsamples']
98        self.subcondition_signals_a = sub_config['signals']
99        self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
100        method = sub_config['method']
101        kwargs = sub_config['kwargs']
102        if type(kwargs) == type(None):
103            kwargs = dict()
104
105        state_size = self.gru_a_units
106        self.subconditioner_a = get_subconditioner(method,
107            sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
108            state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
109            **sub_config['kwargs'])
110
111
112        # wrap up subconditioning, group_size_gru_a holds the number
113        # of timesteps that are grouped as sample input for GRU A
114        # input and group_size_subcondition_a holds the number of samples that are
115        # grouped as input to pre-GRU B subconditioning
116        self.group_size_gru_a = self.substeps_a * self.substeps_b
117        self.group_size_subcondition_a = self.substeps_b
118        self.gru_a_rate_divider = self.group_size_gru_a
119        self.gru_b_rate_divider = self.substeps_b
120
121        # gru sizes
122        self.gru_a_input_dim        = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
123        self.gru_b_input_dim        = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
124        self.signals_idx            = [self.input_layout['signals'][key] for key in self.signals]
125
126        # sample rate network layers
127        self.signal_embedding   = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
128        self.gru_a              = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
129        self.gru_b              = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
130
131        # sparsification
132        self.sparsifier = []
133
134        # GRU A
135        if 'gru_a' in config['sparsification']:
136            gru_config  = config['sparsification']['gru_a']
137            task_list = [(self.gru_a, gru_config['params'])]
138            self.sparsifier.append(GRUSparsifier(task_list,
139                                                 gru_config['start'],
140                                                 gru_config['stop'],
141                                                 gru_config['interval'],
142                                                 gru_config['exponent'])
143            )
144            self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
145                                                                      gru_config['params'], drop_input=True)
146        else:
147            self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
148
149        # GRU B
150        if 'gru_b' in config['sparsification']:
151            gru_config  = config['sparsification']['gru_b']
152            task_list = [(self.gru_b, gru_config['params'])]
153            self.sparsifier.append(GRUSparsifier(task_list,
154                                                 gru_config['start'],
155                                                 gru_config['stop'],
156                                                 gru_config['interval'],
157                                                 gru_config['exponent'])
158            )
159            self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
160                                                                      gru_config['params'])
161        else:
162            self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
163
164
165
166        # dual FCs
167        self.dual_fc = []
168        for i in range(self.substeps_b):
169            dim = self.subconditioner_b.get_output_dim(i)
170            self.dual_fc.append(DualFC(dim, self.output_levels))
171            self.add_module(f"dual_fc_{i}", self.dual_fc[-1])
172
173    def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
174        gflops = 0
175
176        # frame rate network
177        conditioning_dim = self.feature_conditioning_dim
178        feature_channels = self.feature_channels
179        frame_rate = fs / self.frame_size
180        frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
181        if verbose:
182            print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
183        gflops += frame_rate_network_complexity
184
185        # gru a
186        gru_a_rate = fs / self.group_size_gru_a
187        gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
188        if verbose:
189            print(f"gru A: {gru_a_complexity} GFLOPS")
190        gflops += gru_a_complexity
191
192        # subconditioning a
193        subcond_a_rate = fs / self.substeps_b
194        subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
195        if verbose:
196            print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
197        gflops += subconditioning_a_complexity
198
199        # gru b
200        gru_b_rate = fs / self.substeps_b
201        gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
202        if verbose:
203            print(f"gru B: {gru_b_complexity} GFLOPS")
204        gflops += gru_b_complexity
205
206        # subconditioning b
207        subcond_b_rate = fs
208        subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
209        if verbose:
210            print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
211        gflops += subconditioning_b_complexity
212
213        # dual fcs
214        for i, fc in enumerate(self.dual_fc):
215            rate = fs / len(self.dual_fc)
216            input_size = fc.dense1.in_features
217            output_size = fc.dense1.out_features
218            dual_fc_complexity = 1e-9 *  (4 * input_size * output_size + 22 * output_size) * rate
219            if hierarchical_sampling:
220                dual_fc_complexity /= 8
221            if verbose:
222                print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
223            gflops += dual_fc_complexity
224
225        if verbose:
226            print(f'total: {gflops} GFLOPS')
227
228        return gflops
229
230
231
232    def sparsify(self):
233        for sparsifier in self.sparsifier:
234            sparsifier.step()
235
236    def frame_rate_network(self, features, periods):
237
238        embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
239        features = torch.concat((features, embedded_periods), dim=-1)
240
241        # convert to channels first and calculate conditioning vector
242        c = torch.permute(features, [0, 2, 1])
243
244        c = torch.tanh(self.feature_conv1(c))
245        c = torch.tanh(self.feature_conv2(c))
246        # back to channels last
247        c = torch.permute(c, [0, 2, 1])
248        c = torch.tanh(self.feature_dense1(c))
249        c = torch.tanh(self.feature_dense2(c))
250
251        return c
252
253    def prepare_signals(self, signals, group_size, signal_idx):
254        """ extracts, delays and groups signals """
255
256        batch_size, sequence_length, num_signals = signals.shape
257
258        # extract signals according to position
259        signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
260                            dim=-1)
261
262        # roll back pcm to account for grouping
263        signals  = torch.roll(signals, group_size - 1, -2)
264
265        # reshape
266        signals = torch.reshape(signals,
267            (batch_size, sequence_length // group_size, group_size * len(signal_idx)))
268
269        return signals
270
271
272    def sample_rate_network(self, signals, c, gru_states):
273
274        signals_a        = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
275        embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
276        # features at GRU A rate
277        c_upsampled_a    = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
278        # features at GRU B rate
279        c_upsampled_b    = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)
280
281        y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
282        y, gru_a_state = self.gru_a(y, gru_states[0])
283        # first round of upsampling and subconditioning
284        c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
285        y = self.subconditioner_a(y, c_signals_a)
286        y = interleave_tensors(y)
287
288        y = torch.concat((y, c_upsampled_b), dim=-1)
289        y, gru_b_state = self.gru_b(y, gru_states[1])
290        c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
291        y = self.subconditioner_b(y, c_signals_b)
292
293        y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
294        y = interleave_tensors(y)
295
296        return y, (gru_a_state, gru_b_state)
297
298    def decoder(self, signals, c, gru_states):
299        embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
300
301        y = torch.concat((embedded_signals, c), dim=-1)
302        y, gru_a_state = self.gru_a(y, gru_states[0])
303        y = torch.concat((y, c), dim=-1)
304        y, gru_b_state = self.gru_b(y, gru_states[1])
305
306        y = self.dual_fc(y)
307
308        return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)
309
310    def forward(self, features, periods, signals, gru_states):
311
312        c           = self.frame_rate_network(features, periods)
313        y, _        = self.sample_rate_network(signals, c, gru_states)
314        log_probs   = torch.log_softmax(y, dim=-1)
315
316        return log_probs
317
318    def generate(self, features, periods, lpcs):
319
320        with torch.no_grad():
321            device = self.parameters().__next__().device
322
323            num_frames          = features.shape[0] - self.feature_history - self.feature_lookahead
324            lpc_order           = lpcs.shape[-1]
325            num_input_signals   = len(self.signals)
326            pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
327
328            # signal buffers
329            last_signal       = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
330            prediction        = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
331            last_error        = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
332            output            = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
333            mem = 0
334
335            # state buffers
336            gru_a_state = torch.zeros((1, 1, self.gru_a_units))
337            gru_b_state = torch.zeros((1, 1, self.gru_b_units))
338
339            input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
340            # conditioning signals for subconditioner a
341            c_signals_a   = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
342            # conditioning signals for subconditioner b
343            c_signals_b   = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)
344
345            # signal dict
346            signal_dict = {
347                'prediction'    : prediction,
348                'last_error'    : last_error,
349                'last_signal'   : last_signal
350            }
351
352            # push data to device
353            features = features.to(device)
354            periods  = periods.to(device)
355            lpcs     = lpcs.to(device)
356
357            # run feature encoding
358            c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
359
360            for frame_index in range(num_frames):
361                frame_start = frame_index * self.frame_size
362                pitch_corr  = features[frame_index + self.feature_history, pitch_corr_position]
363                a           = - torch.flip(lpcs[frame_index + self.feature_history], [0])
364                current_c   = c[:, frame_index : frame_index + 1, :]
365
366                for i in range(0, self.frame_size, self.group_size_gru_a):
367                    pcm_position    = frame_start + i + lpc_order
368                    output_position = frame_start + i
369
370                    # calculate newest prediction
371                    prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
372
373                    # prepare input
374                    for slot in range(self.group_size_gru_a):
375                        k = slot - self.group_size_gru_a + 1
376                        for idx, name in enumerate(self.signals):
377                            input_signals[idx + slot * num_input_signals] = lin2ulawq(
378                                signal_dict[name][pcm_position + k]
379                            )
380
381
382                    # run GRU A
383                    embed_signals   = self.signal_embedding(input_signals.reshape((1, 1, -1)))
384                    embed_signals   = torch.flatten(embed_signals, 2)
385                    y               = torch.cat((embed_signals, current_c), dim=-1)
386                    h_a, gru_a_state  = self.gru_a(y, gru_a_state)
387
388                    # loop over substeps_a
389                    for step_a in range(self.substeps_a):
390                        # prepare conditioning input
391                        for slot in range(self.group_size_subcondition_a):
392                            k = slot - self.group_size_subcondition_a + 1
393                            for idx, name in enumerate(self.subcondition_signals_a):
394                                c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
395                                    signal_dict[name][pcm_position + k]
396                                )
397
398                        # subconditioning
399                        h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))
400
401                        # run GRU B
402                        y = torch.cat((h_a, current_c), dim=-1)
403                        h_b, gru_b_state = self.gru_b(y, gru_b_state)
404
405                        # loop over substeps b
406                        for step_b in range(self.substeps_b):
407                            # prepare subconditioning input
408                            for idx, name in enumerate(self.subcondition_signals_b):
409                                c_signals_b[idx] = lin2ulawq(
410                                    signal_dict[name][pcm_position]
411                                )
412
413                            # subcondition
414                            h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))
415
416                            # run dual FC
417                            probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)
418
419                            # sample
420                            new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))
421
422                            # update signals
423                            sig = new_exc + prediction[pcm_position]
424                            last_error[pcm_position + 1] = new_exc
425                            last_signal[pcm_position + 1] = sig
426
427                            mem = 0.85 * mem + float(sig)
428                            output[output_position] = clip_to_int16(round(mem))
429
430                            # increase positions
431                            pcm_position += 1
432                            output_position += 1
433
434                            # calculate next prediction
435                            prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
436
437        return output
438