1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import math as m 31import copy 32 33import torch 34import torch.nn.functional as F 35from torch import nn 36from torch.nn.utils import weight_norm, spectral_norm 37import torchaudio 38 39from utils.spec import gen_filterbank 40 41# auxiliary functions 42 43def remove_all_weight_norms(module): 44 for m in module.modules(): 45 if hasattr(m, 'weight_v'): 46 nn.utils.remove_weight_norm(m) 47 48 49def create_smoothing_kernel(h, w, gamma=1.5): 50 51 ch = h / 2 - 0.5 52 cw = w / 2 - 0.5 53 54 sh = gamma * ch 55 sw = gamma * cw 56 57 vx = ((torch.arange(h) - ch) / sh) ** 2 58 vy = ((torch.arange(w) - cw) / sw) ** 2 59 vals = vx.view(-1, 1) + vy.view(1, -1) 60 kernel = torch.exp(- vals) 61 kernel = kernel / kernel.sum() 62 63 return kernel 64 65 66def create_kernel(h, w, sh, sw): 67 # proto kernel gives disjoint partition of 1 68 proto_kernel = torch.ones((sh, sw)) 69 70 # create smoothing kernel eta 71 h_eta, w_eta = h - sh + 1, w - sw + 1 72 assert h_eta > 0 and w_eta > 0 73 eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta) 74 75 kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0) 76 kernel = F.conv2d(kernel0, eta) 77 78 return kernel 79 80# positional embeddings 81class FrequencyPositionalEmbedding(nn.Module): 82 def __init__(self): 83 84 super().__init__() 85 86 def forward(self, x): 87 88 N = x.size(2) 89 args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N 90 cos = torch.cos(args).reshape(1, 1, -1, 1) 91 sin = torch.sin(args).reshape(1, 1, -1, 1) 92 zeros = torch.zeros_like(x[:, 0:1, :, :]) 93 94 y = torch.cat((x, zeros + sin, zeros + cos), dim=1) 95 96 return y 97 98 99class PositionalEmbedding2D(nn.Module): 100 def __init__(self, d=5): 101 102 super().__init__() 103 104 self.d = d 105 106 def forward(self, x): 107 108 N = x.size(2) 109 M = x.size(3) 110 111 h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1) 112 w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1) 113 coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1) 114 115 h_sin = torch.sin(coeffs * h_args) 116 h_cos = torch.sin(coeffs * h_args) 117 w_sin = torch.sin(coeffs * w_args) 118 w_cos = torch.sin(coeffs * w_args) 119 120 zeros = torch.zeros_like(x[:, 0:1, :, :]) 121 122 y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1) 123 124 return y 125 126 127# spectral discriminator base class 128class SpecDiscriminatorBase(nn.Module): 129 RECEPTIVE_FIELD_MAX_WIDTH=10000 130 def __init__(self, 131 layers, 132 resolution, 133 fs=16000, 134 freq_roi=[50, 7000], 135 noise_gain=1e-3, 136 fmap_start_index=0 137 ): 138 super().__init__() 139 140 141 self.layers = nn.ModuleList(layers) 142 self.resolution = resolution 143 self.fs = fs 144 self.noise_gain = noise_gain 145 self.fmap_start_index = fmap_start_index 146 147 if fmap_start_index >= len(layers): 148 raise ValueError(f'fmap_start_index is larger than number of layers') 149 150 # filter bank for noise shaping 151 n_fft = resolution[0] 152 153 self.filterbank = nn.Parameter( 154 gen_filterbank(n_fft // 2, fs, keep_size=True), 155 requires_grad=False 156 ) 157 158 # roi bins 159 f_step = fs / n_fft 160 self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01)) 161 self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1) 162 163 self.init_weights() 164 165 # determine receptive field size, offsets and strides 166 167 hw = 1000 168 while True: 169 x = torch.zeros((1, hw, hw)) 170 with torch.no_grad(): 171 y = self.run_layer_stack(x)[-1] 172 173 pos0 = [y.size(-2) // 2, y.size(-1) // 2] 174 pos1 = [t + 1 for t in pos0] 175 176 hs0, ws0 = self._receptive_field((hw, hw), pos0) 177 hs1, ws1 = self._receptive_field((hw, hw), pos1) 178 179 h0 = hs0[1] - hs0[0] + 1 180 h1 = hs1[1] - hs1[0] + 1 181 w0 = ws0[1] - ws0[0] + 1 182 w1 = ws1[1] - ws1[0] + 1 183 184 if h0 != h1 or w0 != w1: 185 hw = 2 * hw 186 else: 187 188 # strides 189 sh = hs1[0] - hs0[0] 190 sw = ws1[0] - ws0[0] 191 192 if sh == 0 or sw == 0: continue 193 194 # offsets 195 oh = hs0[0] - sh * pos0[0] 196 ow = ws0[0] - sw * pos0[1] 197 198 # overlap factor 199 overlap = w0 / sw + h0 / sh 200 201 #print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}") 202 self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap} 203 204 break 205 206 if hw > self.RECEPTIVE_FIELD_MAX_WIDTH: 207 print("warning: exceeded max size while trying to determine receptive field") 208 209 # create transposed convolutional kernel 210 #self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False) 211 212 def run_layer_stack(self, spec): 213 214 output = [] 215 216 x = spec.unsqueeze(1) 217 218 for layer in self.layers: 219 x = layer(x) 220 output.append(x) 221 222 return output 223 224 def forward(self, x): 225 """ returns array with feature maps and final score at index -1 """ 226 227 output = [] 228 229 x = self.spectrogram(x) 230 231 output = self.run_layer_stack(x) 232 233 return output[self.fmap_start_index:] 234 235 def receptive_field(self, output_pos): 236 237 if self.receptive_field_params is not None: 238 s, o, h = self.receptive_field_params['height'] 239 h_min = output_pos[0] * s + o + self.start_bin 240 h_max = h_min + h 241 h_min = max(h_min, self.start_bin) 242 h_max = min(h_max, self.stop_bin) 243 244 s, o, w = self.receptive_field_params['width'] 245 w_min = output_pos[1] * s + o 246 w_max = w_min + w 247 248 return (h_min, h_max), (w_min, w_max) 249 250 else: 251 return None, None 252 253 254 def _receptive_field(self, input_dims, output_pos): 255 """ determines receptive field probabilistically via autograd (slow) """ 256 257 x = torch.randn((1,) + input_dims, requires_grad=True) 258 259 # run input through layers 260 y = self.run_layer_stack(x)[-1] 261 b, c, h, w = y.shape 262 263 if output_pos[0] >= h or output_pos[1] >= w: 264 raise ValueError("position out of range") 265 266 mask = torch.zeros((b, c, h, w)) 267 mask[0, 0, output_pos[0], output_pos[1]] = 1 268 269 (mask * y).sum().backward() 270 271 hs, ws = torch.nonzero(x.grad[0], as_tuple=True) 272 273 h_min, h_max = hs.min().item(), hs.max().item() 274 w_min, w_max = ws.min().item(), ws.max().item() 275 276 return [h_min, h_max], [w_min, w_max] 277 278 279 280 def init_weights(self): 281 282 for m in self.modules(): 283 if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): 284 nn.init.orthogonal_(m.weight.data) 285 286 287 def spectrogram(self, x): 288 n_fft, hop_length, win_length = self.resolution 289 x = x.squeeze(1) 290 window = getattr(torch, 'hann_window')(win_length).to(x.device) 291 292 x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\ 293 window=window, return_complex=True) #[B, F, T] 294 x = torch.abs(x) 295 296 # noise floor following spectral envelope 297 smoothed_x = torch.matmul(self.filterbank, x) 298 noise = torch.randn_like(x) * smoothed_x * self.noise_gain 299 x = x + noise 300 301 # frequency ROI 302 x = x[:, self.start_bin : self.stop_bin + 1, ...] 303 304 return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x) 305 306 def grad_map(self, x): 307 self.zero_grad() 308 309 n_fft, hop_length, win_length = self.resolution 310 311 window = getattr(torch, 'hann_window')(win_length).to(x.device) 312 313 y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length, 314 window=window, return_complex=True) #[B, F, T] 315 y = torch.abs(y) 316 317 specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80) 318 319 specgram.requires_grad = True 320 specgram.retain_grad() 321 322 if specgram.grad is not None: 323 specgram.grad.zero_() 324 325 y = specgram[:, self.start_bin : self.stop_bin + 1, ...] 326 327 scores = self.run_layer_stack(y)[-1] 328 329 loss = torch.mean((1 - scores) ** 2) 330 loss.backward() 331 332 return specgram.data[0], torch.abs(specgram.grad)[0] 333 334 def relevance_map(self, x): 335 336 n_fft, hop_length, win_length = self.resolution 337 y = x.view(-1) 338 window = getattr(torch, 'hann_window')(win_length).to(x.device) 339 340 y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\ 341 window=window, return_complex=True) #[B, F, T] 342 y = torch.abs(y) 343 344 specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80) 345 346 347 scores = self.forward(x)[-1] 348 349 sh, _, h = self.receptive_field_params['height'] 350 sw, _, w = self.receptive_field_params['width'] 351 kernel = create_kernel(h, w, sh, sw).float().to(scores.device) 352 with torch.no_grad(): 353 pad_w = (w + sw - 1) // sw 354 pad_h = (h + sh - 1) // sh 355 padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate') 356 # CAVE: padding should be derived from offsets 357 rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2)) 358 rv = rv[..., pad_h * sh : - pad_h * sh, pad_w * sw : -pad_w * sw] 359 360 relevance = torch.zeros_like(specgram) 361 relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv 362 363 364 return specgram, relevance 365 366 367 def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False): 368 """ layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """ 369 370 # ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations 371 372 def newconv2d(layer,g): 373 374 new_layer = nn.Conv2d(layer.in_channels, 375 layer.out_channels, 376 layer.kernel_size, 377 stride=layer.stride, 378 padding=layer.padding, 379 dilation=layer.dilation, 380 groups=layer.groups) 381 382 try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone())) 383 except AttributeError: pass 384 385 try: new_layer.bias = nn.Parameter(g(layer.bias.data.clone())) 386 except AttributeError: pass 387 388 return new_layer 389 390 bounds = { 391 64: [-85.82449722290039, 2.1755014657974243], 392 128: [-84.49211349487305, 3.5078893899917607], 393 256: [-80.33127822875977, 7.6687201976776125], 394 512: [-73.79328079223633, 14.20672025680542], 395 1024: [-67.59239501953125, 20.40760498046875], 396 2048: [-62.31902580261231, 25.680974197387698], 397 } 398 399 nfft = self.resolution[0] 400 if low is None: low = bounds[nfft][0] 401 if high is None: high = bounds[nfft][1] 402 403 remove_all_weight_norms(self) 404 405 for p in self.parameters(): 406 if p.grad is not None: 407 p.grad.zero_() 408 409 num_layers = len(self.layers) 410 X = self.spectrogram(x). detach() 411 412 413 # forward pass 414 A = [X.unsqueeze(1)] + [None] * len(self.layers) 415 416 for i in range(num_layers - 1): 417 A[i + 1] = self.layers[i](A[i]) 418 419 # initial relevance is last layer without activation 420 r = A[-2] 421 last_layer_rs = [r] 422 layer = self.layers[-1] 423 for sublayer in list(layer)[:-1]: 424 r = sublayer(r) 425 last_layer_rs.append(r) 426 427 428 mask = torch.zeros_like(r) 429 mask.requires_grad_(False) 430 if verbose: 431 print(r.min(), r.max()) 432 if label in {'both', 'fake'}: 433 mask[r < -threshold] = 1 434 if label in {'both', 'real'}: 435 mask[r > threshold] = 1 436 r = r * mask 437 438 # backward pass 439 R = [None] * num_layers + [r] 440 441 for l in range(1, num_layers)[::-1]: 442 A[l] = (A[l]).data.requires_grad_(True) 443 444 layer = nn.Sequential(*(list(self.layers[l])[:-1])) 445 z = layer(A[l]) + eps 446 s = (R[l+1] / z).data 447 (z*s).sum().backward() 448 c = A[l].grad 449 R[l] = (A[l] * c).data 450 451 # first layer 452 A[0] = (A[0].data).requires_grad_(True) 453 454 Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True) 455 Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True) 456 457 if len(list(self.layers)) > 2: 458 # unsafe way to check for embedding layer 459 embed = list(self.layers[0])[0] 460 conv = list(self.layers[0])[1] 461 462 layer = nn.Sequential(embed, conv) 463 layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0))) 464 layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0))) 465 466 else: 467 layer = list(self.layers[0])[0] 468 layerl = newconv2d(layer, lambda p: p.clamp(min=0)) 469 layerh = newconv2d(layer, lambda p: p.clamp(max=0)) 470 471 472 z = layer(A[0]) 473 z -= layerl(Xl) + layerh(Xh) 474 s = (R[1] / z).data 475 (z * s).sum().backward() 476 c, cp, cm = A[0].grad, Xl.grad, Xh.grad 477 478 R[0] = (A[0] * c + Xl * cp + Xh * cm) 479 #R[0] = (A[0] * c).data 480 481 return X, R[0].mean(dim=1) 482 483 484 485 486 487 488 489 490 491 492def create_3x3_conv_plan(num_layers : int, 493 f_stretch : int, 494 f_down : int, 495 t_stretch : int, 496 t_down : int 497 ): 498 499 500 """ creates a stride, dilation, padding plan for a 2d conv network 501 502 Args: 503 num_layers (int): number of layers 504 f_stretch (int): log_2 of stretching factor along frequency axis 505 f_down (int): log_2 of downsampling factor along frequency axis 506 t_stretch (int): log_2 of stretching factor along time axis 507 t_down (int): log_2 of downsampling factor along time axis 508 509 Returns: 510 list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)] 511 """ 512 513 assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0 514 assert f_stretch < num_layers and t_stretch < num_layers 515 516 def process_dimension(n_layers, stretch, down): 517 518 stack_layers = n_layers - 1 519 520 stride_layers = min(min(down, stretch) , stack_layers) 521 dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0) 522 final_stride = 2 ** (max(down - stride_layers, 0)) 523 524 final_dilation = 1 525 if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0: 526 final_dilation = 2 527 528 strides, dilations, paddings = [], [], [] 529 processed_layers = 0 530 current_dilation = 1 531 532 for _ in range(stride_layers): 533 # increase receptive field and downsample via stride = 2 534 strides.append(2) 535 dilations.append(1) 536 paddings.append(1) 537 processed_layers += 1 538 539 if processed_layers < stack_layers: 540 strides.append(1) 541 dilations.append(1) 542 paddings.append(1) 543 processed_layers += 1 544 545 for _ in range(dilation_layers): 546 # increase receptive field via dilation = 2 547 strides.append(1) 548 current_dilation *= 2 549 dilations.append(current_dilation) 550 paddings.append(current_dilation) 551 processed_layers += 1 552 553 while processed_layers < n_layers - 1: 554 # fill up with std layers 555 strides.append(1) 556 dilations.append(current_dilation) 557 paddings.append(current_dilation) 558 processed_layers += 1 559 560 # final layer 561 strides.append(final_stride) 562 current_dilation * final_dilation 563 dilations.append(current_dilation) 564 paddings.append(current_dilation) 565 processed_layers += 1 566 567 assert processed_layers == n_layers 568 569 return strides, dilations, paddings 570 571 t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down) 572 f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down) 573 574 plan = [] 575 576 for i in range(num_layers): 577 plan.append([ 578 (f_strides[i], t_strides[i]), 579 (f_dilations[i], t_dilations[i]), 580 (f_paddings[i], t_paddings[i]), 581 ]) 582 583 return plan 584 585 586class DiscriminatorExperimental(SpecDiscriminatorBase): 587 588 def __init__(self, 589 resolution, 590 fs=16000, 591 freq_roi=[50, 7400], 592 noise_gain=0, 593 num_channels=16, 594 max_channels=512, 595 num_layers=5, 596 use_spectral_norm=False): 597 598 norm_f = weight_norm if use_spectral_norm == False else spectral_norm 599 600 self.num_channels = num_channels 601 self.num_channels_max = max_channels 602 self.num_layers = num_layers 603 604 layers = [] 605 stride = (2, 1) 606 padding= (1, 1) 607 in_channels = 1 + 2 608 out_channels = self.num_channels 609 for _ in range(self.num_layers): 610 layers.append( 611 nn.Sequential( 612 FrequencyPositionalEmbedding(), 613 norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)), 614 nn.ReLU(inplace=True) 615 ) 616 ) 617 in_channels = out_channels + 2 618 out_channels = min(2 * out_channels, self.num_channels_max) 619 620 layers.append( 621 nn.Sequential( 622 FrequencyPositionalEmbedding(), 623 norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)), 624 nn.Sigmoid() 625 ) 626 ) 627 628 super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) 629 630 # bias biases 631 bias_val = 0.1 632 with torch.no_grad(): 633 for name, weight in self.named_parameters(): 634 if 'bias' in name: 635 weight = weight + bias_val 636 637 638configs = { 639 'f_down': { 640 'stretch' : { 641 64 : (0, 0), 642 128: (1, 0), 643 256: (2, 0), 644 512: (3, 0), 645 1024: (4, 0), 646 2048: (5, 0) 647 }, 648 'down' : { 649 64 : (0, 0), 650 128: (1, 0), 651 256: (2, 0), 652 512: (3, 0), 653 1024: (4, 0), 654 2048: (5, 0) 655 } 656 }, 657 'ft_down': { 658 'stretch' : { 659 64 : (0, 4), 660 128: (1, 3), 661 256: (2, 2), 662 512: (3, 1), 663 1024: (4, 0), 664 2048: (5, 0) 665 }, 666 'down' : { 667 64 : (0, 4), 668 128: (1, 3), 669 256: (2, 2), 670 512: (3, 1), 671 1024: (4, 0), 672 2048: (5, 0) 673 } 674 }, 675 'dilated': { 676 'stretch' : { 677 64 : (0, 4), 678 128: (1, 3), 679 256: (2, 2), 680 512: (3, 1), 681 1024: (4, 0), 682 2048: (5, 0) 683 }, 684 'down' : { 685 64 : (0, 0), 686 128: (0, 0), 687 256: (0, 0), 688 512: (0, 0), 689 1024: (0, 0), 690 2048: (0, 0) 691 } 692 }, 693 'mixed': { 694 'stretch' : { 695 64 : (0, 4), 696 128: (1, 3), 697 256: (2, 2), 698 512: (3, 1), 699 1024: (4, 0), 700 2048: (5, 0) 701 }, 702 'down' : { 703 64 : (0, 0), 704 128: (1, 0), 705 256: (2, 0), 706 512: (3, 0), 707 1024: (4, 0), 708 2048: (5, 0) 709 } 710 }, 711} 712 713 714class DiscriminatorMagFree(SpecDiscriminatorBase): 715 716 def __init__(self, 717 resolution, 718 fs=16000, 719 freq_roi=[50, 7400], 720 noise_gain=0, 721 num_channels=16, 722 max_channels=256, 723 num_layers=5, 724 use_spectral_norm=False, 725 design=None): 726 727 if design is None: 728 raise ValueError('error: arch required in DiscriminatorMagFree') 729 730 norm_f = weight_norm if use_spectral_norm == False else spectral_norm 731 732 stretch = configs[design]['stretch'][resolution[0]] 733 down = configs[design]['down'][resolution[0]] 734 735 self.num_channels = num_channels 736 self.num_channels_max = max_channels 737 self.num_layers = num_layers 738 self.stretch = stretch 739 self.down = down 740 741 layers = [] 742 plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1]) 743 in_channels = 1 + 2 744 out_channels = self.num_channels 745 for i in range(self.num_layers): 746 layers.append( 747 nn.Sequential( 748 FrequencyPositionalEmbedding(), 749 norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])), 750 nn.ReLU(inplace=True) 751 ) 752 ) 753 in_channels = out_channels + 2 754 # product over strides 755 channel_factor = plan[i][0][0] * plan[i][0][1] 756 out_channels = min(channel_factor * out_channels, self.num_channels_max) 757 758 layers.append( 759 nn.Sequential( 760 FrequencyPositionalEmbedding(), 761 norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])), 762 nn.Sigmoid() 763 ) 764 ) 765 766 767 768 # for layer in layers: 769 # print(layer) 770 771 # print("end\n\n") 772 773 super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) 774 775 # bias biases 776 bias_val = 0.1 777 with torch.no_grad(): 778 for name, weight in self.named_parameters(): 779 if 'bias' in name: 780 weight = weight + bias_val 781 782class DiscriminatorMagFreqPosition(SpecDiscriminatorBase): 783 784 def __init__(self, 785 resolution, 786 fs=16000, 787 freq_roi=[50, 7400], 788 noise_gain=0, 789 num_channels=16, 790 max_channels=512, 791 num_layers=5, 792 use_spectral_norm=False): 793 794 norm_f = weight_norm if use_spectral_norm == False else spectral_norm 795 796 self.num_channels = num_channels 797 self.num_channels_max = max_channels 798 self.num_layers = num_layers 799 800 layers = [] 801 stride = (2, 1) 802 padding= (1, 1) 803 in_channels = 1 + 2 804 out_channels = self.num_channels 805 for _ in range(self.num_layers): 806 layers.append( 807 nn.Sequential( 808 FrequencyPositionalEmbedding(), 809 norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)), 810 nn.LeakyReLU(0.2, inplace=True) 811 ) 812 ) 813 in_channels = out_channels + 2 814 out_channels = min(2 * out_channels, self.num_channels_max) 815 816 layers.append( 817 nn.Sequential( 818 FrequencyPositionalEmbedding(), 819 norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)) 820 ) 821 ) 822 823 super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) 824 825 826 827class DiscriminatorMag2dPositional(SpecDiscriminatorBase): 828 829 def __init__(self, 830 resolution, 831 fs=16000, 832 freq_roi=[50, 7400], 833 noise_gain=0, 834 num_channels=16, 835 max_channels=512, 836 num_layers=5, 837 d=5, 838 use_spectral_norm=False): 839 840 norm_f = weight_norm if use_spectral_norm == False else spectral_norm 841 self.resolution = resolution 842 self.num_channels = num_channels 843 self.num_channels_max = max_channels 844 self.num_layers = num_layers 845 self.d = d 846 embedding_dim = 4 * d 847 848 849 layers = [] 850 stride = (2, 2) 851 padding= (1, 1) 852 in_channels = 1 + embedding_dim 853 out_channels = self.num_channels 854 for _ in range(self.num_layers): 855 layers.append( 856 nn.Sequential( 857 PositionalEmbedding2D(d), 858 norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)), 859 nn.LeakyReLU(0.2, inplace=True) 860 ) 861 ) 862 in_channels = out_channels + embedding_dim 863 out_channels = min(2 * out_channels, self.num_channels_max) 864 865 866 layers.append( 867 nn.Sequential( 868 PositionalEmbedding2D(), 869 norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)) 870 ) 871 ) 872 873 super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) 874 875 876 877class DiscriminatorMag(SpecDiscriminatorBase): 878 def __init__(self, 879 resolution, 880 fs=16000, 881 freq_roi=[50, 7400], 882 noise_gain=0, 883 num_channels=32, 884 num_layers=5, 885 use_spectral_norm=False): 886 887 norm_f = weight_norm if use_spectral_norm == False else spectral_norm 888 889 self.num_channels = num_channels 890 self.num_layers = num_layers 891 892 layers = [] 893 stride = (1, 1) 894 padding= (1, 1) 895 in_channels = 1 896 out_channels = self.num_channels 897 for _ in range(self.num_layers): 898 layers.append( 899 nn.Sequential( 900 norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)), 901 nn.LeakyReLU(0.2, inplace=True) 902 ) 903 ) 904 in_channels = out_channels 905 906 layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))) 907 908 super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) 909 910 911discriminators = { 912 'mag': DiscriminatorMag, 913 'freqpos': DiscriminatorMagFreqPosition, 914 '2dpos': DiscriminatorMag2dPositional, 915 'experimental': DiscriminatorExperimental, 916 'free': DiscriminatorMagFree 917} 918 919class TFDMultiResolutionDiscriminator(torch.nn.Module): 920 def __init__(self, 921 fft_sizes_16k=[64, 128, 256, 512, 1024, 2048], 922 architecture='mag', 923 fs=16000, 924 freq_roi=[50, 7400], 925 noise_gain=0, 926 use_spectral_norm=False, 927 **kwargs): 928 929 super().__init__() 930 931 932 fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k] 933 934 resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes] 935 936 937 Disc = discriminators[architecture] 938 939 discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))] 940 941 self.discriminators = nn.ModuleList(discs) 942 943 def forward(self, y): 944 outputs = [] 945 946 for disc in self.discriminators: 947 outputs.append(disc(y)) 948 949 return outputs 950 951 952class FWGAN_disc_wrapper(nn.Module): 953 def __init__(self, disc): 954 super().__init__() 955 956 self.disc = disc 957 958 def forward(self, y, y_hat): 959 960 out_real = self.disc(y) 961 out_fake = self.disc(y_hat) 962 963 y_d_rs = [] 964 y_d_gs = [] 965 fmap_rs = [] 966 fmap_gs = [] 967 968 for y_real, y_fake in zip(out_real, out_fake): 969 y_d_rs.append(y_real[-1]) 970 y_d_gs.append(y_fake[-1]) 971 fmap_rs.append(y_real[:-1]) 972 fmap_gs.append(y_fake[:-1]) 973 974 return y_d_rs, y_d_gs, fmap_rs, fmap_gs 975