xref: /aosp_15_r20/external/libopus/dnn/training_tf2/mdense.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1from tensorflow.keras import backend as K
2from tensorflow.keras.layers import Layer, InputSpec
3from tensorflow.keras import activations
4from tensorflow.keras import initializers, regularizers, constraints
5import numpy as np
6import math
7
8class MDense(Layer):
9
10    def __init__(self, outputs,
11                 channels=2,
12                 activation=None,
13                 use_bias=True,
14                 kernel_initializer='glorot_uniform',
15                 bias_initializer='zeros',
16                 kernel_regularizer=None,
17                 bias_regularizer=None,
18                 activity_regularizer=None,
19                 kernel_constraint=None,
20                 bias_constraint=None,
21                 **kwargs):
22        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
23            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
24        super(MDense, self).__init__(**kwargs)
25        self.units = outputs
26        self.channels = channels
27        self.activation = activations.get(activation)
28        self.use_bias = use_bias
29        self.kernel_initializer = initializers.get(kernel_initializer)
30        self.bias_initializer = initializers.get(bias_initializer)
31        self.kernel_regularizer = regularizers.get(kernel_regularizer)
32        self.bias_regularizer = regularizers.get(bias_regularizer)
33        self.activity_regularizer = regularizers.get(activity_regularizer)
34        self.kernel_constraint = constraints.get(kernel_constraint)
35        self.bias_constraint = constraints.get(bias_constraint)
36        self.input_spec = InputSpec(min_ndim=2)
37        self.supports_masking = True
38
39    def build(self, input_shape):
40        assert len(input_shape) >= 2
41        input_dim = input_shape[-1]
42
43        self.kernel = self.add_weight(shape=(self.units, input_dim, self.channels),
44                                      initializer=self.kernel_initializer,
45                                      name='kernel',
46                                      regularizer=self.kernel_regularizer,
47                                      constraint=self.kernel_constraint)
48        if self.use_bias:
49            self.bias = self.add_weight(shape=(self.units, self.channels),
50                                        initializer=self.bias_initializer,
51                                        name='bias',
52                                        regularizer=self.bias_regularizer,
53                                        constraint=self.bias_constraint)
54        else:
55            self.bias = None
56        self.factor = self.add_weight(shape=(self.units, self.channels),
57                                    initializer='ones',
58                                    name='factor',
59                                    regularizer=self.bias_regularizer,
60                                    constraint=self.bias_constraint)
61        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
62        self.built = True
63
64    def call(self, inputs):
65        output = K.dot(inputs, self.kernel)
66        if self.use_bias:
67            output = output + self.bias
68        output = K.tanh(output) * self.factor
69        output = K.sum(output, axis=-1)
70        if self.activation is not None:
71            output = self.activation(output)
72        return output
73
74    def compute_output_shape(self, input_shape):
75        assert input_shape and len(input_shape) >= 2
76        assert input_shape[-1]
77        output_shape = list(input_shape)
78        output_shape[-1] = self.units
79        return tuple(output_shape)
80
81    def get_config(self):
82        config = {
83            'units': self.units,
84            'activation': activations.serialize(self.activation),
85            'use_bias': self.use_bias,
86            'kernel_initializer': initializers.serialize(self.kernel_initializer),
87            'bias_initializer': initializers.serialize(self.bias_initializer),
88            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
89            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
90            'activity_regularizer': regularizers.serialize(self.activity_regularizer),
91            'kernel_constraint': constraints.serialize(self.kernel_constraint),
92            'bias_constraint': constraints.serialize(self.bias_constraint)
93        }
94        base_config = super(MDense, self).get_config()
95        return dict(list(base_config.items()) + list(config.items()))
96