xref: /aosp_15_r20/external/libopus/dnn/torch/lossgen/lossgen.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Liimport torch
2*a58d3d2aSXin Lifrom torch import nn
3*a58d3d2aSXin Liimport torch.nn.functional as F
4*a58d3d2aSXin Li
5*a58d3d2aSXin Liclass LossGen(nn.Module):
6*a58d3d2aSXin Li    def __init__(self, gru1_size=16, gru2_size=16):
7*a58d3d2aSXin Li        super(LossGen, self).__init__()
8*a58d3d2aSXin Li
9*a58d3d2aSXin Li        self.gru1_size = gru1_size
10*a58d3d2aSXin Li        self.gru2_size = gru2_size
11*a58d3d2aSXin Li        self.dense_in = nn.Linear(2, 8)
12*a58d3d2aSXin Li        self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True)
13*a58d3d2aSXin Li        self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
14*a58d3d2aSXin Li        self.dense_out = nn.Linear(self.gru2_size, 1)
15*a58d3d2aSXin Li
16*a58d3d2aSXin Li    def forward(self, loss, perc, states=None):
17*a58d3d2aSXin Li        #print(states)
18*a58d3d2aSXin Li        device = loss.device
19*a58d3d2aSXin Li        batch_size = loss.size(0)
20*a58d3d2aSXin Li        if states is None:
21*a58d3d2aSXin Li            gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
22*a58d3d2aSXin Li            gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
23*a58d3d2aSXin Li        else:
24*a58d3d2aSXin Li            gru1_state = states[0]
25*a58d3d2aSXin Li            gru2_state = states[1]
26*a58d3d2aSXin Li        x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1)))
27*a58d3d2aSXin Li        gru1_out, gru1_state = self.gru1(x, gru1_state)
28*a58d3d2aSXin Li        gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
29*a58d3d2aSXin Li        return self.dense_out(gru2_out), [gru1_state, gru2_state]
30