1""" 2Modification of Tensorflow's Embedding Layer: 3 1. Not restricted to be the first layer of a model 4 2. Differentiable (allows non-integer lookups) 5 - For non integer lookup, this layer linearly interpolates between the adjacent embeddings in the following way to preserver gradient flow 6 - E = (1 - frac(x))*embed(floor(x)) + frac(x)*embed(ceil(x)) 7""" 8 9import tensorflow as tf 10from tensorflow.keras.layers import Layer 11 12class diff_Embed(Layer): 13 """ 14 Parameters: 15 - units: int 16 Dimension of the Embedding 17 - dict_size: int 18 Number of Embeddings to lookup 19 - pcm_init: boolean 20 Initialized for the embedding matrix 21 """ 22 def __init__(self, units=128, dict_size = 256, pcm_init = True, initializer = None, **kwargs): 23 super(diff_Embed, self).__init__(**kwargs) 24 self.units = units 25 self.dict_size = dict_size 26 self.pcm_init = pcm_init 27 self.initializer = initializer 28 29 def build(self, input_shape): 30 w_init = tf.random_normal_initializer() 31 if self.pcm_init: 32 w_init = self.initializer 33 self.w = tf.Variable(initial_value=w_init(shape=(self.dict_size, self.units),dtype='float32'),trainable=True) 34 35 def call(self, inputs): 36 alpha = inputs - tf.math.floor(inputs) 37 alpha = tf.expand_dims(alpha,axis = -1) 38 alpha = tf.tile(alpha,[1,1,1,self.units]) 39 inputs = tf.cast(inputs,'int32') 40 M = (1 - alpha)*tf.gather(self.w,inputs) + alpha*tf.gather(self.w,tf.clip_by_value(inputs + 1, 0, 255)) 41 return M 42 43 def get_config(self): 44 config = super(diff_Embed, self).get_config() 45 config.update({"units": self.units}) 46 config.update({"dict_size" : self.dict_size}) 47 config.update({"pcm_init" : self.pcm_init}) 48 config.update({"initializer" : self.initializer}) 49 return config