xref: /aosp_15_r20/external/libopus/dnn/training_tf2/diffembed.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2Modification of Tensorflow's Embedding Layer:
3    1. Not restricted to be the first layer of a model
4    2. Differentiable (allows non-integer lookups)
5        - For non integer lookup, this layer linearly interpolates between the adjacent embeddings in the following way to preserver gradient flow
6            - E = (1 - frac(x))*embed(floor(x)) + frac(x)*embed(ceil(x))
7"""
8
9import tensorflow as tf
10from tensorflow.keras.layers import Layer
11
12class diff_Embed(Layer):
13    """
14    Parameters:
15        - units: int
16            Dimension of the Embedding
17        - dict_size: int
18            Number of Embeddings to lookup
19        - pcm_init: boolean
20            Initialized for the embedding matrix
21    """
22    def __init__(self, units=128, dict_size = 256, pcm_init = True, initializer = None, **kwargs):
23        super(diff_Embed, self).__init__(**kwargs)
24        self.units = units
25        self.dict_size = dict_size
26        self.pcm_init = pcm_init
27        self.initializer = initializer
28
29    def build(self, input_shape):
30        w_init = tf.random_normal_initializer()
31        if self.pcm_init:
32            w_init = self.initializer
33        self.w = tf.Variable(initial_value=w_init(shape=(self.dict_size, self.units),dtype='float32'),trainable=True)
34
35    def call(self, inputs):
36        alpha = inputs - tf.math.floor(inputs)
37        alpha = tf.expand_dims(alpha,axis = -1)
38        alpha = tf.tile(alpha,[1,1,1,self.units])
39        inputs = tf.cast(inputs,'int32')
40        M = (1 - alpha)*tf.gather(self.w,inputs) + alpha*tf.gather(self.w,tf.clip_by_value(inputs + 1, 0, 255))
41        return M
42
43    def get_config(self):
44        config = super(diff_Embed, self).get_config()
45        config.update({"units": self.units})
46        config.update({"dict_size" : self.dict_size})
47        config.update({"pcm_init" : self.pcm_init})
48        config.update({"initializer" : self.initializer})
49        return config