1#!/usr/bin/python 2 3from __future__ import print_function 4 5from keras.models import Sequential 6from keras.layers import Dense 7from keras.layers import LSTM 8from keras.layers import GRU 9from keras.models import load_model 10from keras import backend as K 11import sys 12import re 13import numpy as np 14 15def printVector(f, ft, vector, name): 16 v = np.reshape(vector, (-1)); 17 #print('static const float ', name, '[', len(v), '] = \n', file=f) 18 f.write('static const rnn_weight {}[{}] = {{\n '.format(name, len(v))) 19 for i in range(0, len(v)): 20 f.write('{}'.format(min(127, int(round(256*v[i]))))) 21 ft.write('{}'.format(min(127, int(round(256*v[i]))))) 22 if (i!=len(v)-1): 23 f.write(',') 24 else: 25 break; 26 ft.write(" ") 27 if (i%8==7): 28 f.write("\n ") 29 else: 30 f.write(" ") 31 #print(v, file=f) 32 f.write('\n};\n\n') 33 ft.write("\n") 34 return; 35 36def printLayer(f, ft, layer): 37 weights = layer.get_weights() 38 activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() 39 if len(weights) > 2: 40 ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]/3)) 41 else: 42 ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1])) 43 if activation == 'SIGMOID': 44 ft.write('1\n') 45 elif activation == 'RELU': 46 ft.write('2\n') 47 else: 48 ft.write('0\n') 49 printVector(f, ft, weights[0], layer.name + '_weights') 50 if len(weights) > 2: 51 printVector(f, ft, weights[1], layer.name + '_recurrent_weights') 52 printVector(f, ft, weights[-1], layer.name + '_bias') 53 name = layer.name 54 if len(weights) > 2: 55 f.write('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 56 .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation)) 57 else: 58 f.write('static const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 59 .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation)) 60 61def structLayer(f, layer): 62 weights = layer.get_weights() 63 name = layer.name 64 if len(weights) > 2: 65 f.write(' {},\n'.format(weights[0].shape[1]/3)) 66 else: 67 f.write(' {},\n'.format(weights[0].shape[1])) 68 f.write(' &{},\n'.format(name)) 69 70 71def foo(c, name): 72 return None 73 74def mean_squared_sqrt_error(y_true, y_pred): 75 return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) 76 77 78model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, 'mean_squared_sqrt_error': mean_squared_sqrt_error, 'my_crossentropy': mean_squared_sqrt_error, 'mycost': mean_squared_sqrt_error, 'WeightClip': foo}) 79 80weights = model.get_weights() 81 82f = open(sys.argv[2], 'w') 83ft = open(sys.argv[3], 'w') 84 85f.write('/*This file is automatically generated from a Keras model*/\n\n') 86f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n#include "rnn_data.h"\n\n') 87ft.write('rnnoise-nu model file version 1\n') 88 89layer_list = [] 90for i, layer in enumerate(model.layers): 91 if len(layer.get_weights()) > 0: 92 printLayer(f, ft, layer) 93 if len(layer.get_weights()) > 2: 94 layer_list.append(layer.name) 95 96f.write('const struct RNNModel rnnoise_model_{} = {{\n'.format(sys.argv[4])) 97for i, layer in enumerate(model.layers): 98 if len(layer.get_weights()) > 0: 99 structLayer(f, layer) 100f.write('};\n') 101 102#hf.write('struct RNNState {\n') 103#for i, name in enumerate(layer_list): 104# hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) 105#hf.write('};\n') 106 107f.close() 108