xref: /aosp_15_r20/external/libopus/dnn/torch/neural-pitch/models.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2Pitch Estimation Models and dataloaders
3    - Classification Based (Input features, output logits)
4"""
5
6import torch
7import numpy as np
8
9class PitchDNNIF(torch.nn.Module):
10
11    def __init__(self, input_dim=88, gru_dim=64, output_dim=192):
12        super().__init__()
13
14        self.activation = torch.nn.Tanh()
15        self.initial = torch.nn.Linear(input_dim, gru_dim)
16        self.hidden = torch.nn.Linear(gru_dim, gru_dim)
17        self.gru = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, batch_first=True)
18        self.upsample = torch.nn.Linear(gru_dim, output_dim)
19
20    def forward(self, x):
21
22        x = self.initial(x)
23        x = self.activation(x)
24        x = self.hidden(x)
25        x = self.activation(x)
26        x,_ = self.gru(x)
27        x = self.upsample(x)
28        x = self.activation(x)
29        x = x.permute(0,2,1)
30
31        return x
32
33class PitchDNNXcorr(torch.nn.Module):
34
35    def __init__(self, input_dim=90, gru_dim=64, output_dim=192):
36        super().__init__()
37
38        self.activation = torch.nn.Tanh()
39
40        self.conv = torch.nn.Sequential(
41            torch.nn.ZeroPad2d((2, 0, 1, 1)),
42            torch.nn.Conv2d(1, 8, 3, bias=True),
43            self.activation,
44            torch.nn.ZeroPad2d((2,0,1,1)),
45            torch.nn.Conv2d(8, 8, 3, bias=True),
46            self.activation,
47            torch.nn.ZeroPad2d((2,0,1,1)),
48            torch.nn.Conv2d(8, 1, 3, bias=True),
49            self.activation,
50        )
51
52        self.downsample = torch.nn.Sequential(
53            torch.nn.Linear(input_dim, gru_dim),
54            self.activation
55        )
56        self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True)
57        self.upsample = torch.nn.Sequential(
58            torch.nn.Linear(gru_dim,output_dim),
59            self.activation
60        )
61
62    def forward(self, x):
63        x = self.conv(x.unsqueeze(-1).permute(0,3,2,1)).squeeze(1)
64        x,_ = self.GRU(self.downsample(x.permute(0,2,1)))
65        x = self.upsample(x).permute(0,2,1)
66
67        return x
68
69class PitchDNN(torch.nn.Module):
70    """
71    Joint IF-xcorr
72    1D CNN on IF, merge with xcorr, 2D CNN on merged + GRU
73    """
74
75    def __init__(self,input_IF_dim=88, input_xcorr_dim=224, gru_dim=64, output_dim=192):
76        super().__init__()
77
78        self.activation = torch.nn.Tanh()
79
80        self.if_upsample = torch.nn.Sequential(
81            torch.nn.Linear(input_IF_dim,64),
82            self.activation,
83            torch.nn.Linear(64,64),
84            self.activation,
85        )
86
87        self.conv = torch.nn.Sequential(
88            torch.nn.ZeroPad2d((2,0,1,1)),
89            torch.nn.Conv2d(1, 4, 3, bias=True),
90            self.activation,
91            torch.nn.ZeroPad2d((2,0,1,1)),
92            torch.nn.Conv2d(4, 1, 3, bias=True),
93            self.activation,
94        )
95
96        self.downsample = torch.nn.Sequential(
97            torch.nn.Linear(64 + input_xcorr_dim, gru_dim),
98            self.activation
99        )
100        self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True)
101        self.upsample = torch.nn.Sequential(
102            torch.nn.Linear(gru_dim, output_dim)
103        )
104
105    def forward(self, x):
106        xcorr_feat = x[:,:,:224]
107        if_feat = x[:,:,224:]
108        xcorr_feat = self.conv(xcorr_feat.unsqueeze(-1).permute(0,3,2,1)).squeeze(1).permute(0,2,1)
109        if_feat = self.if_upsample(if_feat)
110        x = torch.cat([xcorr_feat,if_feat],axis = - 1)
111        x,_ = self.GRU(self.downsample(x))
112        x = self.upsample(x).permute(0,2,1)
113
114        return x
115
116
117# Dataloaders
118class Loader(torch.utils.data.Dataset):
119      def __init__(self, features_if, file_pitch, confidence_threshold=0.4, dimension_if=30, context=100):
120            self.if_feat = np.memmap(features_if, dtype=np.float32).reshape(-1,3*dimension_if)
121
122            # Resolution of 20 cents
123            self.cents = np.rint(np.load(file_pitch)[0,:]/20)
124            self.cents = np.clip(self.cents,0,179)
125            self.confidence = np.load(file_pitch)[1,:]
126
127            # Filter confidence for CREPE
128            self.confidence[self.confidence < confidence_threshold] = 0
129            self.context = context
130            # Clip both to same size
131            size_common = min(self.if_feat.shape[0], self.cents.shape[0])
132            self.if_feat = self.if_feat[:size_common,:]
133            self.cents = self.cents[:size_common]
134            self.confidence = self.confidence[:size_common]
135
136            frame_max = self.if_feat.shape[0]//context
137            self.if_feat = np.reshape(self.if_feat[:frame_max*context, :],(frame_max, context,3*dimension_if))
138            self.cents = np.reshape(self.cents[:frame_max * context],(frame_max, context))
139            self.confidence = np.reshape(self.confidence[:frame_max*context],(frame_max, context))
140
141      def __len__(self):
142            return self.if_feat.shape[0]
143
144      def __getitem__(self, index):
145            return torch.from_numpy(self.if_feat[index,:,:]), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index])
146
147class PitchDNNDataloader(torch.utils.data.Dataset):
148      def __init__(self, features, file_pitch, confidence_threshold=0.4, context=100, choice_data='both'):
149            self.feat = np.memmap(features, mode='r', dtype=np.int8).reshape(-1,312)
150            self.xcorr = self.feat[:,:224]
151            self.if_feat = self.feat[:,224:]
152            ground_truth = np.memmap(file_pitch, mode='r', dtype=np.float32).reshape(-1,2)
153            self.cents = np.rint(60*np.log2(ground_truth[:,0]/62.5))
154            mask = (self.cents>=0).astype('float32') * (self.cents<=180).astype('float32')
155            self.cents = np.clip(self.cents,0,179)
156            self.confidence = ground_truth[:,1] * mask
157            # Filter confidence for CREPE
158            self.confidence[self.confidence < confidence_threshold] = 0
159            self.context = context
160
161            self.choice_data = choice_data
162
163            frame_max = self.if_feat.shape[0]//context
164            self.if_feat = np.reshape(self.if_feat[:frame_max*context,:], (frame_max, context, 88))
165            self.cents = np.reshape(self.cents[:frame_max*context], (frame_max,context))
166            self.xcorr = np.reshape(self.xcorr[:frame_max*context,:], (frame_max,context, 224))
167            self.confidence = np.reshape(self.confidence[:frame_max*context], (frame_max, context))
168
169      def __len__(self):
170            return self.if_feat.shape[0]
171
172      def __getitem__(self, index):
173            if self.choice_data == 'both':
174                return torch.cat([torch.from_numpy((1./127)*self.xcorr[index,:,:]), torch.from_numpy((1./127)*self.if_feat[index,:,:])], dim=-1), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index])
175            elif self.choice_data == 'if':
176                return torch.from_numpy((1./127)*self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
177            else:
178                return torch.from_numpy((1./127)*self.xcorr[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
179