xref: /aosp_15_r20/external/libopus/scripts/dump_rnn.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python
2*a58d3d2aSXin Li
3*a58d3d2aSXin Lifrom __future__ import print_function
4*a58d3d2aSXin Li
5*a58d3d2aSXin Lifrom keras.models import Sequential
6*a58d3d2aSXin Lifrom keras.layers import Dense
7*a58d3d2aSXin Lifrom keras.layers import LSTM
8*a58d3d2aSXin Lifrom keras.layers import GRU
9*a58d3d2aSXin Lifrom keras.models import load_model
10*a58d3d2aSXin Lifrom keras import backend as K
11*a58d3d2aSXin Li
12*a58d3d2aSXin Liimport numpy as np
13*a58d3d2aSXin Li
14*a58d3d2aSXin Lidef printVector(f, vector, name):
15*a58d3d2aSXin Li    v = np.reshape(vector, (-1));
16*a58d3d2aSXin Li    #print('static const float ', name, '[', len(v), '] = \n', file=f)
17*a58d3d2aSXin Li    f.write('static const opus_int16 {}[{}] = {{\n   '.format(name, len(v)))
18*a58d3d2aSXin Li    for i in range(0, len(v)):
19*a58d3d2aSXin Li        f.write('{}'.format(int(round(8192*v[i]))))
20*a58d3d2aSXin Li        if (i!=len(v)-1):
21*a58d3d2aSXin Li            f.write(',')
22*a58d3d2aSXin Li        else:
23*a58d3d2aSXin Li            break;
24*a58d3d2aSXin Li        if (i%8==7):
25*a58d3d2aSXin Li            f.write("\n   ")
26*a58d3d2aSXin Li        else:
27*a58d3d2aSXin Li            f.write(" ")
28*a58d3d2aSXin Li    #print(v, file=f)
29*a58d3d2aSXin Li    f.write('\n};\n\n')
30*a58d3d2aSXin Li    return;
31*a58d3d2aSXin Li
32*a58d3d2aSXin Lidef binary_crossentrop2(y_true, y_pred):
33*a58d3d2aSXin Li        return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
34*a58d3d2aSXin Li
35*a58d3d2aSXin Li
36*a58d3d2aSXin Limodel = load_model("weights.hdf5", custom_objects={'binary_crossentrop2': binary_crossentrop2})
37*a58d3d2aSXin Li
38*a58d3d2aSXin Liweights = model.get_weights()
39*a58d3d2aSXin Li
40*a58d3d2aSXin Lif = open('rnn_weights.c', 'w')
41*a58d3d2aSXin Li
42*a58d3d2aSXin Lif.write('/*This file is automatically generated from a Keras model*/\n\n')
43*a58d3d2aSXin Lif.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "mlp.h"\n\n')
44*a58d3d2aSXin Li
45*a58d3d2aSXin LiprintVector(f, weights[0], 'layer0_weights')
46*a58d3d2aSXin LiprintVector(f, weights[1], 'layer0_bias')
47*a58d3d2aSXin LiprintVector(f, weights[2], 'layer1_weights')
48*a58d3d2aSXin LiprintVector(f, weights[3], 'layer1_recur_weights')
49*a58d3d2aSXin LiprintVector(f, weights[4], 'layer1_bias')
50*a58d3d2aSXin LiprintVector(f, weights[5], 'layer2_weights')
51*a58d3d2aSXin LiprintVector(f, weights[6], 'layer2_bias')
52*a58d3d2aSXin Li
53*a58d3d2aSXin Lif.write('const DenseLayer layer0 = {\n   layer0_bias,\n   layer0_weights,\n   25, 16, 0\n};\n\n')
54*a58d3d2aSXin Lif.write('const GRULayer layer1 = {\n   layer1_bias,\n   layer1_weights,\n   layer1_recur_weights,\n   16, 12\n};\n\n')
55*a58d3d2aSXin Lif.write('const DenseLayer layer2 = {\n   layer2_bias,\n   layer2_weights,\n   12, 2, 1\n};\n\n')
56*a58d3d2aSXin Li
57*a58d3d2aSXin Lif.close()
58