xref: /aosp_15_r20/external/libopus/dnn/training_tf2/test_lpcnet.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li'''Copyright (c) 2018 Mozilla
3*a58d3d2aSXin Li
4*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li   are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li'''
27*a58d3d2aSXin Liimport argparse
28*a58d3d2aSXin Liimport sys
29*a58d3d2aSXin Li
30*a58d3d2aSXin Liimport h5py
31*a58d3d2aSXin Liimport numpy as np
32*a58d3d2aSXin Li
33*a58d3d2aSXin Liimport lpcnet
34*a58d3d2aSXin Lifrom ulaw import ulaw2lin, lin2ulaw
35*a58d3d2aSXin Li
36*a58d3d2aSXin Li
37*a58d3d2aSXin Liparser = argparse.ArgumentParser()
38*a58d3d2aSXin Liparser.add_argument('model-file', type=str, help='model weight h5 file')
39*a58d3d2aSXin Liparser.add_argument('--lpc-gamma', type=float, help='LPC weighting factor. WARNING: giving an inconsistent value here will severely degrade performance', default=1)
40*a58d3d2aSXin Li
41*a58d3d2aSXin Liargs = parser.parse_args()
42*a58d3d2aSXin Li
43*a58d3d2aSXin Lifilename = args.model_file
44*a58d3d2aSXin Liwith h5py.File(filename, "r") as f:
45*a58d3d2aSXin Li    units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape)
46*a58d3d2aSXin Li    units2 = min(f['model_weights']['gru_b']['gru_b']['recurrent_kernel:0'].shape)
47*a58d3d2aSXin Li    cond_size = min(f['model_weights']['feature_dense1']['feature_dense1']['kernel:0'].shape)
48*a58d3d2aSXin Li    e2e = 'rc2lpc' in f['model_weights']
49*a58d3d2aSXin Li
50*a58d3d2aSXin Li
51*a58d3d2aSXin Limodel, enc, dec = lpcnet.new_lpcnet_model(training = False, rnn_units1=units, rnn_units2=units2, flag_e2e = e2e, cond_size=cond_size, batch_size=1)
52*a58d3d2aSXin Li
53*a58d3d2aSXin Limodel.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
54*a58d3d2aSXin Li#model.summary()
55*a58d3d2aSXin Li
56*a58d3d2aSXin Li
57*a58d3d2aSXin Lifeature_file = sys.argv[2]
58*a58d3d2aSXin Liout_file = sys.argv[3]
59*a58d3d2aSXin Liframe_size = model.frame_size
60*a58d3d2aSXin Linb_features = 36
61*a58d3d2aSXin Linb_used_features = model.nb_used_features
62*a58d3d2aSXin Li
63*a58d3d2aSXin Lifeatures = np.fromfile(feature_file, dtype='float32')
64*a58d3d2aSXin Lifeatures = np.resize(features, (-1, nb_features))
65*a58d3d2aSXin Linb_frames = 1
66*a58d3d2aSXin Lifeature_chunk_size = features.shape[0]
67*a58d3d2aSXin Lipcm_chunk_size = frame_size*feature_chunk_size
68*a58d3d2aSXin Li
69*a58d3d2aSXin Lifeatures = np.reshape(features, (nb_frames, feature_chunk_size, nb_features))
70*a58d3d2aSXin Liperiods = (.1 + 50*features[:,:,18:19]+100).astype('int16')
71*a58d3d2aSXin Li
72*a58d3d2aSXin Li
73*a58d3d2aSXin Li
74*a58d3d2aSXin Limodel.load_weights(filename);
75*a58d3d2aSXin Li
76*a58d3d2aSXin Liorder = 16
77*a58d3d2aSXin Li
78*a58d3d2aSXin Lipcm = np.zeros((nb_frames*pcm_chunk_size, ))
79*a58d3d2aSXin Lifexc = np.zeros((1, 1, 3), dtype='int16')+128
80*a58d3d2aSXin Listate1 = np.zeros((1, model.rnn_units1), dtype='float32')
81*a58d3d2aSXin Listate2 = np.zeros((1, model.rnn_units2), dtype='float32')
82*a58d3d2aSXin Li
83*a58d3d2aSXin Limem = 0
84*a58d3d2aSXin Licoef = 0.85
85*a58d3d2aSXin Li
86*a58d3d2aSXin Lilpc_weights = np.array([args.lpc_gamma ** (i + 1) for i in range(16)])
87*a58d3d2aSXin Li
88*a58d3d2aSXin Lifout = open(out_file, 'wb')
89*a58d3d2aSXin Li
90*a58d3d2aSXin Liskip = order + 1
91*a58d3d2aSXin Lifor c in range(0, nb_frames):
92*a58d3d2aSXin Li    if not e2e:
93*a58d3d2aSXin Li        cfeat = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]])
94*a58d3d2aSXin Li    else:
95*a58d3d2aSXin Li        cfeat,lpcs = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]])
96*a58d3d2aSXin Li    for fr in range(0, feature_chunk_size):
97*a58d3d2aSXin Li        f = c*feature_chunk_size + fr
98*a58d3d2aSXin Li        if not e2e:
99*a58d3d2aSXin Li            a = features[c, fr, nb_features-order:] * lpc_weights
100*a58d3d2aSXin Li        else:
101*a58d3d2aSXin Li            a = lpcs[c,fr]
102*a58d3d2aSXin Li        for i in range(skip, frame_size):
103*a58d3d2aSXin Li            pred = -sum(a*pcm[f*frame_size + i - 1:f*frame_size + i - order-1:-1])
104*a58d3d2aSXin Li            fexc[0, 0, 1] = lin2ulaw(pred)
105*a58d3d2aSXin Li
106*a58d3d2aSXin Li            p, state1, state2 = dec.predict([fexc, cfeat[:, fr:fr+1, :], state1, state2])
107*a58d3d2aSXin Li            #Lower the temperature for voiced frames to reduce noisiness
108*a58d3d2aSXin Li            p *= np.power(p, np.maximum(0, 1.5*features[c, fr, 19] - .5))
109*a58d3d2aSXin Li            p = p/(1e-18 + np.sum(p))
110*a58d3d2aSXin Li            #Cut off the tail of the remaining distribution
111*a58d3d2aSXin Li            p = np.maximum(p-0.002, 0).astype('float64')
112*a58d3d2aSXin Li            p = p/(1e-8 + np.sum(p))
113*a58d3d2aSXin Li
114*a58d3d2aSXin Li            fexc[0, 0, 2] = np.argmax(np.random.multinomial(1, p[0,0,:], 1))
115*a58d3d2aSXin Li            pcm[f*frame_size + i] = pred + ulaw2lin(fexc[0, 0, 2])
116*a58d3d2aSXin Li            fexc[0, 0, 0] = lin2ulaw(pcm[f*frame_size + i])
117*a58d3d2aSXin Li            mem = coef*mem + pcm[f*frame_size + i]
118*a58d3d2aSXin Li            #print(mem)
119*a58d3d2aSXin Li            np.array([np.round(mem)], dtype='int16').tofile(fout)
120*a58d3d2aSXin Li        skip = 0
121