xref: /aosp_15_r20/external/libopus/dnn/training_tf2/diffembed.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li"""
2*a58d3d2aSXin LiModification of Tensorflow's Embedding Layer:
3*a58d3d2aSXin Li    1. Not restricted to be the first layer of a model
4*a58d3d2aSXin Li    2. Differentiable (allows non-integer lookups)
5*a58d3d2aSXin Li        - For non integer lookup, this layer linearly interpolates between the adjacent embeddings in the following way to preserver gradient flow
6*a58d3d2aSXin Li            - E = (1 - frac(x))*embed(floor(x)) + frac(x)*embed(ceil(x))
7*a58d3d2aSXin Li"""
8*a58d3d2aSXin Li
9*a58d3d2aSXin Liimport tensorflow as tf
10*a58d3d2aSXin Lifrom tensorflow.keras.layers import Layer
11*a58d3d2aSXin Li
12*a58d3d2aSXin Liclass diff_Embed(Layer):
13*a58d3d2aSXin Li    """
14*a58d3d2aSXin Li    Parameters:
15*a58d3d2aSXin Li        - units: int
16*a58d3d2aSXin Li            Dimension of the Embedding
17*a58d3d2aSXin Li        - dict_size: int
18*a58d3d2aSXin Li            Number of Embeddings to lookup
19*a58d3d2aSXin Li        - pcm_init: boolean
20*a58d3d2aSXin Li            Initialized for the embedding matrix
21*a58d3d2aSXin Li    """
22*a58d3d2aSXin Li    def __init__(self, units=128, dict_size = 256, pcm_init = True, initializer = None, **kwargs):
23*a58d3d2aSXin Li        super(diff_Embed, self).__init__(**kwargs)
24*a58d3d2aSXin Li        self.units = units
25*a58d3d2aSXin Li        self.dict_size = dict_size
26*a58d3d2aSXin Li        self.pcm_init = pcm_init
27*a58d3d2aSXin Li        self.initializer = initializer
28*a58d3d2aSXin Li
29*a58d3d2aSXin Li    def build(self, input_shape):
30*a58d3d2aSXin Li        w_init = tf.random_normal_initializer()
31*a58d3d2aSXin Li        if self.pcm_init:
32*a58d3d2aSXin Li            w_init = self.initializer
33*a58d3d2aSXin Li        self.w = tf.Variable(initial_value=w_init(shape=(self.dict_size, self.units),dtype='float32'),trainable=True)
34*a58d3d2aSXin Li
35*a58d3d2aSXin Li    def call(self, inputs):
36*a58d3d2aSXin Li        alpha = inputs - tf.math.floor(inputs)
37*a58d3d2aSXin Li        alpha = tf.expand_dims(alpha,axis = -1)
38*a58d3d2aSXin Li        alpha = tf.tile(alpha,[1,1,1,self.units])
39*a58d3d2aSXin Li        inputs = tf.cast(inputs,'int32')
40*a58d3d2aSXin Li        M = (1 - alpha)*tf.gather(self.w,inputs) + alpha*tf.gather(self.w,tf.clip_by_value(inputs + 1, 0, 255))
41*a58d3d2aSXin Li        return M
42*a58d3d2aSXin Li
43*a58d3d2aSXin Li    def get_config(self):
44*a58d3d2aSXin Li        config = super(diff_Embed, self).get_config()
45*a58d3d2aSXin Li        config.update({"units": self.units})
46*a58d3d2aSXin Li        config.update({"dict_size" : self.dict_size})
47*a58d3d2aSXin Li        config.update({"pcm_init" : self.pcm_init})
48*a58d3d2aSXin Li        config.update({"initializer" : self.initializer})
49*a58d3d2aSXin Li        return config