xref: /aosp_15_r20/external/libopus/dnn/torch/osce/models/fd_discriminator.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import math as m
31import copy
32
33import torch
34import torch.nn.functional as F
35from torch import nn
36from torch.nn.utils import weight_norm, spectral_norm
37import torchaudio
38
39from utils.spec import gen_filterbank
40
41# auxiliary functions
42
43def remove_all_weight_norms(module):
44    for m in module.modules():
45        if hasattr(m, 'weight_v'):
46            nn.utils.remove_weight_norm(m)
47
48
49def create_smoothing_kernel(h, w, gamma=1.5):
50
51    ch = h / 2 - 0.5
52    cw = w / 2 - 0.5
53
54    sh = gamma * ch
55    sw = gamma * cw
56
57    vx = ((torch.arange(h) - ch) / sh) ** 2
58    vy = ((torch.arange(w) - cw) / sw) ** 2
59    vals = vx.view(-1, 1) + vy.view(1, -1)
60    kernel = torch.exp(- vals)
61    kernel = kernel / kernel.sum()
62
63    return kernel
64
65
66def create_kernel(h, w, sh, sw):
67    # proto kernel gives disjoint partition of 1
68    proto_kernel = torch.ones((sh, sw))
69
70    # create smoothing kernel eta
71    h_eta, w_eta = h - sh + 1, w - sw + 1
72    assert h_eta > 0 and w_eta > 0
73    eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta)
74
75    kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0)
76    kernel = F.conv2d(kernel0, eta)
77
78    return kernel
79
80# positional embeddings
81class FrequencyPositionalEmbedding(nn.Module):
82    def __init__(self):
83
84        super().__init__()
85
86    def forward(self, x):
87
88        N = x.size(2)
89        args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N
90        cos = torch.cos(args).reshape(1, 1, -1, 1)
91        sin = torch.sin(args).reshape(1, 1, -1, 1)
92        zeros = torch.zeros_like(x[:, 0:1, :, :])
93
94        y = torch.cat((x, zeros + sin, zeros + cos), dim=1)
95
96        return y
97
98
99class PositionalEmbedding2D(nn.Module):
100    def __init__(self, d=5):
101
102        super().__init__()
103
104        self.d = d
105
106    def forward(self, x):
107
108        N = x.size(2)
109        M = x.size(3)
110
111        h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1)
112        w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1)
113        coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1)
114
115        h_sin = torch.sin(coeffs * h_args)
116        h_cos = torch.sin(coeffs * h_args)
117        w_sin = torch.sin(coeffs * w_args)
118        w_cos = torch.sin(coeffs * w_args)
119
120        zeros = torch.zeros_like(x[:, 0:1, :, :])
121
122        y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1)
123
124        return y
125
126
127# spectral discriminator base class
128class SpecDiscriminatorBase(nn.Module):
129    RECEPTIVE_FIELD_MAX_WIDTH=10000
130    def __init__(self,
131                 layers,
132                 resolution,
133                 fs=16000,
134                 freq_roi=[50, 7000],
135                 noise_gain=1e-3,
136                 fmap_start_index=0
137                 ):
138        super().__init__()
139
140
141        self.layers = nn.ModuleList(layers)
142        self.resolution = resolution
143        self.fs = fs
144        self.noise_gain = noise_gain
145        self.fmap_start_index = fmap_start_index
146
147        if fmap_start_index >= len(layers):
148            raise ValueError(f'fmap_start_index is larger than number of layers')
149
150        # filter bank for noise shaping
151        n_fft = resolution[0]
152
153        self.filterbank = nn.Parameter(
154            gen_filterbank(n_fft // 2, fs, keep_size=True),
155            requires_grad=False
156        )
157
158        # roi bins
159        f_step = fs / n_fft
160        self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01))
161        self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1)
162
163        self.init_weights()
164
165        # determine receptive field size, offsets and strides
166
167        hw = 1000
168        while True:
169            x = torch.zeros((1, hw, hw))
170            with torch.no_grad():
171                y = self.run_layer_stack(x)[-1]
172
173            pos0 = [y.size(-2) // 2, y.size(-1) // 2]
174            pos1 = [t + 1 for t in pos0]
175
176            hs0, ws0 = self._receptive_field((hw, hw), pos0)
177            hs1, ws1 = self._receptive_field((hw, hw), pos1)
178
179            h0 = hs0[1] - hs0[0] + 1
180            h1 = hs1[1] - hs1[0] + 1
181            w0 = ws0[1] - ws0[0] + 1
182            w1 = ws1[1] - ws1[0] + 1
183
184            if h0 != h1 or w0 != w1:
185                hw = 2 * hw
186            else:
187
188                # strides
189                sh = hs1[0] - hs0[0]
190                sw = ws1[0] - ws0[0]
191
192                if sh == 0 or sw == 0: continue
193
194                # offsets
195                oh = hs0[0] - sh * pos0[0]
196                ow = ws0[0] - sw * pos0[1]
197
198                # overlap factor
199                overlap = w0 / sw + h0 / sh
200
201                #print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}")
202                self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap}
203
204                break
205
206            if hw > self.RECEPTIVE_FIELD_MAX_WIDTH:
207                print("warning: exceeded max size while trying to determine receptive field")
208
209        # create transposed convolutional kernel
210        #self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False)
211
212    def run_layer_stack(self, spec):
213
214        output = []
215
216        x = spec.unsqueeze(1)
217
218        for layer in self.layers:
219            x = layer(x)
220            output.append(x)
221
222        return output
223
224    def forward(self, x):
225        """ returns array with feature maps and final score at index -1 """
226
227        output = []
228
229        x = self.spectrogram(x)
230
231        output = self.run_layer_stack(x)
232
233        return output[self.fmap_start_index:]
234
235    def receptive_field(self, output_pos):
236
237        if self.receptive_field_params is not None:
238            s, o, h = self.receptive_field_params['height']
239            h_min = output_pos[0] * s + o + self.start_bin
240            h_max = h_min + h
241            h_min = max(h_min, self.start_bin)
242            h_max = min(h_max, self.stop_bin)
243
244            s, o, w = self.receptive_field_params['width']
245            w_min = output_pos[1] * s + o
246            w_max = w_min + w
247
248            return (h_min, h_max), (w_min, w_max)
249
250        else:
251            return None, None
252
253
254    def _receptive_field(self, input_dims, output_pos):
255        """ determines receptive field probabilistically via autograd (slow) """
256
257        x = torch.randn((1,) + input_dims, requires_grad=True)
258
259        # run input through layers
260        y = self.run_layer_stack(x)[-1]
261        b, c, h, w = y.shape
262
263        if output_pos[0] >= h or output_pos[1] >= w:
264            raise ValueError("position out of range")
265
266        mask = torch.zeros((b, c, h, w))
267        mask[0, 0, output_pos[0], output_pos[1]] = 1
268
269        (mask * y).sum().backward()
270
271        hs, ws = torch.nonzero(x.grad[0], as_tuple=True)
272
273        h_min, h_max = hs.min().item(), hs.max().item()
274        w_min, w_max = ws.min().item(), ws.max().item()
275
276        return [h_min, h_max], [w_min, w_max]
277
278
279
280    def init_weights(self):
281
282        for m in self.modules():
283            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
284                nn.init.orthogonal_(m.weight.data)
285
286
287    def spectrogram(self, x):
288        n_fft, hop_length, win_length = self.resolution
289        x = x.squeeze(1)
290        window = getattr(torch, 'hann_window')(win_length).to(x.device)
291
292        x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
293                       window=window, return_complex=True) #[B, F, T]
294        x = torch.abs(x)
295
296        # noise floor following spectral envelope
297        smoothed_x = torch.matmul(self.filterbank, x)
298        noise = torch.randn_like(x) * smoothed_x * self.noise_gain
299        x = x + noise
300
301        # frequency ROI
302        x = x[:, self.start_bin : self.stop_bin + 1, ...]
303
304        return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x)
305
306    def grad_map(self, x):
307        self.zero_grad()
308
309        n_fft, hop_length, win_length = self.resolution
310
311        window = getattr(torch, 'hann_window')(win_length).to(x.device)
312
313        y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length,
314                       window=window, return_complex=True) #[B, F, T]
315        y = torch.abs(y)
316
317        specgram  = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
318
319        specgram.requires_grad = True
320        specgram.retain_grad()
321
322        if specgram.grad is not None:
323            specgram.grad.zero_()
324
325        y = specgram[:, self.start_bin : self.stop_bin + 1, ...]
326
327        scores = self.run_layer_stack(y)[-1]
328
329        loss = torch.mean((1 - scores) ** 2)
330        loss.backward()
331
332        return specgram.data[0], torch.abs(specgram.grad)[0]
333
334    def relevance_map(self, x):
335
336        n_fft, hop_length, win_length = self.resolution
337        y = x.view(-1)
338        window = getattr(torch, 'hann_window')(win_length).to(x.device)
339
340        y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\
341                       window=window, return_complex=True) #[B, F, T]
342        y = torch.abs(y)
343
344        specgram  = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
345
346
347        scores = self.forward(x)[-1]
348
349        sh, _, h = self.receptive_field_params['height']
350        sw, _, w = self.receptive_field_params['width']
351        kernel = create_kernel(h, w, sh, sw).float().to(scores.device)
352        with torch.no_grad():
353            pad_w = (w + sw - 1) // sw
354            pad_h = (h + sh - 1) // sh
355            padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
356            # CAVE: padding should be derived from offsets
357            rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2))
358            rv = rv[..., pad_h * sh : - pad_h * sh,  pad_w * sw : -pad_w * sw]
359
360            relevance = torch.zeros_like(specgram)
361            relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv
362
363
364        return specgram, relevance
365
366
367    def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False):
368        """ layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """
369
370        # ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations
371
372        def newconv2d(layer,g):
373
374            new_layer = nn.Conv2d(layer.in_channels,
375                                  layer.out_channels,
376                                  layer.kernel_size,
377                                  stride=layer.stride,
378                                  padding=layer.padding,
379                                  dilation=layer.dilation,
380                                  groups=layer.groups)
381
382            try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone()))
383            except AttributeError: pass
384
385            try: new_layer.bias   = nn.Parameter(g(layer.bias.data.clone()))
386            except AttributeError: pass
387
388            return new_layer
389
390        bounds = {
391            64: [-85.82449722290039, 2.1755014657974243],
392            128: [-84.49211349487305, 3.5078893899917607],
393            256: [-80.33127822875977, 7.6687201976776125],
394            512: [-73.79328079223633, 14.20672025680542],
395            1024: [-67.59239501953125, 20.40760498046875],
396            2048: [-62.31902580261231, 25.680974197387698],
397        }
398
399        nfft = self.resolution[0]
400        if low is None: low = bounds[nfft][0]
401        if high is None: high = bounds[nfft][1]
402
403        remove_all_weight_norms(self)
404
405        for p in self.parameters():
406            if p.grad is not None:
407                p.grad.zero_()
408
409        num_layers = len(self.layers)
410        X = self.spectrogram(x). detach()
411
412
413        # forward pass
414        A = [X.unsqueeze(1)] + [None] * len(self.layers)
415
416        for i in range(num_layers - 1):
417            A[i + 1] = self.layers[i](A[i])
418
419        # initial relevance is last layer without activation
420        r = A[-2]
421        last_layer_rs = [r]
422        layer = self.layers[-1]
423        for sublayer in list(layer)[:-1]:
424            r = sublayer(r)
425            last_layer_rs.append(r)
426
427
428        mask = torch.zeros_like(r)
429        mask.requires_grad_(False)
430        if verbose:
431            print(r.min(), r.max())
432        if label in {'both', 'fake'}:
433            mask[r < -threshold] = 1
434        if label in {'both', 'real'}:
435            mask[r > threshold] = 1
436        r = r * mask
437
438        # backward pass
439        R = [None] * num_layers + [r]
440
441        for l in range(1, num_layers)[::-1]:
442            A[l] = (A[l]).data.requires_grad_(True)
443
444            layer = nn.Sequential(*(list(self.layers[l])[:-1]))
445            z = layer(A[l]) + eps
446            s = (R[l+1] / z).data
447            (z*s).sum().backward()
448            c = A[l].grad
449            R[l] = (A[l] * c).data
450
451        # first layer
452        A[0] = (A[0].data).requires_grad_(True)
453
454        Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True)
455        Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True)
456
457        if len(list(self.layers)) > 2:
458            # unsafe way to check for embedding layer
459            embed = list(self.layers[0])[0]
460            conv  = list(self.layers[0])[1]
461
462            layer = nn.Sequential(embed, conv)
463            layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0)))
464            layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0)))
465
466        else:
467            layer = list(self.layers[0])[0]
468            layerl = newconv2d(layer, lambda p: p.clamp(min=0))
469            layerh = newconv2d(layer, lambda p: p.clamp(max=0))
470
471
472        z = layer(A[0])
473        z -= layerl(Xl) + layerh(Xh)
474        s = (R[1] / z).data
475        (z * s).sum().backward()
476        c, cp, cm = A[0].grad, Xl.grad, Xh.grad
477
478        R[0] = (A[0] * c + Xl * cp + Xh * cm)
479        #R[0] = (A[0] * c).data
480
481        return X, R[0].mean(dim=1)
482
483
484
485
486
487
488
489
490
491
492def create_3x3_conv_plan(num_layers : int,
493                         f_stretch : int,
494                         f_down : int,
495                         t_stretch : int,
496                         t_down : int
497                         ):
498
499
500    """ creates a stride, dilation, padding plan for a 2d conv network
501
502    Args:
503        num_layers (int): number of layers
504        f_stretch (int): log_2 of stretching factor along frequency axis
505        f_down (int): log_2 of downsampling factor along frequency axis
506        t_stretch (int): log_2 of stretching factor along time axis
507        t_down (int): log_2 of downsampling factor along time axis
508
509    Returns:
510        list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)]
511    """
512
513    assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0
514    assert f_stretch < num_layers and t_stretch < num_layers
515
516    def process_dimension(n_layers, stretch, down):
517
518        stack_layers = n_layers - 1
519
520        stride_layers = min(min(down, stretch) , stack_layers)
521        dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0)
522        final_stride = 2 ** (max(down - stride_layers, 0))
523
524        final_dilation = 1
525        if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0:
526                final_dilation = 2
527
528        strides, dilations, paddings = [], [], []
529        processed_layers = 0
530        current_dilation = 1
531
532        for _ in range(stride_layers):
533            # increase receptive field and downsample via stride = 2
534            strides.append(2)
535            dilations.append(1)
536            paddings.append(1)
537            processed_layers += 1
538
539        if processed_layers < stack_layers:
540            strides.append(1)
541            dilations.append(1)
542            paddings.append(1)
543            processed_layers += 1
544
545        for _ in range(dilation_layers):
546            # increase receptive field via dilation = 2
547            strides.append(1)
548            current_dilation *= 2
549            dilations.append(current_dilation)
550            paddings.append(current_dilation)
551            processed_layers += 1
552
553        while processed_layers < n_layers - 1:
554            # fill up with std layers
555            strides.append(1)
556            dilations.append(current_dilation)
557            paddings.append(current_dilation)
558            processed_layers += 1
559
560        # final layer
561        strides.append(final_stride)
562        current_dilation * final_dilation
563        dilations.append(current_dilation)
564        paddings.append(current_dilation)
565        processed_layers += 1
566
567        assert processed_layers == n_layers
568
569        return strides, dilations, paddings
570
571    t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down)
572    f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down)
573
574    plan = []
575
576    for i in range(num_layers):
577        plan.append([
578            (f_strides[i], t_strides[i]),
579            (f_dilations[i], t_dilations[i]),
580            (f_paddings[i], t_paddings[i]),
581            ])
582
583    return plan
584
585
586class DiscriminatorExperimental(SpecDiscriminatorBase):
587
588    def __init__(self,
589                 resolution,
590                 fs=16000,
591                 freq_roi=[50, 7400],
592                 noise_gain=0,
593                 num_channels=16,
594                 max_channels=512,
595                 num_layers=5,
596                 use_spectral_norm=False):
597
598        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
599
600        self.num_channels = num_channels
601        self.num_channels_max = max_channels
602        self.num_layers = num_layers
603
604        layers = []
605        stride = (2, 1)
606        padding= (1, 1)
607        in_channels = 1 + 2
608        out_channels = self.num_channels
609        for _ in range(self.num_layers):
610            layers.append(
611                nn.Sequential(
612                    FrequencyPositionalEmbedding(),
613                    norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
614                    nn.ReLU(inplace=True)
615                )
616            )
617            in_channels = out_channels + 2
618            out_channels = min(2 * out_channels, self.num_channels_max)
619
620        layers.append(
621            nn.Sequential(
622                FrequencyPositionalEmbedding(),
623                norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)),
624                nn.Sigmoid()
625            )
626        )
627
628        super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
629
630        # bias biases
631        bias_val = 0.1
632        with torch.no_grad():
633            for name, weight in self.named_parameters():
634                if 'bias' in name:
635                    weight = weight + bias_val
636
637
638configs = {
639    'f_down': {
640        'stretch' : {
641            64 : (0, 0),
642            128: (1, 0),
643            256: (2, 0),
644            512: (3, 0),
645            1024: (4, 0),
646            2048: (5, 0)
647        },
648        'down' : {
649            64 : (0, 0),
650            128: (1, 0),
651            256: (2, 0),
652            512: (3, 0),
653            1024: (4, 0),
654            2048: (5, 0)
655        }
656    },
657    'ft_down': {
658        'stretch' : {
659            64 : (0, 4),
660            128: (1, 3),
661            256: (2, 2),
662            512: (3, 1),
663            1024: (4, 0),
664            2048: (5, 0)
665        },
666        'down' : {
667            64 : (0, 4),
668            128: (1, 3),
669            256: (2, 2),
670            512: (3, 1),
671            1024: (4, 0),
672            2048: (5, 0)
673        }
674    },
675    'dilated': {
676        'stretch' : {
677            64 : (0, 4),
678            128: (1, 3),
679            256: (2, 2),
680            512: (3, 1),
681            1024: (4, 0),
682            2048: (5, 0)
683        },
684        'down' : {
685            64 : (0, 0),
686            128: (0, 0),
687            256: (0, 0),
688            512: (0, 0),
689            1024: (0, 0),
690            2048: (0, 0)
691        }
692    },
693    'mixed': {
694        'stretch' : {
695            64 : (0, 4),
696            128: (1, 3),
697            256: (2, 2),
698            512: (3, 1),
699            1024: (4, 0),
700            2048: (5, 0)
701        },
702        'down' : {
703            64 : (0, 0),
704            128: (1, 0),
705            256: (2, 0),
706            512: (3, 0),
707            1024: (4, 0),
708            2048: (5, 0)
709        }
710    },
711}
712
713
714class DiscriminatorMagFree(SpecDiscriminatorBase):
715
716    def __init__(self,
717                 resolution,
718                 fs=16000,
719                 freq_roi=[50, 7400],
720                 noise_gain=0,
721                 num_channels=16,
722                 max_channels=256,
723                 num_layers=5,
724                 use_spectral_norm=False,
725                 design=None):
726
727        if design is None:
728            raise ValueError('error: arch required in DiscriminatorMagFree')
729
730        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
731
732        stretch = configs[design]['stretch'][resolution[0]]
733        down = configs[design]['down'][resolution[0]]
734
735        self.num_channels = num_channels
736        self.num_channels_max = max_channels
737        self.num_layers = num_layers
738        self.stretch = stretch
739        self.down = down
740
741        layers = []
742        plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1])
743        in_channels = 1 + 2
744        out_channels = self.num_channels
745        for i in range(self.num_layers):
746            layers.append(
747                nn.Sequential(
748                    FrequencyPositionalEmbedding(),
749                    norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])),
750                    nn.ReLU(inplace=True)
751                )
752            )
753            in_channels = out_channels + 2
754            # product over strides
755            channel_factor = plan[i][0][0] * plan[i][0][1]
756            out_channels = min(channel_factor * out_channels, self.num_channels_max)
757
758        layers.append(
759            nn.Sequential(
760                FrequencyPositionalEmbedding(),
761                norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])),
762                nn.Sigmoid()
763            )
764        )
765
766
767
768        # for layer in layers:
769        #     print(layer)
770
771        # print("end\n\n")
772
773        super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
774
775        # bias biases
776        bias_val = 0.1
777        with torch.no_grad():
778            for name, weight in self.named_parameters():
779                if 'bias' in name:
780                    weight = weight + bias_val
781
782class DiscriminatorMagFreqPosition(SpecDiscriminatorBase):
783
784    def __init__(self,
785                 resolution,
786                 fs=16000,
787                 freq_roi=[50, 7400],
788                 noise_gain=0,
789                 num_channels=16,
790                 max_channels=512,
791                 num_layers=5,
792                 use_spectral_norm=False):
793
794        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
795
796        self.num_channels = num_channels
797        self.num_channels_max = max_channels
798        self.num_layers = num_layers
799
800        layers = []
801        stride = (2, 1)
802        padding= (1, 1)
803        in_channels = 1 + 2
804        out_channels = self.num_channels
805        for _ in range(self.num_layers):
806            layers.append(
807                nn.Sequential(
808                    FrequencyPositionalEmbedding(),
809                    norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
810                    nn.LeakyReLU(0.2, inplace=True)
811                )
812            )
813            in_channels = out_channels + 2
814            out_channels = min(2 * out_channels, self.num_channels_max)
815
816        layers.append(
817            nn.Sequential(
818                FrequencyPositionalEmbedding(),
819                norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
820            )
821        )
822
823        super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
824
825
826
827class DiscriminatorMag2dPositional(SpecDiscriminatorBase):
828
829    def __init__(self,
830                 resolution,
831                 fs=16000,
832                 freq_roi=[50, 7400],
833                 noise_gain=0,
834                 num_channels=16,
835                 max_channels=512,
836                 num_layers=5,
837                 d=5,
838                 use_spectral_norm=False):
839
840        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
841        self.resolution = resolution
842        self.num_channels = num_channels
843        self.num_channels_max = max_channels
844        self.num_layers = num_layers
845        self.d = d
846        embedding_dim = 4 * d
847
848
849        layers = []
850        stride = (2, 2)
851        padding= (1, 1)
852        in_channels = 1 + embedding_dim
853        out_channels = self.num_channels
854        for _ in range(self.num_layers):
855            layers.append(
856                nn.Sequential(
857                    PositionalEmbedding2D(d),
858                    norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
859                    nn.LeakyReLU(0.2, inplace=True)
860                )
861            )
862            in_channels = out_channels + embedding_dim
863            out_channels = min(2 * out_channels, self.num_channels_max)
864
865
866        layers.append(
867            nn.Sequential(
868                PositionalEmbedding2D(),
869                norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))
870            )
871        )
872
873        super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
874
875
876
877class DiscriminatorMag(SpecDiscriminatorBase):
878    def __init__(self,
879                 resolution,
880                 fs=16000,
881                 freq_roi=[50, 7400],
882                 noise_gain=0,
883                 num_channels=32,
884                 num_layers=5,
885                 use_spectral_norm=False):
886
887        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
888
889        self.num_channels = num_channels
890        self.num_layers = num_layers
891
892        layers = []
893        stride = (1, 1)
894        padding= (1, 1)
895        in_channels = 1
896        out_channels = self.num_channels
897        for _ in range(self.num_layers):
898            layers.append(
899                nn.Sequential(
900                    norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)),
901                    nn.LeakyReLU(0.2, inplace=True)
902                )
903            )
904            in_channels = out_channels
905
906        layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)))
907
908        super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain)
909
910
911discriminators = {
912    'mag': DiscriminatorMag,
913    'freqpos': DiscriminatorMagFreqPosition,
914    '2dpos': DiscriminatorMag2dPositional,
915    'experimental': DiscriminatorExperimental,
916    'free': DiscriminatorMagFree
917}
918
919class TFDMultiResolutionDiscriminator(torch.nn.Module):
920    def __init__(self,
921                 fft_sizes_16k=[64, 128, 256, 512, 1024, 2048],
922                 architecture='mag',
923                 fs=16000,
924                 freq_roi=[50, 7400],
925                 noise_gain=0,
926                 use_spectral_norm=False,
927                 **kwargs):
928
929        super().__init__()
930
931
932        fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k]
933
934        resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes]
935
936
937        Disc = discriminators[architecture]
938
939        discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))]
940
941        self.discriminators = nn.ModuleList(discs)
942
943    def forward(self, y):
944        outputs = []
945
946        for  disc in self.discriminators:
947            outputs.append(disc(y))
948
949        return outputs
950
951
952class FWGAN_disc_wrapper(nn.Module):
953    def __init__(self, disc):
954        super().__init__()
955
956        self.disc = disc
957
958    def forward(self, y, y_hat):
959
960        out_real = self.disc(y)
961        out_fake = self.disc(y_hat)
962
963        y_d_rs = []
964        y_d_gs = []
965        fmap_rs = []
966        fmap_gs = []
967
968        for y_real, y_fake in zip(out_real, out_fake):
969            y_d_rs.append(y_real[-1])
970            y_d_gs.append(y_fake[-1])
971            fmap_rs.append(y_real[:-1])
972            fmap_gs.append(y_fake[:-1])
973
974        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
975