1*a58d3d2aSXin Li""" 2*a58d3d2aSXin LiModification of Tensorflow's Embedding Layer: 3*a58d3d2aSXin Li 1. Not restricted to be the first layer of a model 4*a58d3d2aSXin Li 2. Differentiable (allows non-integer lookups) 5*a58d3d2aSXin Li - For non integer lookup, this layer linearly interpolates between the adjacent embeddings in the following way to preserver gradient flow 6*a58d3d2aSXin Li - E = (1 - frac(x))*embed(floor(x)) + frac(x)*embed(ceil(x)) 7*a58d3d2aSXin Li""" 8*a58d3d2aSXin Li 9*a58d3d2aSXin Liimport tensorflow as tf 10*a58d3d2aSXin Lifrom tensorflow.keras.layers import Layer 11*a58d3d2aSXin Li 12*a58d3d2aSXin Liclass diff_Embed(Layer): 13*a58d3d2aSXin Li """ 14*a58d3d2aSXin Li Parameters: 15*a58d3d2aSXin Li - units: int 16*a58d3d2aSXin Li Dimension of the Embedding 17*a58d3d2aSXin Li - dict_size: int 18*a58d3d2aSXin Li Number of Embeddings to lookup 19*a58d3d2aSXin Li - pcm_init: boolean 20*a58d3d2aSXin Li Initialized for the embedding matrix 21*a58d3d2aSXin Li """ 22*a58d3d2aSXin Li def __init__(self, units=128, dict_size = 256, pcm_init = True, initializer = None, **kwargs): 23*a58d3d2aSXin Li super(diff_Embed, self).__init__(**kwargs) 24*a58d3d2aSXin Li self.units = units 25*a58d3d2aSXin Li self.dict_size = dict_size 26*a58d3d2aSXin Li self.pcm_init = pcm_init 27*a58d3d2aSXin Li self.initializer = initializer 28*a58d3d2aSXin Li 29*a58d3d2aSXin Li def build(self, input_shape): 30*a58d3d2aSXin Li w_init = tf.random_normal_initializer() 31*a58d3d2aSXin Li if self.pcm_init: 32*a58d3d2aSXin Li w_init = self.initializer 33*a58d3d2aSXin Li self.w = tf.Variable(initial_value=w_init(shape=(self.dict_size, self.units),dtype='float32'),trainable=True) 34*a58d3d2aSXin Li 35*a58d3d2aSXin Li def call(self, inputs): 36*a58d3d2aSXin Li alpha = inputs - tf.math.floor(inputs) 37*a58d3d2aSXin Li alpha = tf.expand_dims(alpha,axis = -1) 38*a58d3d2aSXin Li alpha = tf.tile(alpha,[1,1,1,self.units]) 39*a58d3d2aSXin Li inputs = tf.cast(inputs,'int32') 40*a58d3d2aSXin Li M = (1 - alpha)*tf.gather(self.w,inputs) + alpha*tf.gather(self.w,tf.clip_by_value(inputs + 1, 0, 255)) 41*a58d3d2aSXin Li return M 42*a58d3d2aSXin Li 43*a58d3d2aSXin Li def get_config(self): 44*a58d3d2aSXin Li config = super(diff_Embed, self).get_config() 45*a58d3d2aSXin Li config.update({"units": self.units}) 46*a58d3d2aSXin Li config.update({"dict_size" : self.dict_size}) 47*a58d3d2aSXin Li config.update({"pcm_init" : self.pcm_init}) 48*a58d3d2aSXin Li config.update({"initializer" : self.initializer}) 49*a58d3d2aSXin Li return config