1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4from torch.nn.utils import weight_norm 5import numpy as np 6 7which_norm = weight_norm 8 9#################### Definition of basic model components #################### 10 11#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet) 12class ConvLookahead(nn.Module): 13 def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False): 14 super(ConvLookahead, self).__init__() 15 torch.manual_seed(5) 16 17 self.padding_left = (kernel_size - 2) * dilation 18 self.padding_right = 1 * dilation 19 20 self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias)) 21 22 self.init_weights() 23 24 def init_weights(self): 25 26 for m in self.modules(): 27 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 28 nn.init.orthogonal_(m.weight.data) 29 30 def forward(self, x): 31 32 x = F.pad(x,(self.padding_left, self.padding_right)) 33 conv_out = self.conv(x) 34 return conv_out 35 36#(modified) GLU Activation layer definition 37class GLU(nn.Module): 38 def __init__(self, feat_size): 39 super(GLU, self).__init__() 40 41 torch.manual_seed(5) 42 43 self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False)) 44 45 self.init_weights() 46 47 def init_weights(self): 48 49 for m in self.modules(): 50 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ 51 or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 52 nn.init.orthogonal_(m.weight.data) 53 54 def forward(self, x): 55 56 out = torch.tanh(x) * torch.sigmoid(self.gate(x)) 57 58 return out 59 60#GRU layer definition 61class ContForwardGRU(nn.Module): 62 def __init__(self, input_size, hidden_size, num_layers=1): 63 super(ContForwardGRU, self).__init__() 64 65 torch.manual_seed(5) 66 67 self.hidden_size = hidden_size 68 69 self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.hidden_size, bias=False)), 70 nn.Tanh()) 71 72 self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\ 73 bias=False) 74 75 self.nl = GLU(self.hidden_size) 76 77 self.init_weights() 78 79 def init_weights(self): 80 81 for m in self.modules(): 82 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 83 nn.init.orthogonal_(m.weight.data) 84 85 def forward(self, x, x0): 86 87 self.gru.flatten_parameters() 88 89 h0 = self.cont_fc(x0).unsqueeze(0) 90 91 output, h0 = self.gru(x, h0) 92 93 return self.nl(output) 94 95# Framewise convolution layer definition 96class ContFramewiseConv(torch.nn.Module): 97 98 def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True): 99 100 super(ContFramewiseConv, self).__init__() 101 torch.manual_seed(5) 102 103 self.frame_kernel_size = frame_kernel_size 104 self.frame_len = frame_len 105 106 if (causal == True) or (self.frame_kernel_size == 2): 107 108 self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len 109 self.required_pad_right = 0 110 111 self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.required_pad_left, bias=False)), 112 nn.Tanh() 113 ) 114 115 else: 116 117 self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len 118 self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len 119 120 self.fc_input_dim = self.frame_kernel_size * self.frame_len 121 self.fc_out_dim = out_dim 122 123 if act=='glu': 124 self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)), 125 GLU(self.fc_out_dim) 126 ) 127 if act=='tanh': 128 self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)), 129 nn.Tanh() 130 ) 131 132 self.init_weights() 133 134 135 def init_weights(self): 136 137 for m in self.modules(): 138 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\ 139 isinstance(m, nn.Embedding): 140 nn.init.orthogonal_(m.weight.data) 141 142 def forward(self, x, x0): 143 144 if self.frame_kernel_size == 1: 145 return self.fc(x) 146 147 x_flat = x.reshape(x.size(0),1,-1) 148 pad = self.cont_fc(x0).view(x0.size(0),1,-1) 149 x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2) 150 151 x_flat_padded_unfolded = F.unfold(x_flat_padded,\ 152 kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous() 153 154 out = self.fc(x_flat_padded_unfolded) 155 return out 156 157# A fully-connected based upsampling layer definition 158class UpsampleFC(nn.Module): 159 def __init__(self, in_ch, out_ch, upsample_factor): 160 super(UpsampleFC, self).__init__() 161 torch.manual_seed(5) 162 163 self.in_ch = in_ch 164 self.out_ch = out_ch 165 self.upsample_factor = upsample_factor 166 self.fc = nn.Linear(in_ch, out_ch * upsample_factor, bias=False) 167 self.nl = nn.Tanh() 168 169 self.init_weights() 170 171 def init_weights(self): 172 173 for m in self.modules(): 174 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or\ 175 isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 176 nn.init.orthogonal_(m.weight.data) 177 178 def forward(self, x): 179 180 batch_size = x.size(0) 181 x = x.permute(0, 2, 1) 182 x = self.nl(self.fc(x)) 183 x = x.reshape((batch_size, -1, self.out_ch)) 184 x = x.permute(0, 2, 1) 185 return x 186 187########################### The complete model definition ################################# 188 189class FWGAN400ContLarge(nn.Module): 190 def __init__(self): 191 super().__init__() 192 torch.manual_seed(5) 193 194 self.bfcc_with_corr_upsampler = UpsampleFC(19,80,4) 195 196 self.feat_in_conv1 = ConvLookahead(160,256,kernel_size=5) 197 self.feat_in_nl1 = GLU(256) 198 199 self.cont_net = nn.Sequential(which_norm(nn.Linear(321, 160, bias=False)), 200 nn.Tanh(), 201 which_norm(nn.Linear(160, 160, bias=False)), 202 nn.Tanh(), 203 which_norm(nn.Linear(160, 80, bias=False)), 204 nn.Tanh(), 205 which_norm(nn.Linear(80, 80, bias=False)), 206 nn.Tanh(), 207 which_norm(nn.Linear(80, 64, bias=False)), 208 nn.Tanh(), 209 which_norm(nn.Linear(64, 64, bias=False)), 210 nn.Tanh()) 211 212 self.rnn = ContForwardGRU(256,256) 213 214 self.fwc1 = ContFramewiseConv(256, 256) 215 self.fwc2 = ContFramewiseConv(256, 128) 216 self.fwc3 = ContFramewiseConv(128, 128) 217 self.fwc4 = ContFramewiseConv(128, 64) 218 self.fwc5 = ContFramewiseConv(64, 64) 219 self.fwc6 = ContFramewiseConv(64, 40) 220 self.fwc7 = ContFramewiseConv(40, 40) 221 222 self.init_weights() 223 self.count_parameters() 224 225 def init_weights(self): 226 227 for m in self.modules(): 228 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\ 229 isinstance(m, nn.Embedding): 230 nn.init.orthogonal_(m.weight.data) 231 232 def count_parameters(self): 233 num_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 234 print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n") 235 236 def create_phase_signals(self, periods): 237 238 batch_size = periods.size(0) 239 progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1)) 240 progression = torch.repeat_interleave(progression, batch_size, 0) 241 242 phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1) 243 chunks = [] 244 for sframe in range(periods.size(1)): 245 f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1) 246 247 chunk_sin = torch.sin(f * progression + phase0) 248 chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,40) 249 250 chunk_cos = torch.cos(f * progression + phase0) 251 chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,40) 252 253 chunk = torch.cat((chunk_sin, chunk_cos), dim = -1) 254 255 phase0 = phase0 + 160 * f 256 257 chunks.append(chunk) 258 259 phase_signals = torch.cat(chunks, dim=1) 260 261 return phase_signals 262 263 264 def gain_multiply(self, x, c0): 265 266 gain = 10**(0.5*c0/np.sqrt(18.0)) 267 gain = torch.repeat_interleave(gain, 160, dim=-1) 268 gain = gain.reshape(gain.size(0),1,-1).squeeze(1) 269 270 return x * gain 271 272 def forward(self, pitch_period, bfcc_with_corr, x0): 273 274 norm_x0 = torch.norm(x0,2, dim=-1, keepdim=True) 275 x0 = x0 / torch.sqrt((1e-8) + norm_x0**2) 276 x0 = torch.cat((torch.log(norm_x0 + 1e-7), x0), dim=-1) 277 278 p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous() 279 280 envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous()) 281 282 feat_in = torch.cat((p_embed , envelope), dim=1) 283 284 wav_latent1 = self.feat_in_nl1(self.feat_in_conv1(feat_in).permute(0,2,1).contiguous()) 285 286 cont_latent = self.cont_net(x0) 287 288 rnn_out = self.rnn(wav_latent1, cont_latent) 289 290 fwc1_out = self.fwc1(rnn_out, cont_latent) 291 292 fwc2_out = self.fwc2(fwc1_out, cont_latent) 293 294 fwc3_out = self.fwc3(fwc2_out, cont_latent) 295 296 fwc4_out = self.fwc4(fwc3_out, cont_latent) 297 298 fwc5_out = self.fwc5(fwc4_out, cont_latent) 299 300 fwc6_out = self.fwc6(fwc5_out, cont_latent) 301 302 fwc7_out = self.fwc7(fwc6_out, cont_latent) 303 304 waveform = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1) 305 306 waveform = self.gain_multiply(waveform,bfcc_with_corr[:,:,:1]) 307 308 return waveform