1""" 2Pitch Estimation Models and dataloaders 3 - Classification Based (Input features, output logits) 4""" 5 6import torch 7import numpy as np 8 9class PitchDNNIF(torch.nn.Module): 10 11 def __init__(self, input_dim=88, gru_dim=64, output_dim=192): 12 super().__init__() 13 14 self.activation = torch.nn.Tanh() 15 self.initial = torch.nn.Linear(input_dim, gru_dim) 16 self.hidden = torch.nn.Linear(gru_dim, gru_dim) 17 self.gru = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, batch_first=True) 18 self.upsample = torch.nn.Linear(gru_dim, output_dim) 19 20 def forward(self, x): 21 22 x = self.initial(x) 23 x = self.activation(x) 24 x = self.hidden(x) 25 x = self.activation(x) 26 x,_ = self.gru(x) 27 x = self.upsample(x) 28 x = self.activation(x) 29 x = x.permute(0,2,1) 30 31 return x 32 33class PitchDNNXcorr(torch.nn.Module): 34 35 def __init__(self, input_dim=90, gru_dim=64, output_dim=192): 36 super().__init__() 37 38 self.activation = torch.nn.Tanh() 39 40 self.conv = torch.nn.Sequential( 41 torch.nn.ZeroPad2d((2, 0, 1, 1)), 42 torch.nn.Conv2d(1, 8, 3, bias=True), 43 self.activation, 44 torch.nn.ZeroPad2d((2,0,1,1)), 45 torch.nn.Conv2d(8, 8, 3, bias=True), 46 self.activation, 47 torch.nn.ZeroPad2d((2,0,1,1)), 48 torch.nn.Conv2d(8, 1, 3, bias=True), 49 self.activation, 50 ) 51 52 self.downsample = torch.nn.Sequential( 53 torch.nn.Linear(input_dim, gru_dim), 54 self.activation 55 ) 56 self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True) 57 self.upsample = torch.nn.Sequential( 58 torch.nn.Linear(gru_dim,output_dim), 59 self.activation 60 ) 61 62 def forward(self, x): 63 x = self.conv(x.unsqueeze(-1).permute(0,3,2,1)).squeeze(1) 64 x,_ = self.GRU(self.downsample(x.permute(0,2,1))) 65 x = self.upsample(x).permute(0,2,1) 66 67 return x 68 69class PitchDNN(torch.nn.Module): 70 """ 71 Joint IF-xcorr 72 1D CNN on IF, merge with xcorr, 2D CNN on merged + GRU 73 """ 74 75 def __init__(self,input_IF_dim=88, input_xcorr_dim=224, gru_dim=64, output_dim=192): 76 super().__init__() 77 78 self.activation = torch.nn.Tanh() 79 80 self.if_upsample = torch.nn.Sequential( 81 torch.nn.Linear(input_IF_dim,64), 82 self.activation, 83 torch.nn.Linear(64,64), 84 self.activation, 85 ) 86 87 self.conv = torch.nn.Sequential( 88 torch.nn.ZeroPad2d((2,0,1,1)), 89 torch.nn.Conv2d(1, 4, 3, bias=True), 90 self.activation, 91 torch.nn.ZeroPad2d((2,0,1,1)), 92 torch.nn.Conv2d(4, 1, 3, bias=True), 93 self.activation, 94 ) 95 96 self.downsample = torch.nn.Sequential( 97 torch.nn.Linear(64 + input_xcorr_dim, gru_dim), 98 self.activation 99 ) 100 self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True) 101 self.upsample = torch.nn.Sequential( 102 torch.nn.Linear(gru_dim, output_dim) 103 ) 104 105 def forward(self, x): 106 xcorr_feat = x[:,:,:224] 107 if_feat = x[:,:,224:] 108 xcorr_feat = self.conv(xcorr_feat.unsqueeze(-1).permute(0,3,2,1)).squeeze(1).permute(0,2,1) 109 if_feat = self.if_upsample(if_feat) 110 x = torch.cat([xcorr_feat,if_feat],axis = - 1) 111 x,_ = self.GRU(self.downsample(x)) 112 x = self.upsample(x).permute(0,2,1) 113 114 return x 115 116 117# Dataloaders 118class Loader(torch.utils.data.Dataset): 119 def __init__(self, features_if, file_pitch, confidence_threshold=0.4, dimension_if=30, context=100): 120 self.if_feat = np.memmap(features_if, dtype=np.float32).reshape(-1,3*dimension_if) 121 122 # Resolution of 20 cents 123 self.cents = np.rint(np.load(file_pitch)[0,:]/20) 124 self.cents = np.clip(self.cents,0,179) 125 self.confidence = np.load(file_pitch)[1,:] 126 127 # Filter confidence for CREPE 128 self.confidence[self.confidence < confidence_threshold] = 0 129 self.context = context 130 # Clip both to same size 131 size_common = min(self.if_feat.shape[0], self.cents.shape[0]) 132 self.if_feat = self.if_feat[:size_common,:] 133 self.cents = self.cents[:size_common] 134 self.confidence = self.confidence[:size_common] 135 136 frame_max = self.if_feat.shape[0]//context 137 self.if_feat = np.reshape(self.if_feat[:frame_max*context, :],(frame_max, context,3*dimension_if)) 138 self.cents = np.reshape(self.cents[:frame_max * context],(frame_max, context)) 139 self.confidence = np.reshape(self.confidence[:frame_max*context],(frame_max, context)) 140 141 def __len__(self): 142 return self.if_feat.shape[0] 143 144 def __getitem__(self, index): 145 return torch.from_numpy(self.if_feat[index,:,:]), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index]) 146 147class PitchDNNDataloader(torch.utils.data.Dataset): 148 def __init__(self, features, file_pitch, confidence_threshold=0.4, context=100, choice_data='both'): 149 self.feat = np.memmap(features, mode='r', dtype=np.int8).reshape(-1,312) 150 self.xcorr = self.feat[:,:224] 151 self.if_feat = self.feat[:,224:] 152 ground_truth = np.memmap(file_pitch, mode='r', dtype=np.float32).reshape(-1,2) 153 self.cents = np.rint(60*np.log2(ground_truth[:,0]/62.5)) 154 mask = (self.cents>=0).astype('float32') * (self.cents<=180).astype('float32') 155 self.cents = np.clip(self.cents,0,179) 156 self.confidence = ground_truth[:,1] * mask 157 # Filter confidence for CREPE 158 self.confidence[self.confidence < confidence_threshold] = 0 159 self.context = context 160 161 self.choice_data = choice_data 162 163 frame_max = self.if_feat.shape[0]//context 164 self.if_feat = np.reshape(self.if_feat[:frame_max*context,:], (frame_max, context, 88)) 165 self.cents = np.reshape(self.cents[:frame_max*context], (frame_max,context)) 166 self.xcorr = np.reshape(self.xcorr[:frame_max*context,:], (frame_max,context, 224)) 167 self.confidence = np.reshape(self.confidence[:frame_max*context], (frame_max, context)) 168 169 def __len__(self): 170 return self.if_feat.shape[0] 171 172 def __getitem__(self, index): 173 if self.choice_data == 'both': 174 return torch.cat([torch.from_numpy((1./127)*self.xcorr[index,:,:]), torch.from_numpy((1./127)*self.if_feat[index,:,:])], dim=-1), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index]) 175 elif self.choice_data == 'if': 176 return torch.from_numpy((1./127)*self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index]) 177 else: 178 return torch.from_numpy((1./127)*self.xcorr[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index]) 179