1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31from torch import nn 32import torch.nn.functional as F 33 34from utils.endoscopy import write_data 35 36from utils.ada_conv import adaconv_kernel 37from utils.softquant import soft_quant 38 39class LimitedAdaptiveConv1d(nn.Module): 40 COUNTER = 1 41 42 def __init__(self, 43 in_channels, 44 out_channels, 45 kernel_size, 46 feature_dim, 47 frame_size=160, 48 overlap_size=40, 49 padding=None, 50 name=None, 51 gain_limits_db=[-6, 6], 52 shape_gain_db=0, 53 norm_p=2, 54 softquant=False, 55 apply_weight_norm=False, 56 **kwargs): 57 """ 58 59 Parameters: 60 ----------- 61 62 in_channels : int 63 number of input channels 64 65 out_channels : int 66 number of output channels 67 68 feature_dim : int 69 dimension of features from which kernels, biases and gains are computed 70 71 frame_size : int 72 frame size 73 74 overlap_size : int 75 overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame 76 77 use_bias : bool 78 if true, biases will be added to output channels 79 80 81 padding : List[int, int] 82 83 """ 84 85 super(LimitedAdaptiveConv1d, self).__init__() 86 87 88 89 self.in_channels = in_channels 90 self.out_channels = out_channels 91 self.feature_dim = feature_dim 92 self.kernel_size = kernel_size 93 self.frame_size = frame_size 94 self.overlap_size = overlap_size 95 self.gain_limits_db = gain_limits_db 96 self.shape_gain_db = shape_gain_db 97 self.norm_p = norm_p 98 99 if name is None: 100 self.name = "limited_adaptive_conv1d_" + str(LimitedAdaptiveConv1d.COUNTER) 101 LimitedAdaptiveConv1d.COUNTER += 1 102 else: 103 self.name = name 104 105 norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x 106 107 # network for generating convolution weights 108 self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size)) 109 if softquant: 110 self.conv_kernel = soft_quant(self.conv_kernel) 111 112 self.shape_gain = min(1, 10**(shape_gain_db / 20)) 113 114 self.filter_gain = norm(nn.Linear(feature_dim, out_channels)) 115 log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229 116 self.filter_gain_a = (log_max - log_min) / 2 117 self.filter_gain_b = (log_max + log_min) / 2 118 119 if type(padding) == type(None): 120 self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2] 121 else: 122 self.padding = padding 123 124 self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False) 125 126 127 def flop_count(self, rate): 128 frame_rate = rate / self.frame_size 129 overlap = self.overlap_size 130 overhead = overlap / self.frame_size 131 132 count = 0 133 134 # kernel computation and filtering 135 count += 2 * (frame_rate * self.feature_dim * self.kernel_size) 136 count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate) 137 138 # gain computation 139 140 count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels 141 142 # windowing 143 count += 3 * overlap * frame_rate * self.out_channels 144 145 return count 146 147 def forward(self, x, features, debug=False): 148 """ adaptive 1d convolution 149 150 151 Parameters: 152 ----------- 153 x : torch.tensor 154 input signal of shape (batch_size, in_channels, num_samples) 155 156 feathres : torch.tensor 157 frame-wise features of shape (batch_size, num_frames, feature_dim) 158 159 """ 160 161 batch_size = x.size(0) 162 num_frames = features.size(1) 163 num_samples = x.size(2) 164 frame_size = self.frame_size 165 overlap_size = self.overlap_size 166 kernel_size = self.kernel_size 167 win1 = torch.flip(self.overlap_win, [0]) 168 win2 = self.overlap_win 169 170 if num_samples // self.frame_size != num_frames: 171 raise ValueError('non matching sizes in AdaptiveConv1d.forward') 172 173 conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size)) 174 175 # normalize kernels (TODO: switch to L1 and normalize over kernel and input channel dimension) 176 conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=[-2, -1], keepdim=True)) 177 178 # limit shape 179 id_kernels = torch.zeros_like(conv_kernels) 180 id_kernels[..., self.padding[1]] = 1 181 182 conv_kernels = self.shape_gain * conv_kernels + (1 - self.shape_gain) * id_kernels 183 184 # calculate gains 185 conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b) 186 if debug and batch_size == 1: 187 key = self.name + "_gains" 188 write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size) 189 key = self.name + "_kernels" 190 write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) 191 192 193 conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1) 194 195 conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4) 196 197 output = adaconv_kernel(x, conv_kernels, win1, fft_size=256) 198 199 200 return output