xref: /aosp_15_r20/external/libopus/dnn/training_tf2/fec_encoder.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2022 Amazon
3   Written by Jan Buethe and Jean-Marc Valin */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29import os
30import subprocess
31import argparse
32
33
34import numpy as np
35from scipy.io import wavfile
36import tensorflow as tf
37
38from rdovae import new_rdovae_model, pvq_quantize, apply_dead_zone, sq_rate_metric
39from fec_packets import write_fec_packets, read_fec_packets
40
41
42debug = False
43
44if debug:
45    args = type('dummy', (object,),
46    {
47        'input' : 'item1.wav',
48        'weights' : 'testout/rdovae_alignment_fix_1024_120.h5',
49        'enc_lambda' : 0.0007,
50        'output' : "test_0007.fec",
51        'cond_size' : 1024,
52        'num_redundancy_frames' : 64,
53        'extra_delay' : 0,
54        'dump_data' : './dump_data'
55    })()
56    os.environ['CUDA_VISIBLE_DEVICES']=""
57else:
58    parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames')
59
60    parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)')
61    parser.add_argument('weights', metavar='<weights>', help='trained model file (.h5)')
62#    parser.add_argument('enc_lambda', metavar='<lambda>', type=float, help='lambda for controlling encoder rate')
63    parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
64
65    parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
66    parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)')
67    parser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 40)", default=40)
68    parser.add_argument('--num-redundancy-frames', default=64, type=int, help='number of redundancy frames (20ms) per packet (default 64)')
69    parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
70    parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
71
72    parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk')
73
74    args = parser.parse_args()
75
76model, encoder, decoder, qembedding = new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=1, nb_quant=args.quant_levels, cond_size=args.cond_size)
77model.load_weights(args.weights)
78
79lpc_order = 16
80
81## prepare input signal
82# SILK frame size is 20ms and LPCNet subframes are 10ms
83subframe_size = 160
84frame_size = 2 * subframe_size
85
86# 91 samples delay to align with SILK decoded frames
87silk_delay = 91
88
89# prepend zeros to have enough history to produce the first package
90zero_history = (args.num_redundancy_frames - 1) * frame_size
91
92# dump data has a (feature) delay of 10ms
93dump_data_delay = 160
94
95total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay
96
97# load signal
98if args.input.endswith('.raw') or args.input.endswith('.pcm') or args.input.endswith('.sw'):
99    signal = np.fromfile(args.input, dtype='int16')
100
101elif args.input.endswith('.wav'):
102    fs, signal = wavfile.read(args.input)
103else:
104    raise ValueError(f'unknown input signal format: {args.input}')
105
106# fill up last frame with zeros
107padded_signal_length = len(signal) + total_delay
108tail = padded_signal_length % frame_size
109right_padding = (frame_size - tail) % frame_size
110
111signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16)))
112
113padded_signal_file  = os.path.splitext(args.input)[0] + '_padded.raw'
114signal.tofile(padded_signal_file)
115
116# write signal and call dump_data to create features
117
118feature_file = os.path.splitext(args.input)[0] + '_features.f32'
119command = f"{args.dump_data} -test {padded_signal_file} {feature_file}"
120r = subprocess.run(command, shell=True)
121if r.returncode != 0:
122    raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
123
124# load features
125nb_features = model.nb_used_features + lpc_order
126nb_used_features = model.nb_used_features
127
128# load features
129features = np.fromfile(feature_file, dtype='float32')
130num_subframes = len(features) // nb_features
131num_subframes = 2 * (num_subframes // 2)
132num_frames = num_subframes // 2
133
134features = np.reshape(features, (1, -1, nb_features))
135features = features[:, :, :nb_used_features]
136features = features[:, :num_subframes, :]
137
138#variable quantizer depending on the delay
139q0 = 3
140q1 = 15
141quant_id = np.round(q1 + (q0-q1)*np.arange(args.num_redundancy_frames//2)/args.num_redundancy_frames).astype('int16')
142#print(quant_id)
143
144quant_embed = qembedding(quant_id)
145
146# run encoder
147print("running fec encoder...")
148symbols, gru_state_dec = encoder.predict(features)
149
150# apply quantization
151nsymbols = 80
152quant_scale = tf.math.softplus(quant_embed[:, :nsymbols]).numpy()
153dead_zone = tf.math.softplus(quant_embed[:, nsymbols : 2 * nsymbols]).numpy()
154#symbols = apply_dead_zone([symbols, dead_zone]).numpy()
155#qsymbols = np.round(symbols)
156quant_gru_state_dec = pvq_quantize(gru_state_dec, 82)
157
158# rate estimate
159hard_distr_embed = tf.math.sigmoid(quant_embed[:, 4 * nsymbols : ]).numpy()
160#rate_input = np.concatenate((qsymbols, hard_distr_embed, enc_lambda), axis=-1)
161#rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
162
163# run decoder
164input_length = args.num_redundancy_frames // 2
165offset = args.num_redundancy_frames - 1
166
167packets = []
168packet_sizes = []
169
170sym_batch = np.zeros((num_frames-offset, args.num_redundancy_frames//2, nsymbols), dtype='float32')
171quant_state = quant_gru_state_dec[0, offset:num_frames, :]
172#pack symbols for batch processing
173for i in range(offset, num_frames):
174    sym_batch[i-offset, :, :] = symbols[0, i - 2 * input_length + 2 : i + 1 : 2, :]
175
176#quantize symbols
177sym_batch = sym_batch * quant_scale
178sym_batch = apply_dead_zone([sym_batch, dead_zone]).numpy()
179sym_batch = np.round(sym_batch)
180
181hard_distr_embed = np.broadcast_to(hard_distr_embed, (sym_batch.shape[0], sym_batch.shape[1], 2*sym_batch.shape[2]))
182fake_lambda = np.ones((sym_batch.shape[0], sym_batch.shape[1], 1), dtype='float32')
183rate_input = np.concatenate((sym_batch, hard_distr_embed, fake_lambda), axis=-1)
184rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
185#print(rates.shape)
186print("average rate = ", np.mean(rates[args.num_redundancy_frames:,:]))
187
188#sym_batch.tofile('qsyms.f32')
189
190sym_batch = sym_batch / quant_scale
191#print(sym_batch.shape, quant_state.shape)
192#features = decoder.predict([sym_batch, quant_state])
193features = decoder([sym_batch, quant_state])
194
195#for i in range(offset, num_frames):
196#    print(f"processing frame {i - offset}...")
197#    features = decoder.predict([qsymbols[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_embed_dec[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_gru_state_dec[:, i, :]])
198#    packets.append(features)
199#    packet_size = 8 * int((np.sum(rates[:, i - 2 * input_length + 2 : i + 1 : 2]) + 7) / 8) + 64
200#    packet_sizes.append(packet_size)
201
202
203# write packets
204packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
205#write_fec_packets(packet_file, packets, packet_sizes)
206
207
208#print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
209
210if args.lossfile != None:
211    loss = np.loadtxt(args.lossfile, dtype='int16')
212    fec_out = np.zeros((features.shape[0]*2, features.shape[-1]), dtype='float32')
213    foffset = -2
214    ptr = 0;
215    count = 2;
216    for i in range(features.shape[0]):
217        if (loss[i] == 0) or (i == features.shape[0]-1):
218            fec_out[ptr:ptr+count,:] = features[i, foffset:, :]
219            #print("filled ", count)
220            foffset = -2
221            ptr = ptr+count
222            count = 2
223        else:
224            count = count + 2
225            foffset = foffset - 2
226
227    fec_out_full = np.zeros((fec_out.shape[0], nb_features), dtype=np.float32)
228    fec_out_full[:, :nb_used_features] = fec_out
229
230    fec_out_full.tofile(packet_file[:-4] + f'_fec.f32')
231
232
233#create packets array like in the original version for debugging purposes
234for i in range(offset, num_frames):
235    packets.append(features[i-offset:i-offset+1, :, :])
236
237if args.debug_output:
238    import itertools
239
240    #batches = [2, 4]
241    batches = [4]
242    #offsets = [0, 4, 20]
243    offsets = [0, (args.num_redundancy_frames - 2)*2]
244    # sanity checks
245    # 1. concatenate features at offset 0
246    for batch, offset in itertools.product(batches, offsets):
247
248        stop = packets[0].shape[1] - offset
249        print(batch, offset, stop)
250        test_features = np.concatenate([packet[:,stop - batch: stop, :] for packet in packets[::batch//2]], axis=1)
251
252        test_features_full = np.zeros((test_features.shape[1], nb_features), dtype=np.float32)
253        test_features_full[:, :nb_used_features] = test_features[0, :, :]
254
255        print(f"writing debug output {packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32'}")
256        test_features_full.tofile(packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32')
257