1from tensorflow.keras import backend as K 2from tensorflow.keras.layers import Layer, InputSpec 3from tensorflow.keras import activations 4from tensorflow.keras import initializers, regularizers, constraints 5import numpy as np 6import math 7 8class MDense(Layer): 9 10 def __init__(self, outputs, 11 channels=2, 12 activation=None, 13 use_bias=True, 14 kernel_initializer='glorot_uniform', 15 bias_initializer='zeros', 16 kernel_regularizer=None, 17 bias_regularizer=None, 18 activity_regularizer=None, 19 kernel_constraint=None, 20 bias_constraint=None, 21 **kwargs): 22 if 'input_shape' not in kwargs and 'input_dim' in kwargs: 23 kwargs['input_shape'] = (kwargs.pop('input_dim'),) 24 super(MDense, self).__init__(**kwargs) 25 self.units = outputs 26 self.channels = channels 27 self.activation = activations.get(activation) 28 self.use_bias = use_bias 29 self.kernel_initializer = initializers.get(kernel_initializer) 30 self.bias_initializer = initializers.get(bias_initializer) 31 self.kernel_regularizer = regularizers.get(kernel_regularizer) 32 self.bias_regularizer = regularizers.get(bias_regularizer) 33 self.activity_regularizer = regularizers.get(activity_regularizer) 34 self.kernel_constraint = constraints.get(kernel_constraint) 35 self.bias_constraint = constraints.get(bias_constraint) 36 self.input_spec = InputSpec(min_ndim=2) 37 self.supports_masking = True 38 39 def build(self, input_shape): 40 assert len(input_shape) >= 2 41 input_dim = input_shape[-1] 42 43 self.kernel = self.add_weight(shape=(self.units, input_dim, self.channels), 44 initializer=self.kernel_initializer, 45 name='kernel', 46 regularizer=self.kernel_regularizer, 47 constraint=self.kernel_constraint) 48 if self.use_bias: 49 self.bias = self.add_weight(shape=(self.units, self.channels), 50 initializer=self.bias_initializer, 51 name='bias', 52 regularizer=self.bias_regularizer, 53 constraint=self.bias_constraint) 54 else: 55 self.bias = None 56 self.factor = self.add_weight(shape=(self.units, self.channels), 57 initializer='ones', 58 name='factor', 59 regularizer=self.bias_regularizer, 60 constraint=self.bias_constraint) 61 self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) 62 self.built = True 63 64 def call(self, inputs): 65 output = K.dot(inputs, self.kernel) 66 if self.use_bias: 67 output = output + self.bias 68 output = K.tanh(output) * self.factor 69 output = K.sum(output, axis=-1) 70 if self.activation is not None: 71 output = self.activation(output) 72 return output 73 74 def compute_output_shape(self, input_shape): 75 assert input_shape and len(input_shape) >= 2 76 assert input_shape[-1] 77 output_shape = list(input_shape) 78 output_shape[-1] = self.units 79 return tuple(output_shape) 80 81 def get_config(self): 82 config = { 83 'units': self.units, 84 'activation': activations.serialize(self.activation), 85 'use_bias': self.use_bias, 86 'kernel_initializer': initializers.serialize(self.kernel_initializer), 87 'bias_initializer': initializers.serialize(self.bias_initializer), 88 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 89 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 90 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 91 'kernel_constraint': constraints.serialize(self.kernel_constraint), 92 'bias_constraint': constraints.serialize(self.bias_constraint) 93 } 94 base_config = super(MDense, self).get_config() 95 return dict(list(base_config.items()) + list(config.items())) 96