xref: /aosp_15_r20/external/libopus/dnn/training_tf2/dump_lpcnet.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li'''Copyright (c) 2017-2018 Mozilla
3*a58d3d2aSXin Li
4*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li   are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li'''
27*a58d3d2aSXin Li
28*a58d3d2aSXin Liimport os
29*a58d3d2aSXin Liimport io
30*a58d3d2aSXin Liimport lpcnet
31*a58d3d2aSXin Liimport sys
32*a58d3d2aSXin Liimport numpy as np
33*a58d3d2aSXin Lifrom tensorflow.keras.optimizers import Adam
34*a58d3d2aSXin Lifrom tensorflow.keras.layers import Layer, GRU, Dense, Conv1D, Embedding
35*a58d3d2aSXin Lifrom ulaw import ulaw2lin, lin2ulaw
36*a58d3d2aSXin Lifrom mdense import MDense
37*a58d3d2aSXin Lifrom diffembed import diff_Embed
38*a58d3d2aSXin Lifrom parameters import get_parameter
39*a58d3d2aSXin Liimport h5py
40*a58d3d2aSXin Liimport re
41*a58d3d2aSXin Liimport argparse
42*a58d3d2aSXin Li
43*a58d3d2aSXin Li
44*a58d3d2aSXin Li# no cuda devices needed
45*a58d3d2aSXin Lios.environ['CUDA_VISIBLE_DEVICES'] = ""
46*a58d3d2aSXin Li
47*a58d3d2aSXin Li# Flag for dumping e2e (differentiable lpc) network weights
48*a58d3d2aSXin Liflag_e2e = False
49*a58d3d2aSXin Li
50*a58d3d2aSXin Li
51*a58d3d2aSXin Limax_rnn_neurons = 1
52*a58d3d2aSXin Limax_conv_inputs = 1
53*a58d3d2aSXin Limax_mdense_tmp = 1
54*a58d3d2aSXin Li
55*a58d3d2aSXin Lidef printVector(f, vector, name, dtype='float', dotp=False):
56*a58d3d2aSXin Li    global array_list
57*a58d3d2aSXin Li    if dotp:
58*a58d3d2aSXin Li        vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8))
59*a58d3d2aSXin Li        vector = vector.transpose((2, 0, 3, 1))
60*a58d3d2aSXin Li    v = np.reshape(vector, (-1));
61*a58d3d2aSXin Li    #print('static const float ', name, '[', len(v), '] = \n', file=f)
62*a58d3d2aSXin Li    if name not in array_list:
63*a58d3d2aSXin Li        array_list.append(name)
64*a58d3d2aSXin Li    f.write('#ifndef USE_WEIGHTS_FILE\n')
65*a58d3d2aSXin Li    f.write('#define WEIGHTS_{}_DEFINED\n'.format(name))
66*a58d3d2aSXin Li    f.write('#define WEIGHTS_{}_TYPE WEIGHT_TYPE_{}\n'.format(name, dtype))
67*a58d3d2aSXin Li    f.write('static const {} {}[{}] = {{\n   '.format(dtype, name, len(v)))
68*a58d3d2aSXin Li    for i in range(0, len(v)):
69*a58d3d2aSXin Li        f.write('{}'.format(v[i]))
70*a58d3d2aSXin Li        if (i!=len(v)-1):
71*a58d3d2aSXin Li            f.write(',')
72*a58d3d2aSXin Li        else:
73*a58d3d2aSXin Li            break;
74*a58d3d2aSXin Li        if (i%8==7):
75*a58d3d2aSXin Li            f.write("\n   ")
76*a58d3d2aSXin Li        else:
77*a58d3d2aSXin Li            f.write(" ")
78*a58d3d2aSXin Li    #print(v, file=f)
79*a58d3d2aSXin Li    f.write('\n};\n')
80*a58d3d2aSXin Li    f.write('#endif\n\n')
81*a58d3d2aSXin Li    return;
82*a58d3d2aSXin Li
83*a58d3d2aSXin Lidef printSparseVector(f, A, name, have_diag=True):
84*a58d3d2aSXin Li    N = A.shape[0]
85*a58d3d2aSXin Li    M = A.shape[1]
86*a58d3d2aSXin Li    W = np.zeros((0,), dtype='int')
87*a58d3d2aSXin Li    W0 = np.zeros((0,))
88*a58d3d2aSXin Li    if have_diag:
89*a58d3d2aSXin Li        diag = np.concatenate([np.diag(A[:,:N]), np.diag(A[:,N:2*N]), np.diag(A[:,2*N:])])
90*a58d3d2aSXin Li        A[:,:N] = A[:,:N] - np.diag(np.diag(A[:,:N]))
91*a58d3d2aSXin Li        A[:,N:2*N] = A[:,N:2*N] - np.diag(np.diag(A[:,N:2*N]))
92*a58d3d2aSXin Li        A[:,2*N:] = A[:,2*N:] - np.diag(np.diag(A[:,2*N:]))
93*a58d3d2aSXin Li        printVector(f, diag, name + '_diag')
94*a58d3d2aSXin Li    AQ = np.minimum(127, np.maximum(-128, np.round(A*128))).astype('int')
95*a58d3d2aSXin Li    idx = np.zeros((0,), dtype='int')
96*a58d3d2aSXin Li    for i in range(M//8):
97*a58d3d2aSXin Li        pos = idx.shape[0]
98*a58d3d2aSXin Li        idx = np.append(idx, -1)
99*a58d3d2aSXin Li        nb_nonzero = 0
100*a58d3d2aSXin Li        for j in range(N//4):
101*a58d3d2aSXin Li            block = A[j*4:(j+1)*4, i*8:(i+1)*8]
102*a58d3d2aSXin Li            qblock = AQ[j*4:(j+1)*4, i*8:(i+1)*8]
103*a58d3d2aSXin Li            if np.sum(np.abs(block)) > 1e-10:
104*a58d3d2aSXin Li                nb_nonzero = nb_nonzero + 1
105*a58d3d2aSXin Li                idx = np.append(idx, j*4)
106*a58d3d2aSXin Li                vblock = qblock.transpose((1,0)).reshape((-1,))
107*a58d3d2aSXin Li                W0 = np.concatenate([W0, block.reshape((-1,))])
108*a58d3d2aSXin Li                W = np.concatenate([W, vblock])
109*a58d3d2aSXin Li        idx[pos] = nb_nonzero
110*a58d3d2aSXin Li    f.write('#ifdef DOT_PROD\n')
111*a58d3d2aSXin Li    printVector(f, W, name, dtype='qweight')
112*a58d3d2aSXin Li    f.write('#else /*DOT_PROD*/\n')
113*a58d3d2aSXin Li    printVector(f, W0, name, dtype='qweight')
114*a58d3d2aSXin Li    f.write('#endif /*DOT_PROD*/\n')
115*a58d3d2aSXin Li    #idx = np.tile(np.concatenate([np.array([N]), np.arange(N)]), 3*N//16)
116*a58d3d2aSXin Li    printVector(f, idx, name + '_idx', dtype='int')
117*a58d3d2aSXin Li    return AQ
118*a58d3d2aSXin Li
119*a58d3d2aSXin Lidef dump_layer_ignore(self, f, hf):
120*a58d3d2aSXin Li    print("ignoring layer " + self.name + " of type " + self.__class__.__name__)
121*a58d3d2aSXin Li    return False
122*a58d3d2aSXin LiLayer.dump_layer = dump_layer_ignore
123*a58d3d2aSXin Li
124*a58d3d2aSXin Lidef dump_sparse_gru(self, f, hf):
125*a58d3d2aSXin Li    global max_rnn_neurons
126*a58d3d2aSXin Li    name = 'sparse_' + self.name
127*a58d3d2aSXin Li    print("printing layer " + name + " of type sparse " + self.__class__.__name__)
128*a58d3d2aSXin Li    weights = self.get_weights()
129*a58d3d2aSXin Li    qweights = printSparseVector(f, weights[1], name + '_recurrent_weights')
130*a58d3d2aSXin Li    printVector(f, weights[-1], name + '_bias')
131*a58d3d2aSXin Li    subias = weights[-1].copy()
132*a58d3d2aSXin Li    subias[1,:] = subias[1,:] - np.sum(qweights*(1./128),axis=0)
133*a58d3d2aSXin Li    printVector(f, subias, name + '_subias')
134*a58d3d2aSXin Li    if hasattr(self, 'activation'):
135*a58d3d2aSXin Li        activation = self.activation.__name__.upper()
136*a58d3d2aSXin Li    else:
137*a58d3d2aSXin Li        activation = 'TANH'
138*a58d3d2aSXin Li    if hasattr(self, 'reset_after') and not self.reset_after:
139*a58d3d2aSXin Li        reset_after = 0
140*a58d3d2aSXin Li    else:
141*a58d3d2aSXin Li        reset_after = 1
142*a58d3d2aSXin Li    neurons = weights[0].shape[1]//3
143*a58d3d2aSXin Li    max_rnn_neurons = max(max_rnn_neurons, neurons)
144*a58d3d2aSXin Li    hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights[0].shape[1]//3))
145*a58d3d2aSXin Li    hf.write('#define {}_STATE_SIZE {}\n'.format(name.upper(), weights[0].shape[1]//3))
146*a58d3d2aSXin Li    model_struct.write('  SparseGRULayer {};\n'.format(name));
147*a58d3d2aSXin Li    model_init.write('  if (sparse_gru_init(&model->{}, arrays, "{}_bias", "{}_subias", "{}_recurrent_weights_diag", "{}_recurrent_weights", "{}_recurrent_weights_idx",  {}, ACTIVATION_{}, {})) return 1;\n'
148*a58d3d2aSXin Li            .format(name, name, name, name, name, name, weights[0].shape[1]//3, activation, reset_after))
149*a58d3d2aSXin Li    return True
150*a58d3d2aSXin Li
151*a58d3d2aSXin Lidef dump_grub(self, f, hf, gru_a_size):
152*a58d3d2aSXin Li    global max_rnn_neurons
153*a58d3d2aSXin Li    name = self.name
154*a58d3d2aSXin Li    print("printing layer " + name + " of type " + self.__class__.__name__)
155*a58d3d2aSXin Li    weights = self.get_weights()
156*a58d3d2aSXin Li    qweight = printSparseVector(f, weights[0][:gru_a_size, :], name + '_weights', have_diag=False)
157*a58d3d2aSXin Li
158*a58d3d2aSXin Li    f.write('#ifdef DOT_PROD\n')
159*a58d3d2aSXin Li    qweight2 = np.clip(np.round(128.*weights[1]).astype('int'), -128, 127)
160*a58d3d2aSXin Li    printVector(f, qweight2, name + '_recurrent_weights', dotp=True, dtype='qweight')
161*a58d3d2aSXin Li    f.write('#else /*DOT_PROD*/\n')
162*a58d3d2aSXin Li    printVector(f, weights[1], name + '_recurrent_weights')
163*a58d3d2aSXin Li    f.write('#endif /*DOT_PROD*/\n')
164*a58d3d2aSXin Li
165*a58d3d2aSXin Li    printVector(f, weights[-1], name + '_bias')
166*a58d3d2aSXin Li    subias = weights[-1].copy()
167*a58d3d2aSXin Li    subias[0,:] = subias[0,:] - np.sum(qweight*(1./128.),axis=0)
168*a58d3d2aSXin Li    subias[1,:] = subias[1,:] - np.sum(qweight2*(1./128.),axis=0)
169*a58d3d2aSXin Li    printVector(f, subias, name + '_subias')
170*a58d3d2aSXin Li    if hasattr(self, 'activation'):
171*a58d3d2aSXin Li        activation = self.activation.__name__.upper()
172*a58d3d2aSXin Li    else:
173*a58d3d2aSXin Li        activation = 'TANH'
174*a58d3d2aSXin Li    if hasattr(self, 'reset_after') and not self.reset_after:
175*a58d3d2aSXin Li        reset_after = 0
176*a58d3d2aSXin Li    else:
177*a58d3d2aSXin Li        reset_after = 1
178*a58d3d2aSXin Li    neurons = weights[0].shape[1]//3
179*a58d3d2aSXin Li    max_rnn_neurons = max(max_rnn_neurons, neurons)
180*a58d3d2aSXin Li    model_struct.write('  GRULayer {};\n'.format(name));
181*a58d3d2aSXin Li    model_init.write('  if (gru_init(&model->{}, arrays, "{}_bias", "{}_subias", "{}_weights", "{}_weights_idx", "{}_recurrent_weights", {}, {}, ACTIVATION_{}, {})) return 1;\n'
182*a58d3d2aSXin Li            .format(name, name, name, name, name, name, gru_a_size, weights[0].shape[1]//3, activation, reset_after))
183*a58d3d2aSXin Li    return True
184*a58d3d2aSXin Li
185*a58d3d2aSXin Lidef dump_gru_layer_dummy(self, f, hf):
186*a58d3d2aSXin Li    name = self.name
187*a58d3d2aSXin Li    weights = self.get_weights()
188*a58d3d2aSXin Li    hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights[0].shape[1]//3))
189*a58d3d2aSXin Li    hf.write('#define {}_STATE_SIZE {}\n'.format(name.upper(), weights[0].shape[1]//3))
190*a58d3d2aSXin Li    return True;
191*a58d3d2aSXin Li
192*a58d3d2aSXin LiGRU.dump_layer = dump_gru_layer_dummy
193*a58d3d2aSXin Li
194*a58d3d2aSXin Lidef dump_dense_layer_impl(name, weights, bias, activation, f, hf):
195*a58d3d2aSXin Li    printVector(f, weights, name + '_weights')
196*a58d3d2aSXin Li    printVector(f, bias, name + '_bias')
197*a58d3d2aSXin Li    hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights.shape[1]))
198*a58d3d2aSXin Li    model_struct.write('  DenseLayer {};\n'.format(name));
199*a58d3d2aSXin Li    model_init.write('  if (dense_init(&model->{}, arrays, "{}_bias", "{}_weights", {}, {}, ACTIVATION_{})) return 1;\n'
200*a58d3d2aSXin Li            .format(name, name, name, weights.shape[0], weights.shape[1], activation))
201*a58d3d2aSXin Li
202*a58d3d2aSXin Lidef dump_dense_layer(self, f, hf):
203*a58d3d2aSXin Li    name = self.name
204*a58d3d2aSXin Li    print("printing layer " + name + " of type " + self.__class__.__name__)
205*a58d3d2aSXin Li    weights = self.get_weights()
206*a58d3d2aSXin Li    activation = self.activation.__name__.upper()
207*a58d3d2aSXin Li    dump_dense_layer_impl(name, weights[0], weights[1], activation, f, hf)
208*a58d3d2aSXin Li    return False
209*a58d3d2aSXin Li
210*a58d3d2aSXin LiDense.dump_layer = dump_dense_layer
211*a58d3d2aSXin Li
212*a58d3d2aSXin Lidef dump_mdense_layer(self, f, hf):
213*a58d3d2aSXin Li    global max_mdense_tmp
214*a58d3d2aSXin Li    name = self.name
215*a58d3d2aSXin Li    print("printing layer " + name + " of type " + self.__class__.__name__)
216*a58d3d2aSXin Li    weights = self.get_weights()
217*a58d3d2aSXin Li    printVector(f, np.transpose(weights[0], (0, 2, 1)), name + '_weights')
218*a58d3d2aSXin Li    printVector(f, np.transpose(weights[1], (1, 0)), name + '_bias')
219*a58d3d2aSXin Li    printVector(f, np.transpose(weights[2], (1, 0)), name + '_factor')
220*a58d3d2aSXin Li    activation = self.activation.__name__.upper()
221*a58d3d2aSXin Li    max_mdense_tmp = max(max_mdense_tmp, weights[0].shape[0]*weights[0].shape[2])
222*a58d3d2aSXin Li    hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights[0].shape[0]))
223*a58d3d2aSXin Li    model_struct.write('  MDenseLayer {};\n'.format(name));
224*a58d3d2aSXin Li    model_init.write('  if (mdense_init(&model->{}, arrays, "{}_bias",  "{}_weights",  "{}_factor",  {}, {}, {}, ACTIVATION_{})) return 1;\n'
225*a58d3d2aSXin Li            .format(name, name, name, name, weights[0].shape[1], weights[0].shape[0], weights[0].shape[2], activation))
226*a58d3d2aSXin Li    return False
227*a58d3d2aSXin LiMDense.dump_layer = dump_mdense_layer
228*a58d3d2aSXin Li
229*a58d3d2aSXin Lidef dump_conv1d_layer(self, f, hf):
230*a58d3d2aSXin Li    global max_conv_inputs
231*a58d3d2aSXin Li    name = self.name
232*a58d3d2aSXin Li    print("printing layer " + name + " of type " + self.__class__.__name__)
233*a58d3d2aSXin Li    weights = self.get_weights()
234*a58d3d2aSXin Li    printVector(f, weights[0], name + '_weights')
235*a58d3d2aSXin Li    printVector(f, weights[-1], name + '_bias')
236*a58d3d2aSXin Li    activation = self.activation.__name__.upper()
237*a58d3d2aSXin Li    max_conv_inputs = max(max_conv_inputs, weights[0].shape[1]*weights[0].shape[0])
238*a58d3d2aSXin Li    hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights[0].shape[2]))
239*a58d3d2aSXin Li    hf.write('#define {}_STATE_SIZE ({}*{})\n'.format(name.upper(), weights[0].shape[1], (weights[0].shape[0]-1)))
240*a58d3d2aSXin Li    hf.write('#define {}_DELAY {}\n'.format(name.upper(), (weights[0].shape[0]-1)//2))
241*a58d3d2aSXin Li    model_struct.write('  Conv1DLayer {};\n'.format(name));
242*a58d3d2aSXin Li    model_init.write('  if (conv1d_init(&model->{}, arrays, "{}_bias", "{}_weights", {}, {}, {}, ACTIVATION_{})) return 1;\n'
243*a58d3d2aSXin Li            .format(name, name, name, weights[0].shape[1], weights[0].shape[0], weights[0].shape[2], activation))
244*a58d3d2aSXin Li    return True
245*a58d3d2aSXin LiConv1D.dump_layer = dump_conv1d_layer
246*a58d3d2aSXin Li
247*a58d3d2aSXin Li
248*a58d3d2aSXin Lidef dump_embedding_layer_impl(name, weights, f, hf):
249*a58d3d2aSXin Li    printVector(f, weights, name + '_weights')
250*a58d3d2aSXin Li    hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights.shape[1]))
251*a58d3d2aSXin Li    model_struct.write('  EmbeddingLayer {};\n'.format(name));
252*a58d3d2aSXin Li    model_init.write('  if (embedding_init(&model->{}, arrays, "{}_weights", {}, {})) return 1;\n'
253*a58d3d2aSXin Li            .format(name, name, weights.shape[0], weights.shape[1]))
254*a58d3d2aSXin Li
255*a58d3d2aSXin Lidef dump_embedding_layer(self, f, hf):
256*a58d3d2aSXin Li    name = self.name
257*a58d3d2aSXin Li    print("printing layer " + name + " of type " + self.__class__.__name__)
258*a58d3d2aSXin Li    weights = self.get_weights()[0]
259*a58d3d2aSXin Li    dump_embedding_layer_impl(name, weights, f, hf)
260*a58d3d2aSXin Li    return False
261*a58d3d2aSXin LiEmbedding.dump_layer = dump_embedding_layer
262*a58d3d2aSXin Lidiff_Embed.dump_layer = dump_embedding_layer
263*a58d3d2aSXin Li
264*a58d3d2aSXin Liif __name__ == "__main__":
265*a58d3d2aSXin Li    parser = argparse.ArgumentParser()
266*a58d3d2aSXin Li    parser.add_argument('model_file', type=str, help='model weight h5 file')
267*a58d3d2aSXin Li    parser.add_argument('--nnet-header', type=str, help='name of c header file for dumped model', default='nnet_data.h')
268*a58d3d2aSXin Li    parser.add_argument('--nnet-source', type=str, help='name of c source file for dumped model', default='nnet_data.c')
269*a58d3d2aSXin Li    parser.add_argument('--lpc-gamma', type=float, help='LPC weighting factor. If not specified I will attempt to read it from the model file with 1 as default', default=None)
270*a58d3d2aSXin Li    parser.add_argument('--lookahead', type=float, help='Features lookahead. If not specified I will attempt to read it from the model file with 2 as default', default=None)
271*a58d3d2aSXin Li
272*a58d3d2aSXin Li    args = parser.parse_args()
273*a58d3d2aSXin Li
274*a58d3d2aSXin Li    filename = args.model_file
275*a58d3d2aSXin Li    with h5py.File(filename, "r") as f:
276*a58d3d2aSXin Li        units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape)
277*a58d3d2aSXin Li        units2 = min(f['model_weights']['gru_b']['gru_b']['recurrent_kernel:0'].shape)
278*a58d3d2aSXin Li        cond_size = min(f['model_weights']['feature_dense1']['feature_dense1']['kernel:0'].shape)
279*a58d3d2aSXin Li        e2e = 'rc2lpc' in f['model_weights']
280*a58d3d2aSXin Li
281*a58d3d2aSXin Li    model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units, rnn_units2=units2, flag_e2e = e2e, cond_size=cond_size)
282*a58d3d2aSXin Li    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
283*a58d3d2aSXin Li    #model.summary()
284*a58d3d2aSXin Li
285*a58d3d2aSXin Li    model.load_weights(filename, by_name=True)
286*a58d3d2aSXin Li
287*a58d3d2aSXin Li    cfile = args.nnet_source
288*a58d3d2aSXin Li    hfile = args.nnet_header
289*a58d3d2aSXin Li
290*a58d3d2aSXin Li    f = open(cfile, 'w')
291*a58d3d2aSXin Li    hf = open(hfile, 'w')
292*a58d3d2aSXin Li    model_struct = io.StringIO()
293*a58d3d2aSXin Li    model_init = io.StringIO()
294*a58d3d2aSXin Li    model_struct.write('typedef struct {\n')
295*a58d3d2aSXin Li    model_init.write('#ifndef DUMP_BINARY_WEIGHTS\n')
296*a58d3d2aSXin Li    model_init.write('int init_lpcnet_model(LPCNetModel *model, const WeightArray *arrays) {\n')
297*a58d3d2aSXin Li    array_list = []
298*a58d3d2aSXin Li
299*a58d3d2aSXin Li    f.write('/*This file is automatically generated from a Keras model*/\n')
300*a58d3d2aSXin Li    f.write('/*based on model {}*/\n\n'.format(sys.argv[1]))
301*a58d3d2aSXin Li    f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "nnet.h"\n#include "{}"\n\n'.format(hfile))
302*a58d3d2aSXin Li
303*a58d3d2aSXin Li    hf.write('/*This file is automatically generated from a Keras model*/\n\n')
304*a58d3d2aSXin Li    hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "nnet.h"\n\n')
305*a58d3d2aSXin Li
306*a58d3d2aSXin Li    if e2e:
307*a58d3d2aSXin Li        hf.write('/* This is an end-to-end model */\n')
308*a58d3d2aSXin Li        hf.write('#define END2END\n\n')
309*a58d3d2aSXin Li    else:
310*a58d3d2aSXin Li        hf.write('/* This is *not* an end-to-end model */\n')
311*a58d3d2aSXin Li        hf.write('/* #define END2END */\n\n')
312*a58d3d2aSXin Li
313*a58d3d2aSXin Li    # LPC weighting factor
314*a58d3d2aSXin Li    if type(args.lpc_gamma) == type(None):
315*a58d3d2aSXin Li        lpc_gamma = get_parameter(model, 'lpc_gamma', 1)
316*a58d3d2aSXin Li    else:
317*a58d3d2aSXin Li        lpc_gamma = args.lpc_gamma
318*a58d3d2aSXin Li
319*a58d3d2aSXin Li    hf.write('/* LPC weighting factor */\n')
320*a58d3d2aSXin Li    hf.write('#define LPC_GAMMA ' + str(lpc_gamma) +'f\n\n')
321*a58d3d2aSXin Li
322*a58d3d2aSXin Li    # look-ahead
323*a58d3d2aSXin Li    if type(args.lookahead) == type(None):
324*a58d3d2aSXin Li        lookahead = get_parameter(model, 'lookahead', 2)
325*a58d3d2aSXin Li    else:
326*a58d3d2aSXin Li        lookahead = args.lookahead
327*a58d3d2aSXin Li
328*a58d3d2aSXin Li    hf.write('/* Features look-ahead */\n')
329*a58d3d2aSXin Li    hf.write('#define FEATURES_DELAY ' + str(lookahead) +'\n\n')
330*a58d3d2aSXin Li
331*a58d3d2aSXin Li    embed_size = lpcnet.embed_size
332*a58d3d2aSXin Li
333*a58d3d2aSXin Li    E = model.get_layer('embed_sig').get_weights()[0]
334*a58d3d2aSXin Li    W = model.get_layer('gru_a').get_weights()[0][:embed_size,:]
335*a58d3d2aSXin Li    dump_embedding_layer_impl('gru_a_embed_sig', np.dot(E, W), f, hf)
336*a58d3d2aSXin Li    W = model.get_layer('gru_a').get_weights()[0][embed_size:2*embed_size,:]
337*a58d3d2aSXin Li    dump_embedding_layer_impl('gru_a_embed_pred', np.dot(E, W), f, hf)
338*a58d3d2aSXin Li    W = model.get_layer('gru_a').get_weights()[0][2*embed_size:3*embed_size,:]
339*a58d3d2aSXin Li    dump_embedding_layer_impl('gru_a_embed_exc', np.dot(E, W), f, hf)
340*a58d3d2aSXin Li    W = model.get_layer('gru_a').get_weights()[0][3*embed_size:,:]
341*a58d3d2aSXin Li    #FIXME: dump only half the biases
342*a58d3d2aSXin Li    b = model.get_layer('gru_a').get_weights()[2]
343*a58d3d2aSXin Li    dump_dense_layer_impl('gru_a_dense_feature', W, b[:len(b)//2], 'LINEAR', f, hf)
344*a58d3d2aSXin Li
345*a58d3d2aSXin Li    W = model.get_layer('gru_b').get_weights()[0][model.rnn_units1:,:]
346*a58d3d2aSXin Li    b = model.get_layer('gru_b').get_weights()[2]
347*a58d3d2aSXin Li    # Set biases to zero because they'll be included in the GRU input part
348*a58d3d2aSXin Li    # (we need regular and SU biases)
349*a58d3d2aSXin Li    dump_dense_layer_impl('gru_b_dense_feature', W, 0*b[:len(b)//2], 'LINEAR', f, hf)
350*a58d3d2aSXin Li    dump_grub(model.get_layer('gru_b'), f, hf, model.rnn_units1)
351*a58d3d2aSXin Li
352*a58d3d2aSXin Li    layer_list = []
353*a58d3d2aSXin Li    for i, layer in enumerate(model.layers):
354*a58d3d2aSXin Li        if layer.dump_layer(f, hf):
355*a58d3d2aSXin Li            layer_list.append(layer.name)
356*a58d3d2aSXin Li
357*a58d3d2aSXin Li    dump_sparse_gru(model.get_layer('gru_a'), f, hf)
358*a58d3d2aSXin Li
359*a58d3d2aSXin Li    f.write('#ifndef USE_WEIGHTS_FILE\n')
360*a58d3d2aSXin Li    f.write('const WeightArray lpcnet_arrays[] = {\n')
361*a58d3d2aSXin Li    for name in array_list:
362*a58d3d2aSXin Li        f.write('#ifdef WEIGHTS_{}_DEFINED\n'.format(name))
363*a58d3d2aSXin Li        f.write('  {{"{}", WEIGHTS_{}_TYPE, sizeof({}), {}}},\n'.format(name, name, name, name))
364*a58d3d2aSXin Li        f.write('#endif\n')
365*a58d3d2aSXin Li    f.write('  {NULL, 0, 0, NULL}\n};\n')
366*a58d3d2aSXin Li    f.write('#endif\n')
367*a58d3d2aSXin Li
368*a58d3d2aSXin Li    model_init.write('  return 0;\n}\n')
369*a58d3d2aSXin Li    model_init.write('#endif\n')
370*a58d3d2aSXin Li    f.write(model_init.getvalue())
371*a58d3d2aSXin Li
372*a58d3d2aSXin Li    hf.write('#define MAX_RNN_NEURONS {}\n\n'.format(max_rnn_neurons))
373*a58d3d2aSXin Li    hf.write('#define MAX_CONV_INPUTS {}\n\n'.format(max_conv_inputs))
374*a58d3d2aSXin Li    hf.write('#define MAX_MDENSE_TMP {}\n\n'.format(max_mdense_tmp))
375*a58d3d2aSXin Li
376*a58d3d2aSXin Li
377*a58d3d2aSXin Li    hf.write('typedef struct {\n')
378*a58d3d2aSXin Li    for i, name in enumerate(layer_list):
379*a58d3d2aSXin Li        hf.write('  float {}_state[{}_STATE_SIZE];\n'.format(name, name.upper()))
380*a58d3d2aSXin Li    hf.write('} NNetState;\n\n')
381*a58d3d2aSXin Li
382*a58d3d2aSXin Li    model_struct.write('} LPCNetModel;\n\n')
383*a58d3d2aSXin Li    hf.write(model_struct.getvalue())
384*a58d3d2aSXin Li    hf.write('int init_lpcnet_model(LPCNetModel *model, const WeightArray *arrays);\n\n')
385*a58d3d2aSXin Li    hf.write('\n\n#endif\n')
386*a58d3d2aSXin Li
387*a58d3d2aSXin Li    f.close()
388*a58d3d2aSXin Li    hf.close()
389