xref: /aosp_15_r20/external/libopus/dnn/torch/fwgan/models/fwgan400.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from torch.nn.utils import weight_norm
5import numpy as np
6
7which_norm = weight_norm
8
9#################### Definition of basic model components ####################
10
11#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet)
12class ConvLookahead(nn.Module):
13    def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False):
14        super(ConvLookahead, self).__init__()
15        torch.manual_seed(5)
16
17        self.padding_left = (kernel_size - 2) * dilation
18        self.padding_right = 1 * dilation
19
20        self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias))
21
22        self.init_weights()
23
24    def init_weights(self):
25
26        for m in self.modules():
27            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
28                nn.init.orthogonal_(m.weight.data)
29
30    def forward(self, x):
31
32        x = F.pad(x,(self.padding_left, self.padding_right))
33        conv_out = self.conv(x)
34        return conv_out
35
36#(modified) GLU Activation layer definition
37class GLU(nn.Module):
38    def __init__(self, feat_size):
39        super(GLU, self).__init__()
40
41        torch.manual_seed(5)
42
43        self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False))
44
45        self.init_weights()
46
47    def init_weights(self):
48
49        for m in self.modules():
50            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
51            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
52                nn.init.orthogonal_(m.weight.data)
53
54    def forward(self, x):
55
56        out = torch.tanh(x) * torch.sigmoid(self.gate(x))
57
58        return out
59
60#GRU layer definition
61class ContForwardGRU(nn.Module):
62    def __init__(self, input_size, hidden_size, num_layers=1):
63        super(ContForwardGRU, self).__init__()
64
65        torch.manual_seed(5)
66
67        self.hidden_size = hidden_size
68
69        self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.hidden_size, bias=False)),
70                        nn.Tanh())
71
72        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\
73                          bias=False)
74
75        self.nl = GLU(self.hidden_size)
76
77        self.init_weights()
78
79    def init_weights(self):
80
81        for m in self.modules():
82            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
83                nn.init.orthogonal_(m.weight.data)
84
85    def forward(self, x, x0):
86
87        self.gru.flatten_parameters()
88
89        h0 = self.cont_fc(x0).unsqueeze(0)
90
91        output, h0 = self.gru(x, h0)
92
93        return self.nl(output)
94
95# Framewise convolution layer definition
96class ContFramewiseConv(torch.nn.Module):
97
98    def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True):
99
100        super(ContFramewiseConv, self).__init__()
101        torch.manual_seed(5)
102
103        self.frame_kernel_size = frame_kernel_size
104        self.frame_len = frame_len
105
106        if (causal == True) or (self.frame_kernel_size == 2):
107
108            self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len
109            self.required_pad_right = 0
110
111            self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.required_pad_left, bias=False)),
112                                    nn.Tanh()
113                                   )
114
115        else:
116
117            self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len
118            self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len
119
120        self.fc_input_dim = self.frame_kernel_size * self.frame_len
121        self.fc_out_dim = out_dim
122
123        if act=='glu':
124            self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
125                                    GLU(self.fc_out_dim)
126                                   )
127        if act=='tanh':
128            self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)),
129                                    nn.Tanh()
130                                   )
131
132        self.init_weights()
133
134
135    def init_weights(self):
136
137        for m in self.modules():
138            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
139            isinstance(m, nn.Embedding):
140                nn.init.orthogonal_(m.weight.data)
141
142    def forward(self, x, x0):
143
144        if self.frame_kernel_size == 1:
145            return self.fc(x)
146
147        x_flat = x.reshape(x.size(0),1,-1)
148        pad = self.cont_fc(x0).view(x0.size(0),1,-1)
149        x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2)
150
151        x_flat_padded_unfolded = F.unfold(x_flat_padded,\
152                    kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous()
153
154        out = self.fc(x_flat_padded_unfolded)
155        return out
156
157# A fully-connected based upsampling layer definition
158class UpsampleFC(nn.Module):
159    def __init__(self, in_ch, out_ch, upsample_factor):
160        super(UpsampleFC, self).__init__()
161        torch.manual_seed(5)
162
163        self.in_ch = in_ch
164        self.out_ch = out_ch
165        self.upsample_factor = upsample_factor
166        self.fc = nn.Linear(in_ch, out_ch * upsample_factor, bias=False)
167        self.nl = nn.Tanh()
168
169        self.init_weights()
170
171    def init_weights(self):
172
173        for m in self.modules():
174            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or\
175            isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
176                nn.init.orthogonal_(m.weight.data)
177
178    def forward(self, x):
179
180        batch_size = x.size(0)
181        x = x.permute(0, 2, 1)
182        x = self.nl(self.fc(x))
183        x = x.reshape((batch_size, -1, self.out_ch))
184        x = x.permute(0, 2, 1)
185        return x
186
187########################### The complete model definition #################################
188
189class FWGAN400ContLarge(nn.Module):
190    def __init__(self):
191        super().__init__()
192        torch.manual_seed(5)
193
194        self.bfcc_with_corr_upsampler = UpsampleFC(19,80,4)
195
196        self.feat_in_conv1 = ConvLookahead(160,256,kernel_size=5)
197        self.feat_in_nl1 = GLU(256)
198
199        self.cont_net = nn.Sequential(which_norm(nn.Linear(321, 160, bias=False)),
200                                      nn.Tanh(),
201                                      which_norm(nn.Linear(160, 160, bias=False)),
202                                      nn.Tanh(),
203                                      which_norm(nn.Linear(160, 80, bias=False)),
204                                      nn.Tanh(),
205                                      which_norm(nn.Linear(80, 80, bias=False)),
206                                      nn.Tanh(),
207                                      which_norm(nn.Linear(80, 64, bias=False)),
208                                      nn.Tanh(),
209                                      which_norm(nn.Linear(64, 64, bias=False)),
210                                      nn.Tanh())
211
212        self.rnn = ContForwardGRU(256,256)
213
214        self.fwc1 = ContFramewiseConv(256, 256)
215        self.fwc2 = ContFramewiseConv(256, 128)
216        self.fwc3 = ContFramewiseConv(128, 128)
217        self.fwc4 = ContFramewiseConv(128, 64)
218        self.fwc5 = ContFramewiseConv(64, 64)
219        self.fwc6 = ContFramewiseConv(64, 40)
220        self.fwc7 = ContFramewiseConv(40, 40)
221
222        self.init_weights()
223        self.count_parameters()
224
225    def init_weights(self):
226
227        for m in self.modules():
228            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\
229            isinstance(m, nn.Embedding):
230                nn.init.orthogonal_(m.weight.data)
231
232    def count_parameters(self):
233        num_params =  sum(p.numel() for p in self.parameters() if p.requires_grad)
234        print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n")
235
236    def create_phase_signals(self, periods):
237
238        batch_size = periods.size(0)
239        progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1))
240        progression = torch.repeat_interleave(progression, batch_size, 0)
241
242        phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1)
243        chunks = []
244        for sframe in range(periods.size(1)):
245            f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
246
247            chunk_sin = torch.sin(f  * progression + phase0)
248            chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,40)
249
250            chunk_cos = torch.cos(f  * progression + phase0)
251            chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,40)
252
253            chunk = torch.cat((chunk_sin, chunk_cos), dim = -1)
254
255            phase0 = phase0 + 160 * f
256
257            chunks.append(chunk)
258
259        phase_signals = torch.cat(chunks, dim=1)
260
261        return phase_signals
262
263
264    def gain_multiply(self, x, c0):
265
266        gain = 10**(0.5*c0/np.sqrt(18.0))
267        gain = torch.repeat_interleave(gain, 160, dim=-1)
268        gain  = gain.reshape(gain.size(0),1,-1).squeeze(1)
269
270        return x * gain
271
272    def forward(self, pitch_period, bfcc_with_corr, x0):
273
274        norm_x0 = torch.norm(x0,2, dim=-1, keepdim=True)
275        x0 = x0 / torch.sqrt((1e-8) +  norm_x0**2)
276        x0 = torch.cat((torch.log(norm_x0 + 1e-7), x0), dim=-1)
277
278        p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous()
279
280        envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous())
281
282        feat_in = torch.cat((p_embed , envelope), dim=1)
283
284        wav_latent1 = self.feat_in_nl1(self.feat_in_conv1(feat_in).permute(0,2,1).contiguous())
285
286        cont_latent = self.cont_net(x0)
287
288        rnn_out = self.rnn(wav_latent1, cont_latent)
289
290        fwc1_out = self.fwc1(rnn_out, cont_latent)
291
292        fwc2_out = self.fwc2(fwc1_out, cont_latent)
293
294        fwc3_out = self.fwc3(fwc2_out, cont_latent)
295
296        fwc4_out = self.fwc4(fwc3_out, cont_latent)
297
298        fwc5_out = self.fwc5(fwc4_out, cont_latent)
299
300        fwc6_out = self.fwc6(fwc5_out, cont_latent)
301
302        fwc7_out = self.fwc7(fwc6_out, cont_latent)
303
304        waveform = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1)
305
306        waveform = self.gain_multiply(waveform,bfcc_with_corr[:,:,:1])
307
308        return  waveform