1""" 2/* Copyright (c) 2022 Amazon 3 Written by Jan Buethe and Jean-Marc Valin */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29import os 30import subprocess 31import argparse 32 33 34import numpy as np 35from scipy.io import wavfile 36import tensorflow as tf 37 38from rdovae import new_rdovae_model, pvq_quantize, apply_dead_zone, sq_rate_metric 39from fec_packets import write_fec_packets, read_fec_packets 40 41 42debug = False 43 44if debug: 45 args = type('dummy', (object,), 46 { 47 'input' : 'item1.wav', 48 'weights' : 'testout/rdovae_alignment_fix_1024_120.h5', 49 'enc_lambda' : 0.0007, 50 'output' : "test_0007.fec", 51 'cond_size' : 1024, 52 'num_redundancy_frames' : 64, 53 'extra_delay' : 0, 54 'dump_data' : './dump_data' 55 })() 56 os.environ['CUDA_VISIBLE_DEVICES']="" 57else: 58 parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames') 59 60 parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)') 61 parser.add_argument('weights', metavar='<weights>', help='trained model file (.h5)') 62# parser.add_argument('enc_lambda', metavar='<lambda>', type=float, help='lambda for controlling encoder rate') 63 parser.add_argument('output', type=str, help='output file (will be extended with .fec)') 64 65 parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)') 66 parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)') 67 parser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 40)", default=40) 68 parser.add_argument('--num-redundancy-frames', default=64, type=int, help='number of redundancy frames (20ms) per packet (default 64)') 69 parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)") 70 parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)') 71 72 parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk') 73 74 args = parser.parse_args() 75 76model, encoder, decoder, qembedding = new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=1, nb_quant=args.quant_levels, cond_size=args.cond_size) 77model.load_weights(args.weights) 78 79lpc_order = 16 80 81## prepare input signal 82# SILK frame size is 20ms and LPCNet subframes are 10ms 83subframe_size = 160 84frame_size = 2 * subframe_size 85 86# 91 samples delay to align with SILK decoded frames 87silk_delay = 91 88 89# prepend zeros to have enough history to produce the first package 90zero_history = (args.num_redundancy_frames - 1) * frame_size 91 92# dump data has a (feature) delay of 10ms 93dump_data_delay = 160 94 95total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay 96 97# load signal 98if args.input.endswith('.raw') or args.input.endswith('.pcm') or args.input.endswith('.sw'): 99 signal = np.fromfile(args.input, dtype='int16') 100 101elif args.input.endswith('.wav'): 102 fs, signal = wavfile.read(args.input) 103else: 104 raise ValueError(f'unknown input signal format: {args.input}') 105 106# fill up last frame with zeros 107padded_signal_length = len(signal) + total_delay 108tail = padded_signal_length % frame_size 109right_padding = (frame_size - tail) % frame_size 110 111signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16))) 112 113padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw' 114signal.tofile(padded_signal_file) 115 116# write signal and call dump_data to create features 117 118feature_file = os.path.splitext(args.input)[0] + '_features.f32' 119command = f"{args.dump_data} -test {padded_signal_file} {feature_file}" 120r = subprocess.run(command, shell=True) 121if r.returncode != 0: 122 raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}") 123 124# load features 125nb_features = model.nb_used_features + lpc_order 126nb_used_features = model.nb_used_features 127 128# load features 129features = np.fromfile(feature_file, dtype='float32') 130num_subframes = len(features) // nb_features 131num_subframes = 2 * (num_subframes // 2) 132num_frames = num_subframes // 2 133 134features = np.reshape(features, (1, -1, nb_features)) 135features = features[:, :, :nb_used_features] 136features = features[:, :num_subframes, :] 137 138#variable quantizer depending on the delay 139q0 = 3 140q1 = 15 141quant_id = np.round(q1 + (q0-q1)*np.arange(args.num_redundancy_frames//2)/args.num_redundancy_frames).astype('int16') 142#print(quant_id) 143 144quant_embed = qembedding(quant_id) 145 146# run encoder 147print("running fec encoder...") 148symbols, gru_state_dec = encoder.predict(features) 149 150# apply quantization 151nsymbols = 80 152quant_scale = tf.math.softplus(quant_embed[:, :nsymbols]).numpy() 153dead_zone = tf.math.softplus(quant_embed[:, nsymbols : 2 * nsymbols]).numpy() 154#symbols = apply_dead_zone([symbols, dead_zone]).numpy() 155#qsymbols = np.round(symbols) 156quant_gru_state_dec = pvq_quantize(gru_state_dec, 82) 157 158# rate estimate 159hard_distr_embed = tf.math.sigmoid(quant_embed[:, 4 * nsymbols : ]).numpy() 160#rate_input = np.concatenate((qsymbols, hard_distr_embed, enc_lambda), axis=-1) 161#rates = sq_rate_metric(None, rate_input, reduce=False).numpy() 162 163# run decoder 164input_length = args.num_redundancy_frames // 2 165offset = args.num_redundancy_frames - 1 166 167packets = [] 168packet_sizes = [] 169 170sym_batch = np.zeros((num_frames-offset, args.num_redundancy_frames//2, nsymbols), dtype='float32') 171quant_state = quant_gru_state_dec[0, offset:num_frames, :] 172#pack symbols for batch processing 173for i in range(offset, num_frames): 174 sym_batch[i-offset, :, :] = symbols[0, i - 2 * input_length + 2 : i + 1 : 2, :] 175 176#quantize symbols 177sym_batch = sym_batch * quant_scale 178sym_batch = apply_dead_zone([sym_batch, dead_zone]).numpy() 179sym_batch = np.round(sym_batch) 180 181hard_distr_embed = np.broadcast_to(hard_distr_embed, (sym_batch.shape[0], sym_batch.shape[1], 2*sym_batch.shape[2])) 182fake_lambda = np.ones((sym_batch.shape[0], sym_batch.shape[1], 1), dtype='float32') 183rate_input = np.concatenate((sym_batch, hard_distr_embed, fake_lambda), axis=-1) 184rates = sq_rate_metric(None, rate_input, reduce=False).numpy() 185#print(rates.shape) 186print("average rate = ", np.mean(rates[args.num_redundancy_frames:,:])) 187 188#sym_batch.tofile('qsyms.f32') 189 190sym_batch = sym_batch / quant_scale 191#print(sym_batch.shape, quant_state.shape) 192#features = decoder.predict([sym_batch, quant_state]) 193features = decoder([sym_batch, quant_state]) 194 195#for i in range(offset, num_frames): 196# print(f"processing frame {i - offset}...") 197# features = decoder.predict([qsymbols[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_embed_dec[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_gru_state_dec[:, i, :]]) 198# packets.append(features) 199# packet_size = 8 * int((np.sum(rates[:, i - 2 * input_length + 2 : i + 1 : 2]) + 7) / 8) + 64 200# packet_sizes.append(packet_size) 201 202 203# write packets 204packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output 205#write_fec_packets(packet_file, packets, packet_sizes) 206 207 208#print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps") 209 210if args.lossfile != None: 211 loss = np.loadtxt(args.lossfile, dtype='int16') 212 fec_out = np.zeros((features.shape[0]*2, features.shape[-1]), dtype='float32') 213 foffset = -2 214 ptr = 0; 215 count = 2; 216 for i in range(features.shape[0]): 217 if (loss[i] == 0) or (i == features.shape[0]-1): 218 fec_out[ptr:ptr+count,:] = features[i, foffset:, :] 219 #print("filled ", count) 220 foffset = -2 221 ptr = ptr+count 222 count = 2 223 else: 224 count = count + 2 225 foffset = foffset - 2 226 227 fec_out_full = np.zeros((fec_out.shape[0], nb_features), dtype=np.float32) 228 fec_out_full[:, :nb_used_features] = fec_out 229 230 fec_out_full.tofile(packet_file[:-4] + f'_fec.f32') 231 232 233#create packets array like in the original version for debugging purposes 234for i in range(offset, num_frames): 235 packets.append(features[i-offset:i-offset+1, :, :]) 236 237if args.debug_output: 238 import itertools 239 240 #batches = [2, 4] 241 batches = [4] 242 #offsets = [0, 4, 20] 243 offsets = [0, (args.num_redundancy_frames - 2)*2] 244 # sanity checks 245 # 1. concatenate features at offset 0 246 for batch, offset in itertools.product(batches, offsets): 247 248 stop = packets[0].shape[1] - offset 249 print(batch, offset, stop) 250 test_features = np.concatenate([packet[:,stop - batch: stop, :] for packet in packets[::batch//2]], axis=1) 251 252 test_features_full = np.zeros((test_features.shape[1], nb_features), dtype=np.float32) 253 test_features_full[:, :nb_used_features] = test_features[0, :, :] 254 255 print(f"writing debug output {packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32'}") 256 test_features_full.tofile(packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32') 257