1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30""" Dataset for LPCNet training """ 31import os 32 33import yaml 34import torch 35import numpy as np 36from torch.utils.data import Dataset 37 38 39scale = 255.0/32768.0 40scale_1 = 32768.0/255.0 41def ulaw2lin(u): 42 u = u - 128 43 s = np.sign(u) 44 u = np.abs(u) 45 return s*scale_1*(np.exp(u/128.*np.log(256))-1) 46 47 48def lin2ulaw(x): 49 s = np.sign(x) 50 x = np.abs(x) 51 u = (s*(128*np.log(1+scale*x)/np.log(256))) 52 u = np.clip(128 + np.round(u), 0, 255) 53 return u 54 55 56def run_lpc(signal, lpcs, frame_length=160): 57 num_frames, lpc_order = lpcs.shape 58 59 prediction = np.concatenate( 60 [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)] 61 ) 62 error = signal[lpc_order :] - prediction 63 64 return prediction, error 65 66class LPCNetVocodingDataset(Dataset): 67 def __init__(self, 68 path_to_dataset, 69 features=['cepstrum', 'periods', 'pitch_corr'], 70 target='signal', 71 frames_per_sample=100, 72 feature_history=0, 73 feature_lookahead=0, 74 lpc_gamma=1): 75 76 super().__init__() 77 78 # load dataset info 79 self.path_to_dataset = path_to_dataset 80 with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f: 81 dataset = yaml.load(f, yaml.FullLoader) 82 83 # dataset version 84 self.version = dataset['version'] 85 if self.version == 1: 86 self.getitem = self.getitem_v1 87 elif self.version == 2: 88 self.getitem = self.getitem_v2 89 else: 90 raise ValueError(f"dataset version {self.version} unknown") 91 92 # features 93 self.feature_history = feature_history 94 self.feature_lookahead = feature_lookahead 95 self.frame_offset = 2 + self.feature_history 96 self.frames_per_sample = frames_per_sample 97 self.input_features = features 98 self.feature_frame_layout = dataset['feature_frame_layout'] 99 self.lpc_gamma = lpc_gamma 100 101 # load feature file 102 self.feature_file = os.path.join(path_to_dataset, dataset['feature_file']) 103 self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype']) 104 self.feature_frame_length = dataset['feature_frame_length'] 105 106 assert len(self.features) % self.feature_frame_length == 0 107 self.features = self.features.reshape((-1, self.feature_frame_length)) 108 109 # derive number of samples is dataset 110 self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample 111 112 # signals 113 self.frame_length = dataset['frame_length'] 114 self.signal_frame_layout = dataset['signal_frame_layout'] 115 self.target = target 116 117 # load signals 118 self.signal_file = os.path.join(path_to_dataset, dataset['signal_file']) 119 self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype']) 120 self.signal_frame_length = dataset['signal_frame_length'] 121 self.signals = self.signals.reshape((-1, self.signal_frame_length)) 122 assert len(self.signals) == len(self.features) * self.frame_length 123 124 125 def __getitem__(self, index): 126 return self.getitem(index) 127 128 def getitem_v2(self, index): 129 sample = dict() 130 131 # extract features 132 frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history 133 frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead 134 135 for feature in self.input_features: 136 feature_start, feature_stop = self.feature_frame_layout[feature] 137 sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] 138 139 # convert periods 140 if 'periods' in self.input_features: 141 sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') 142 143 signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length 144 signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length 145 146 # last_signal and signal are always expected to be there 147 sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']] 148 sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']] 149 150 # calculate prediction and error if lpc coefficients present and prediction not given 151 if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout: 152 # lpc coefficients with one frame lookahead 153 # frame positions (start one frame early for past excitation) 154 frame_start = self.frame_offset + self.frames_per_sample * index - 1 155 frame_stop = self.frame_offset + self.frames_per_sample * (index + 1) 156 157 # feature positions 158 lpc_start, lpc_stop = self.feature_frame_layout['lpc'] 159 lpc_order = lpc_stop - lpc_start 160 lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop] 161 162 # LPC weighting 163 lpc_order = lpc_stop - lpc_start 164 weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]) 165 lpcs = lpcs * weights 166 167 # signal position (lpc_order samples as history) 168 signal_start = frame_start * self.frame_length - lpc_order + 1 169 signal_stop = frame_stop * self.frame_length + 1 170 noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']] 171 clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']] 172 173 noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length) 174 175 # extract signals 176 offset = self.frame_length 177 sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample] 178 sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample] 179 # calculate error between real signal and noisy prediction 180 181 182 sample['error'] = sample['signal'] - sample['prediction'] 183 184 185 # concatenate features 186 feature_keys = [key for key in self.input_features if not key.startswith("periods")] 187 features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) 188 target = torch.FloatTensor(sample[self.target]) / 2**15 189 periods = torch.LongTensor(sample['periods']) 190 191 return {'features' : features, 'periods' : periods, 'target' : target} 192 193 def getitem_v1(self, index): 194 sample = dict() 195 196 # extract features 197 frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history 198 frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead 199 200 for feature in self.input_features: 201 feature_start, feature_stop = self.feature_frame_layout[feature] 202 sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] 203 204 # convert periods 205 if 'periods' in self.input_features: 206 sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') 207 208 signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length 209 signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length 210 211 # last_signal and signal are always expected to be there 212 for signal_name, index in self.signal_frame_layout.items(): 213 sample[signal_name] = self.signals[signal_start : signal_stop, index] 214 215 # concatenate features 216 feature_keys = [key for key in self.input_features if not key.startswith("periods")] 217 features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) 218 signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1) 219 target = torch.LongTensor(sample[self.target]) 220 periods = torch.LongTensor(sample['periods']) 221 222 return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target} 223 224 def __len__(self): 225 return self.dataset_length 226