1import os 2import time 3import torch 4import numpy as np 5from scipy import signal as si 6from scipy.io import wavfile 7import argparse 8 9from models import model_dict 10 11parser = argparse.ArgumentParser() 12parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name') 13parser.add_argument('weightfile', type=str, help='weight file') 14parser.add_argument('input', type=str, help='input: feature file or folder with feature files') 15parser.add_argument('output', type=str, help='output: wav file name or folder name, depending on input') 16 17 18########################### Signal Processing Layers ########################### 19 20def preemphasis(x, coef= -0.85): 21 22 return si.lfilter(np.array([1.0, coef]), np.array([1.0]), x).astype('float32') 23 24def deemphasis(x, coef= -0.85): 25 26 return si.lfilter(np.array([1.0]), np.array([1.0, coef]), x).astype('float32') 27 28gamma = 0.92 29weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) 30 31 32def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)): 33 34 out = np.zeros_like(frame) 35 36 filt = np.flip(filt) 37 38 inp = frame[:] 39 40 41 for i in range(0, inp.shape[0]): 42 43 s = inp[i] - np.dot(buffer*weighting_vector, filt) 44 45 buffer[0] = s 46 47 buffer = np.roll(buffer, -1) 48 49 out[i] = s 50 51 return out 52 53def inverse_perceptual_weighting (pw_signal, filters, weighting_vector): 54 55 #inverse perceptual weighting= H_preemph / W(z/gamma) 56 57 pw_signal = preemphasis(pw_signal) 58 59 signal = np.zeros_like(pw_signal) 60 buffer = np.zeros(16) 61 num_frames = pw_signal.shape[0] //160 62 assert num_frames == filters.shape[0] 63 64 for frame_idx in range(0, num_frames): 65 66 in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:] 67 out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector) 68 signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:] 69 buffer[:] = out_sig_frame[-16:] 70 71 return signal 72 73 74def process_item(generator, feature_filename, output_filename, verbose=False): 75 76 feat = np.memmap(feature_filename, dtype='float32', mode='r') 77 78 num_feat_frames = len(feat) // 36 79 feat = np.reshape(feat, (num_feat_frames, 36)) 80 81 bfcc = np.copy(feat[:, :18]) 82 corr = np.copy(feat[:, 19:20]) + 0.5 83 bfcc_with_corr = torch.from_numpy(np.hstack((bfcc, corr))).type(torch.FloatTensor).unsqueeze(0)#.to(device) 84 85 period = torch.from_numpy((0.1 + 50 * np.copy(feat[:, 18:19]) + 100)\ 86 .astype('int32')).type(torch.long).view(1,-1)#.to(device) 87 88 lpc_filters = np.copy(feat[:, -16:]) 89 90 start_time = time.time() 91 x1 = generator(period, bfcc_with_corr, torch.zeros(1,320)) #this means the vocoder runs in complete synthesis mode with zero history audio frames 92 end_time = time.time() 93 total_time = end_time - start_time 94 x1 = x1.squeeze(1).squeeze(0).detach().cpu().numpy() 95 gen_seconds = len(x1)/16000 96 out = deemphasis(inverse_perceptual_weighting(x1, lpc_filters, weighting_vector)) 97 if verbose: 98 print(f"Took {total_time:.3f}s to generate {len(x1)} samples ({gen_seconds}s) -> {gen_seconds/total_time:.2f}x real time") 99 100 out = np.clip(np.round(2**15 * out), -2**15, 2**15 -1).astype(np.int16) 101 wavfile.write(output_filename, 16000, out) 102 103 104########################### The inference loop over folder containing lpcnet feature files ################################# 105if __name__ == "__main__": 106 107 args = parser.parse_args() 108 109 generator = model_dict[args.model]() 110 111 112 #Load the FWGAN500Hz Checkpoint 113 saved_gen= torch.load(args.weightfile, map_location='cpu') 114 generator.load_state_dict(saved_gen) 115 116 #this is just to remove the weight_norm from the model layers as it's no longer needed 117 def _remove_weight_norm(m): 118 try: 119 torch.nn.utils.remove_weight_norm(m) 120 except ValueError: # this module didn't have weight norm 121 return 122 generator.apply(_remove_weight_norm) 123 124 #enable inference mode 125 generator = generator.eval() 126 127 print('Successfully loaded the generator model ... start generation:') 128 129 if os.path.isdir(args.input): 130 131 os.makedirs(args.output, exist_ok=True) 132 133 for fn in os.listdir(args.input): 134 print(f"processing input {fn}...") 135 feature_filename = os.path.join(args.input, fn) 136 output_filename = os.path.join(args.output, os.path.splitext(fn)[0] + f"_{args.model}.wav") 137 process_item(generator, feature_filename, output_filename) 138 else: 139 process_item(generator, args.input, args.output) 140 141 print("Finished!")