xref: /aosp_15_r20/external/libopus/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31from torch import nn
32import torch.nn.functional as F
33
34from utils.endoscopy import write_data
35
36from utils.ada_conv import adaconv_kernel
37from utils.softquant import soft_quant
38
39class LimitedAdaptiveConv1d(nn.Module):
40    COUNTER = 1
41
42    def __init__(self,
43                 in_channels,
44                 out_channels,
45                 kernel_size,
46                 feature_dim,
47                 frame_size=160,
48                 overlap_size=40,
49                 padding=None,
50                 name=None,
51                 gain_limits_db=[-6, 6],
52                 shape_gain_db=0,
53                 norm_p=2,
54                 softquant=False,
55                 apply_weight_norm=False,
56                 **kwargs):
57        """
58
59        Parameters:
60        -----------
61
62        in_channels : int
63            number of input channels
64
65        out_channels : int
66            number of output channels
67
68        feature_dim : int
69            dimension of features from which kernels, biases and gains are computed
70
71        frame_size : int
72            frame size
73
74        overlap_size : int
75            overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame
76
77        use_bias : bool
78            if true, biases will be added to output channels
79
80
81        padding : List[int, int]
82
83        """
84
85        super(LimitedAdaptiveConv1d, self).__init__()
86
87
88
89        self.in_channels    = in_channels
90        self.out_channels   = out_channels
91        self.feature_dim    = feature_dim
92        self.kernel_size    = kernel_size
93        self.frame_size     = frame_size
94        self.overlap_size   = overlap_size
95        self.gain_limits_db = gain_limits_db
96        self.shape_gain_db  = shape_gain_db
97        self.norm_p         = norm_p
98
99        if name is None:
100            self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER)
101            LimitedAdaptiveConv1d.COUNTER += 1
102        else:
103            self.name = name
104
105        norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
106
107        # network for generating convolution weights
108        self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size))
109        if softquant:
110            self.conv_kernel = soft_quant(self.conv_kernel)
111
112        self.shape_gain = min(1, 10**(shape_gain_db / 20))
113
114        self.filter_gain = norm(nn.Linear(feature_dim, out_channels))
115        log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
116        self.filter_gain_a = (log_max - log_min) / 2
117        self.filter_gain_b = (log_max + log_min) / 2
118
119        if type(padding) == type(None):
120            self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
121        else:
122            self.padding = padding
123
124        self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
125
126
127    def flop_count(self, rate):
128        frame_rate = rate / self.frame_size
129        overlap = self.overlap_size
130        overhead = overlap / self.frame_size
131
132        count = 0
133
134        # kernel computation and filtering
135        count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
136        count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
137
138        # gain computation
139
140        count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
141
142        # windowing
143        count += 3 * overlap * frame_rate * self.out_channels
144
145        return count
146
147    def forward(self, x, features, debug=False):
148        """ adaptive 1d convolution
149
150
151        Parameters:
152        -----------
153        x : torch.tensor
154            input signal of shape (batch_size, in_channels, num_samples)
155
156        feathres : torch.tensor
157            frame-wise features of shape (batch_size, num_frames, feature_dim)
158
159        """
160
161        batch_size = x.size(0)
162        num_frames = features.size(1)
163        num_samples = x.size(2)
164        frame_size = self.frame_size
165        overlap_size = self.overlap_size
166        kernel_size = self.kernel_size
167        win1 = torch.flip(self.overlap_win, [0])
168        win2 = self.overlap_win
169
170        if num_samples // self.frame_size != num_frames:
171            raise ValueError('non matching sizes in AdaptiveConv1d.forward')
172
173        conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
174
175        # normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension)
176        conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True))
177
178        # limit shape
179        id_kernels = torch.zeros_like(conv_kernels)
180        id_kernels[..., self.padding[1]] = 1
181
182        conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels
183
184        # calculate gains
185        conv_gains   = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b)
186        if debug and batch_size == 1:
187            key = self.name + "_gains"
188            write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
189            key = self.name + "_kernels"
190            write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
191
192
193        conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1)
194
195        conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4)
196
197        output = adaconv_kernel(x, conv_kernels, win1, fft_size=256)
198
199
200        return output