1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31from torch import nn 32from utils.layers.subconditioner import get_subconditioner 33from utils.layers import DualFC 34 35from utils.ulaw import lin2ulawq, ulaw2lin 36from utils.sample import sample_excitation 37from utils.pcm import clip_to_int16 38from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step 39 40from utils.misc import interleave_tensors 41 42 43 44 45# MultiRateLPCNet 46class MultiRateLPCNet(nn.Module): 47 def __init__(self, config): 48 super(MultiRateLPCNet, self).__init__() 49 50 # general parameters 51 self.input_layout = config['input_layout'] 52 self.feature_history = config['feature_history'] 53 self.feature_lookahead = config['feature_lookahead'] 54 self.signals = config['signals'] 55 56 # frame rate network parameters 57 self.feature_dimension = config['feature_dimension'] 58 self.period_embedding_dim = config['period_embedding_dim'] 59 self.period_levels = config['period_levels'] 60 self.feature_channels = self.feature_dimension + self.period_embedding_dim 61 self.feature_conditioning_dim = config['feature_conditioning_dim'] 62 self.feature_conv_kernel_size = config['feature_conv_kernel_size'] 63 64 # frame rate network layers 65 self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim) 66 self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid') 67 self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid') 68 self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim) 69 self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim])) 70 71 # sample rate network parameters 72 self.frame_size = config['frame_size'] 73 self.signal_levels = config['signal_levels'] 74 self.signal_embedding_dim = config['signal_embedding_dim'] 75 self.gru_a_units = config['gru_a_units'] 76 self.gru_b_units = config['gru_b_units'] 77 self.output_levels = config['output_levels'] 78 79 # subconditioning B 80 sub_config = config['subconditioning']['subconditioning_b'] 81 self.substeps_b = sub_config['number_of_subsamples'] 82 self.subcondition_signals_b = sub_config['signals'] 83 self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']] 84 method = sub_config['method'] 85 kwargs = sub_config['kwargs'] 86 if type(kwargs) == type(None): 87 kwargs = dict() 88 89 state_size = self.gru_b_units 90 self.subconditioner_b = get_subconditioner(method, 91 sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'], 92 state_size, self.signal_levels, len(sub_config['signals']), 93 **sub_config['kwargs']) 94 95 # subconditioning A 96 sub_config = config['subconditioning']['subconditioning_a'] 97 self.substeps_a = sub_config['number_of_subsamples'] 98 self.subcondition_signals_a = sub_config['signals'] 99 self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']] 100 method = sub_config['method'] 101 kwargs = sub_config['kwargs'] 102 if type(kwargs) == type(None): 103 kwargs = dict() 104 105 state_size = self.gru_a_units 106 self.subconditioner_a = get_subconditioner(method, 107 sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'], 108 state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']), 109 **sub_config['kwargs']) 110 111 112 # wrap up subconditioning, group_size_gru_a holds the number 113 # of timesteps that are grouped as sample input for GRU A 114 # input and group_size_subcondition_a holds the number of samples that are 115 # grouped as input to pre-GRU B subconditioning 116 self.group_size_gru_a = self.substeps_a * self.substeps_b 117 self.group_size_subcondition_a = self.substeps_b 118 self.gru_a_rate_divider = self.group_size_gru_a 119 self.gru_b_rate_divider = self.substeps_b 120 121 # gru sizes 122 self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim 123 self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim 124 self.signals_idx = [self.input_layout['signals'][key] for key in self.signals] 125 126 # sample rate network layers 127 self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim) 128 self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True) 129 self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True) 130 131 # sparsification 132 self.sparsifier = [] 133 134 # GRU A 135 if 'gru_a' in config['sparsification']: 136 gru_config = config['sparsification']['gru_a'] 137 task_list = [(self.gru_a, gru_config['params'])] 138 self.sparsifier.append(GRUSparsifier(task_list, 139 gru_config['start'], 140 gru_config['stop'], 141 gru_config['interval'], 142 gru_config['exponent']) 143 ) 144 self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, 145 gru_config['params'], drop_input=True) 146 else: 147 self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True) 148 149 # GRU B 150 if 'gru_b' in config['sparsification']: 151 gru_config = config['sparsification']['gru_b'] 152 task_list = [(self.gru_b, gru_config['params'])] 153 self.sparsifier.append(GRUSparsifier(task_list, 154 gru_config['start'], 155 gru_config['stop'], 156 gru_config['interval'], 157 gru_config['exponent']) 158 ) 159 self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b, 160 gru_config['params']) 161 else: 162 self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b) 163 164 165 166 # dual FCs 167 self.dual_fc = [] 168 for i in range(self.substeps_b): 169 dim = self.subconditioner_b.get_output_dim(i) 170 self.dual_fc.append(DualFC(dim, self.output_levels)) 171 self.add_module(f"dual_fc_{i}", self.dual_fc[-1]) 172 173 def get_gflops(self, fs, verbose=False, hierarchical_sampling=False): 174 gflops = 0 175 176 # frame rate network 177 conditioning_dim = self.feature_conditioning_dim 178 feature_channels = self.feature_channels 179 frame_rate = fs / self.frame_size 180 frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate 181 if verbose: 182 print(f"frame rate network: {frame_rate_network_complexity} GFLOPS") 183 gflops += frame_rate_network_complexity 184 185 # gru a 186 gru_a_rate = fs / self.group_size_gru_a 187 gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step 188 if verbose: 189 print(f"gru A: {gru_a_complexity} GFLOPS") 190 gflops += gru_a_complexity 191 192 # subconditioning a 193 subcond_a_rate = fs / self.substeps_b 194 subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate 195 if verbose: 196 print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS") 197 gflops += subconditioning_a_complexity 198 199 # gru b 200 gru_b_rate = fs / self.substeps_b 201 gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step 202 if verbose: 203 print(f"gru B: {gru_b_complexity} GFLOPS") 204 gflops += gru_b_complexity 205 206 # subconditioning b 207 subcond_b_rate = fs 208 subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate 209 if verbose: 210 print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS") 211 gflops += subconditioning_b_complexity 212 213 # dual fcs 214 for i, fc in enumerate(self.dual_fc): 215 rate = fs / len(self.dual_fc) 216 input_size = fc.dense1.in_features 217 output_size = fc.dense1.out_features 218 dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate 219 if hierarchical_sampling: 220 dual_fc_complexity /= 8 221 if verbose: 222 print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS") 223 gflops += dual_fc_complexity 224 225 if verbose: 226 print(f'total: {gflops} GFLOPS') 227 228 return gflops 229 230 231 232 def sparsify(self): 233 for sparsifier in self.sparsifier: 234 sparsifier.step() 235 236 def frame_rate_network(self, features, periods): 237 238 embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3) 239 features = torch.concat((features, embedded_periods), dim=-1) 240 241 # convert to channels first and calculate conditioning vector 242 c = torch.permute(features, [0, 2, 1]) 243 244 c = torch.tanh(self.feature_conv1(c)) 245 c = torch.tanh(self.feature_conv2(c)) 246 # back to channels last 247 c = torch.permute(c, [0, 2, 1]) 248 c = torch.tanh(self.feature_dense1(c)) 249 c = torch.tanh(self.feature_dense2(c)) 250 251 return c 252 253 def prepare_signals(self, signals, group_size, signal_idx): 254 """ extracts, delays and groups signals """ 255 256 batch_size, sequence_length, num_signals = signals.shape 257 258 # extract signals according to position 259 signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx], 260 dim=-1) 261 262 # roll back pcm to account for grouping 263 signals = torch.roll(signals, group_size - 1, -2) 264 265 # reshape 266 signals = torch.reshape(signals, 267 (batch_size, sequence_length // group_size, group_size * len(signal_idx))) 268 269 return signals 270 271 272 def sample_rate_network(self, signals, c, gru_states): 273 274 signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx) 275 embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3) 276 # features at GRU A rate 277 c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1) 278 # features at GRU B rate 279 c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1) 280 281 y = torch.concat((embedded_signals, c_upsampled_a), dim=-1) 282 y, gru_a_state = self.gru_a(y, gru_states[0]) 283 # first round of upsampling and subconditioning 284 c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a) 285 y = self.subconditioner_a(y, c_signals_a) 286 y = interleave_tensors(y) 287 288 y = torch.concat((y, c_upsampled_b), dim=-1) 289 y, gru_b_state = self.gru_b(y, gru_states[1]) 290 c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b) 291 y = self.subconditioner_b(y, c_signals_b) 292 293 y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)] 294 y = interleave_tensors(y) 295 296 return y, (gru_a_state, gru_b_state) 297 298 def decoder(self, signals, c, gru_states): 299 embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3) 300 301 y = torch.concat((embedded_signals, c), dim=-1) 302 y, gru_a_state = self.gru_a(y, gru_states[0]) 303 y = torch.concat((y, c), dim=-1) 304 y, gru_b_state = self.gru_b(y, gru_states[1]) 305 306 y = self.dual_fc(y) 307 308 return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state) 309 310 def forward(self, features, periods, signals, gru_states): 311 312 c = self.frame_rate_network(features, periods) 313 y, _ = self.sample_rate_network(signals, c, gru_states) 314 log_probs = torch.log_softmax(y, dim=-1) 315 316 return log_probs 317 318 def generate(self, features, periods, lpcs): 319 320 with torch.no_grad(): 321 device = self.parameters().__next__().device 322 323 num_frames = features.shape[0] - self.feature_history - self.feature_lookahead 324 lpc_order = lpcs.shape[-1] 325 num_input_signals = len(self.signals) 326 pitch_corr_position = self.input_layout['features']['pitch_corr'][0] 327 328 # signal buffers 329 last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1)) 330 prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1)) 331 last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1)) 332 output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16) 333 mem = 0 334 335 # state buffers 336 gru_a_state = torch.zeros((1, 1, self.gru_a_units)) 337 gru_b_state = torch.zeros((1, 1, self.gru_b_units)) 338 339 input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long) 340 # conditioning signals for subconditioner a 341 c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long) 342 # conditioning signals for subconditioner b 343 c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long) 344 345 # signal dict 346 signal_dict = { 347 'prediction' : prediction, 348 'last_error' : last_error, 349 'last_signal' : last_signal 350 } 351 352 # push data to device 353 features = features.to(device) 354 periods = periods.to(device) 355 lpcs = lpcs.to(device) 356 357 # run feature encoding 358 c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0)) 359 360 for frame_index in range(num_frames): 361 frame_start = frame_index * self.frame_size 362 pitch_corr = features[frame_index + self.feature_history, pitch_corr_position] 363 a = - torch.flip(lpcs[frame_index + self.feature_history], [0]) 364 current_c = c[:, frame_index : frame_index + 1, :] 365 366 for i in range(0, self.frame_size, self.group_size_gru_a): 367 pcm_position = frame_start + i + lpc_order 368 output_position = frame_start + i 369 370 # calculate newest prediction 371 prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a) 372 373 # prepare input 374 for slot in range(self.group_size_gru_a): 375 k = slot - self.group_size_gru_a + 1 376 for idx, name in enumerate(self.signals): 377 input_signals[idx + slot * num_input_signals] = lin2ulawq( 378 signal_dict[name][pcm_position + k] 379 ) 380 381 382 # run GRU A 383 embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1))) 384 embed_signals = torch.flatten(embed_signals, 2) 385 y = torch.cat((embed_signals, current_c), dim=-1) 386 h_a, gru_a_state = self.gru_a(y, gru_a_state) 387 388 # loop over substeps_a 389 for step_a in range(self.substeps_a): 390 # prepare conditioning input 391 for slot in range(self.group_size_subcondition_a): 392 k = slot - self.group_size_subcondition_a + 1 393 for idx, name in enumerate(self.subcondition_signals_a): 394 c_signals_a[idx + slot * num_input_signals] = lin2ulawq( 395 signal_dict[name][pcm_position + k] 396 ) 397 398 # subconditioning 399 h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1))) 400 401 # run GRU B 402 y = torch.cat((h_a, current_c), dim=-1) 403 h_b, gru_b_state = self.gru_b(y, gru_b_state) 404 405 # loop over substeps b 406 for step_b in range(self.substeps_b): 407 # prepare subconditioning input 408 for idx, name in enumerate(self.subcondition_signals_b): 409 c_signals_b[idx] = lin2ulawq( 410 signal_dict[name][pcm_position] 411 ) 412 413 # subcondition 414 h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1))) 415 416 # run dual FC 417 probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1) 418 419 # sample 420 new_exc = ulaw2lin(sample_excitation(probs, pitch_corr)) 421 422 # update signals 423 sig = new_exc + prediction[pcm_position] 424 last_error[pcm_position + 1] = new_exc 425 last_signal[pcm_position + 1] = sig 426 427 mem = 0.85 * mem + float(sig) 428 output[output_position] = clip_to_int16(round(mem)) 429 430 # increase positions 431 pcm_position += 1 432 output_position += 1 433 434 # calculate next prediction 435 prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a) 436 437 return output 438