xref: /aosp_15_r20/external/libopus/dnn/training_tf2/fec_encoder.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li"""
2*a58d3d2aSXin Li/* Copyright (c) 2022 Amazon
3*a58d3d2aSXin Li   Written by Jan Buethe and Jean-Marc Valin */
4*a58d3d2aSXin Li/*
5*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
6*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
7*a58d3d2aSXin Li   are met:
8*a58d3d2aSXin Li
9*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
10*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
11*a58d3d2aSXin Li
12*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
13*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
14*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
15*a58d3d2aSXin Li
16*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20*a58d3d2aSXin Li   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*a58d3d2aSXin Li*/
28*a58d3d2aSXin Li"""
29*a58d3d2aSXin Liimport os
30*a58d3d2aSXin Liimport subprocess
31*a58d3d2aSXin Liimport argparse
32*a58d3d2aSXin Li
33*a58d3d2aSXin Li
34*a58d3d2aSXin Liimport numpy as np
35*a58d3d2aSXin Lifrom scipy.io import wavfile
36*a58d3d2aSXin Liimport tensorflow as tf
37*a58d3d2aSXin Li
38*a58d3d2aSXin Lifrom rdovae import new_rdovae_model, pvq_quantize, apply_dead_zone, sq_rate_metric
39*a58d3d2aSXin Lifrom fec_packets import write_fec_packets, read_fec_packets
40*a58d3d2aSXin Li
41*a58d3d2aSXin Li
42*a58d3d2aSXin Lidebug = False
43*a58d3d2aSXin Li
44*a58d3d2aSXin Liif debug:
45*a58d3d2aSXin Li    args = type('dummy', (object,),
46*a58d3d2aSXin Li    {
47*a58d3d2aSXin Li        'input' : 'item1.wav',
48*a58d3d2aSXin Li        'weights' : 'testout/rdovae_alignment_fix_1024_120.h5',
49*a58d3d2aSXin Li        'enc_lambda' : 0.0007,
50*a58d3d2aSXin Li        'output' : "test_0007.fec",
51*a58d3d2aSXin Li        'cond_size' : 1024,
52*a58d3d2aSXin Li        'num_redundancy_frames' : 64,
53*a58d3d2aSXin Li        'extra_delay' : 0,
54*a58d3d2aSXin Li        'dump_data' : './dump_data'
55*a58d3d2aSXin Li    })()
56*a58d3d2aSXin Li    os.environ['CUDA_VISIBLE_DEVICES']=""
57*a58d3d2aSXin Lielse:
58*a58d3d2aSXin Li    parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames')
59*a58d3d2aSXin Li
60*a58d3d2aSXin Li    parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)')
61*a58d3d2aSXin Li    parser.add_argument('weights', metavar='<weights>', help='trained model file (.h5)')
62*a58d3d2aSXin Li#    parser.add_argument('enc_lambda', metavar='<lambda>', type=float, help='lambda for controlling encoder rate')
63*a58d3d2aSXin Li    parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
64*a58d3d2aSXin Li
65*a58d3d2aSXin Li    parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
66*a58d3d2aSXin Li    parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)')
67*a58d3d2aSXin Li    parser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 40)", default=40)
68*a58d3d2aSXin Li    parser.add_argument('--num-redundancy-frames', default=64, type=int, help='number of redundancy frames (20ms) per packet (default 64)')
69*a58d3d2aSXin Li    parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
70*a58d3d2aSXin Li    parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
71*a58d3d2aSXin Li
72*a58d3d2aSXin Li    parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk')
73*a58d3d2aSXin Li
74*a58d3d2aSXin Li    args = parser.parse_args()
75*a58d3d2aSXin Li
76*a58d3d2aSXin Limodel, encoder, decoder, qembedding = new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=1, nb_quant=args.quant_levels, cond_size=args.cond_size)
77*a58d3d2aSXin Limodel.load_weights(args.weights)
78*a58d3d2aSXin Li
79*a58d3d2aSXin Lilpc_order = 16
80*a58d3d2aSXin Li
81*a58d3d2aSXin Li## prepare input signal
82*a58d3d2aSXin Li# SILK frame size is 20ms and LPCNet subframes are 10ms
83*a58d3d2aSXin Lisubframe_size = 160
84*a58d3d2aSXin Liframe_size = 2 * subframe_size
85*a58d3d2aSXin Li
86*a58d3d2aSXin Li# 91 samples delay to align with SILK decoded frames
87*a58d3d2aSXin Lisilk_delay = 91
88*a58d3d2aSXin Li
89*a58d3d2aSXin Li# prepend zeros to have enough history to produce the first package
90*a58d3d2aSXin Lizero_history = (args.num_redundancy_frames - 1) * frame_size
91*a58d3d2aSXin Li
92*a58d3d2aSXin Li# dump data has a (feature) delay of 10ms
93*a58d3d2aSXin Lidump_data_delay = 160
94*a58d3d2aSXin Li
95*a58d3d2aSXin Litotal_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay
96*a58d3d2aSXin Li
97*a58d3d2aSXin Li# load signal
98*a58d3d2aSXin Liif args.input.endswith('.raw') or args.input.endswith('.pcm') or args.input.endswith('.sw'):
99*a58d3d2aSXin Li    signal = np.fromfile(args.input, dtype='int16')
100*a58d3d2aSXin Li
101*a58d3d2aSXin Lielif args.input.endswith('.wav'):
102*a58d3d2aSXin Li    fs, signal = wavfile.read(args.input)
103*a58d3d2aSXin Lielse:
104*a58d3d2aSXin Li    raise ValueError(f'unknown input signal format: {args.input}')
105*a58d3d2aSXin Li
106*a58d3d2aSXin Li# fill up last frame with zeros
107*a58d3d2aSXin Lipadded_signal_length = len(signal) + total_delay
108*a58d3d2aSXin Litail = padded_signal_length % frame_size
109*a58d3d2aSXin Liright_padding = (frame_size - tail) % frame_size
110*a58d3d2aSXin Li
111*a58d3d2aSXin Lisignal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16)))
112*a58d3d2aSXin Li
113*a58d3d2aSXin Lipadded_signal_file  = os.path.splitext(args.input)[0] + '_padded.raw'
114*a58d3d2aSXin Lisignal.tofile(padded_signal_file)
115*a58d3d2aSXin Li
116*a58d3d2aSXin Li# write signal and call dump_data to create features
117*a58d3d2aSXin Li
118*a58d3d2aSXin Lifeature_file = os.path.splitext(args.input)[0] + '_features.f32'
119*a58d3d2aSXin Licommand = f"{args.dump_data} -test {padded_signal_file} {feature_file}"
120*a58d3d2aSXin Lir = subprocess.run(command, shell=True)
121*a58d3d2aSXin Liif r.returncode != 0:
122*a58d3d2aSXin Li    raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
123*a58d3d2aSXin Li
124*a58d3d2aSXin Li# load features
125*a58d3d2aSXin Linb_features = model.nb_used_features + lpc_order
126*a58d3d2aSXin Linb_used_features = model.nb_used_features
127*a58d3d2aSXin Li
128*a58d3d2aSXin Li# load features
129*a58d3d2aSXin Lifeatures = np.fromfile(feature_file, dtype='float32')
130*a58d3d2aSXin Linum_subframes = len(features) // nb_features
131*a58d3d2aSXin Linum_subframes = 2 * (num_subframes // 2)
132*a58d3d2aSXin Linum_frames = num_subframes // 2
133*a58d3d2aSXin Li
134*a58d3d2aSXin Lifeatures = np.reshape(features, (1, -1, nb_features))
135*a58d3d2aSXin Lifeatures = features[:, :, :nb_used_features]
136*a58d3d2aSXin Lifeatures = features[:, :num_subframes, :]
137*a58d3d2aSXin Li
138*a58d3d2aSXin Li#variable quantizer depending on the delay
139*a58d3d2aSXin Liq0 = 3
140*a58d3d2aSXin Liq1 = 15
141*a58d3d2aSXin Liquant_id = np.round(q1 + (q0-q1)*np.arange(args.num_redundancy_frames//2)/args.num_redundancy_frames).astype('int16')
142*a58d3d2aSXin Li#print(quant_id)
143*a58d3d2aSXin Li
144*a58d3d2aSXin Liquant_embed = qembedding(quant_id)
145*a58d3d2aSXin Li
146*a58d3d2aSXin Li# run encoder
147*a58d3d2aSXin Liprint("running fec encoder...")
148*a58d3d2aSXin Lisymbols, gru_state_dec = encoder.predict(features)
149*a58d3d2aSXin Li
150*a58d3d2aSXin Li# apply quantization
151*a58d3d2aSXin Linsymbols = 80
152*a58d3d2aSXin Liquant_scale = tf.math.softplus(quant_embed[:, :nsymbols]).numpy()
153*a58d3d2aSXin Lidead_zone = tf.math.softplus(quant_embed[:, nsymbols : 2 * nsymbols]).numpy()
154*a58d3d2aSXin Li#symbols = apply_dead_zone([symbols, dead_zone]).numpy()
155*a58d3d2aSXin Li#qsymbols = np.round(symbols)
156*a58d3d2aSXin Liquant_gru_state_dec = pvq_quantize(gru_state_dec, 82)
157*a58d3d2aSXin Li
158*a58d3d2aSXin Li# rate estimate
159*a58d3d2aSXin Lihard_distr_embed = tf.math.sigmoid(quant_embed[:, 4 * nsymbols : ]).numpy()
160*a58d3d2aSXin Li#rate_input = np.concatenate((qsymbols, hard_distr_embed, enc_lambda), axis=-1)
161*a58d3d2aSXin Li#rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
162*a58d3d2aSXin Li
163*a58d3d2aSXin Li# run decoder
164*a58d3d2aSXin Liinput_length = args.num_redundancy_frames // 2
165*a58d3d2aSXin Lioffset = args.num_redundancy_frames - 1
166*a58d3d2aSXin Li
167*a58d3d2aSXin Lipackets = []
168*a58d3d2aSXin Lipacket_sizes = []
169*a58d3d2aSXin Li
170*a58d3d2aSXin Lisym_batch = np.zeros((num_frames-offset, args.num_redundancy_frames//2, nsymbols), dtype='float32')
171*a58d3d2aSXin Liquant_state = quant_gru_state_dec[0, offset:num_frames, :]
172*a58d3d2aSXin Li#pack symbols for batch processing
173*a58d3d2aSXin Lifor i in range(offset, num_frames):
174*a58d3d2aSXin Li    sym_batch[i-offset, :, :] = symbols[0, i - 2 * input_length + 2 : i + 1 : 2, :]
175*a58d3d2aSXin Li
176*a58d3d2aSXin Li#quantize symbols
177*a58d3d2aSXin Lisym_batch = sym_batch * quant_scale
178*a58d3d2aSXin Lisym_batch = apply_dead_zone([sym_batch, dead_zone]).numpy()
179*a58d3d2aSXin Lisym_batch = np.round(sym_batch)
180*a58d3d2aSXin Li
181*a58d3d2aSXin Lihard_distr_embed = np.broadcast_to(hard_distr_embed, (sym_batch.shape[0], sym_batch.shape[1], 2*sym_batch.shape[2]))
182*a58d3d2aSXin Lifake_lambda = np.ones((sym_batch.shape[0], sym_batch.shape[1], 1), dtype='float32')
183*a58d3d2aSXin Lirate_input = np.concatenate((sym_batch, hard_distr_embed, fake_lambda), axis=-1)
184*a58d3d2aSXin Lirates = sq_rate_metric(None, rate_input, reduce=False).numpy()
185*a58d3d2aSXin Li#print(rates.shape)
186*a58d3d2aSXin Liprint("average rate = ", np.mean(rates[args.num_redundancy_frames:,:]))
187*a58d3d2aSXin Li
188*a58d3d2aSXin Li#sym_batch.tofile('qsyms.f32')
189*a58d3d2aSXin Li
190*a58d3d2aSXin Lisym_batch = sym_batch / quant_scale
191*a58d3d2aSXin Li#print(sym_batch.shape, quant_state.shape)
192*a58d3d2aSXin Li#features = decoder.predict([sym_batch, quant_state])
193*a58d3d2aSXin Lifeatures = decoder([sym_batch, quant_state])
194*a58d3d2aSXin Li
195*a58d3d2aSXin Li#for i in range(offset, num_frames):
196*a58d3d2aSXin Li#    print(f"processing frame {i - offset}...")
197*a58d3d2aSXin Li#    features = decoder.predict([qsymbols[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_embed_dec[:, i - 2 * input_length + 2 : i + 1 : 2, :], quant_gru_state_dec[:, i, :]])
198*a58d3d2aSXin Li#    packets.append(features)
199*a58d3d2aSXin Li#    packet_size = 8 * int((np.sum(rates[:, i - 2 * input_length + 2 : i + 1 : 2]) + 7) / 8) + 64
200*a58d3d2aSXin Li#    packet_sizes.append(packet_size)
201*a58d3d2aSXin Li
202*a58d3d2aSXin Li
203*a58d3d2aSXin Li# write packets
204*a58d3d2aSXin Lipacket_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
205*a58d3d2aSXin Li#write_fec_packets(packet_file, packets, packet_sizes)
206*a58d3d2aSXin Li
207*a58d3d2aSXin Li
208*a58d3d2aSXin Li#print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
209*a58d3d2aSXin Li
210*a58d3d2aSXin Liif args.lossfile != None:
211*a58d3d2aSXin Li    loss = np.loadtxt(args.lossfile, dtype='int16')
212*a58d3d2aSXin Li    fec_out = np.zeros((features.shape[0]*2, features.shape[-1]), dtype='float32')
213*a58d3d2aSXin Li    foffset = -2
214*a58d3d2aSXin Li    ptr = 0;
215*a58d3d2aSXin Li    count = 2;
216*a58d3d2aSXin Li    for i in range(features.shape[0]):
217*a58d3d2aSXin Li        if (loss[i] == 0) or (i == features.shape[0]-1):
218*a58d3d2aSXin Li            fec_out[ptr:ptr+count,:] = features[i, foffset:, :]
219*a58d3d2aSXin Li            #print("filled ", count)
220*a58d3d2aSXin Li            foffset = -2
221*a58d3d2aSXin Li            ptr = ptr+count
222*a58d3d2aSXin Li            count = 2
223*a58d3d2aSXin Li        else:
224*a58d3d2aSXin Li            count = count + 2
225*a58d3d2aSXin Li            foffset = foffset - 2
226*a58d3d2aSXin Li
227*a58d3d2aSXin Li    fec_out_full = np.zeros((fec_out.shape[0], nb_features), dtype=np.float32)
228*a58d3d2aSXin Li    fec_out_full[:, :nb_used_features] = fec_out
229*a58d3d2aSXin Li
230*a58d3d2aSXin Li    fec_out_full.tofile(packet_file[:-4] + f'_fec.f32')
231*a58d3d2aSXin Li
232*a58d3d2aSXin Li
233*a58d3d2aSXin Li#create packets array like in the original version for debugging purposes
234*a58d3d2aSXin Lifor i in range(offset, num_frames):
235*a58d3d2aSXin Li    packets.append(features[i-offset:i-offset+1, :, :])
236*a58d3d2aSXin Li
237*a58d3d2aSXin Liif args.debug_output:
238*a58d3d2aSXin Li    import itertools
239*a58d3d2aSXin Li
240*a58d3d2aSXin Li    #batches = [2, 4]
241*a58d3d2aSXin Li    batches = [4]
242*a58d3d2aSXin Li    #offsets = [0, 4, 20]
243*a58d3d2aSXin Li    offsets = [0, (args.num_redundancy_frames - 2)*2]
244*a58d3d2aSXin Li    # sanity checks
245*a58d3d2aSXin Li    # 1. concatenate features at offset 0
246*a58d3d2aSXin Li    for batch, offset in itertools.product(batches, offsets):
247*a58d3d2aSXin Li
248*a58d3d2aSXin Li        stop = packets[0].shape[1] - offset
249*a58d3d2aSXin Li        print(batch, offset, stop)
250*a58d3d2aSXin Li        test_features = np.concatenate([packet[:,stop - batch: stop, :] for packet in packets[::batch//2]], axis=1)
251*a58d3d2aSXin Li
252*a58d3d2aSXin Li        test_features_full = np.zeros((test_features.shape[1], nb_features), dtype=np.float32)
253*a58d3d2aSXin Li        test_features_full[:, :nb_used_features] = test_features[0, :, :]
254*a58d3d2aSXin Li
255*a58d3d2aSXin Li        print(f"writing debug output {packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32'}")
256*a58d3d2aSXin Li        test_features_full.tofile(packet_file[:-4] + f'_tf_batch{batch}_offset{offset}.f32')
257