1*a58d3d2aSXin Liimport torch 2*a58d3d2aSXin Lifrom torch import nn 3*a58d3d2aSXin Liimport torch.nn.functional as F 4*a58d3d2aSXin Li 5*a58d3d2aSXin Liclass LossGen(nn.Module): 6*a58d3d2aSXin Li def __init__(self, gru1_size=16, gru2_size=16): 7*a58d3d2aSXin Li super(LossGen, self).__init__() 8*a58d3d2aSXin Li 9*a58d3d2aSXin Li self.gru1_size = gru1_size 10*a58d3d2aSXin Li self.gru2_size = gru2_size 11*a58d3d2aSXin Li self.dense_in = nn.Linear(2, 8) 12*a58d3d2aSXin Li self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True) 13*a58d3d2aSXin Li self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True) 14*a58d3d2aSXin Li self.dense_out = nn.Linear(self.gru2_size, 1) 15*a58d3d2aSXin Li 16*a58d3d2aSXin Li def forward(self, loss, perc, states=None): 17*a58d3d2aSXin Li #print(states) 18*a58d3d2aSXin Li device = loss.device 19*a58d3d2aSXin Li batch_size = loss.size(0) 20*a58d3d2aSXin Li if states is None: 21*a58d3d2aSXin Li gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device) 22*a58d3d2aSXin Li gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device) 23*a58d3d2aSXin Li else: 24*a58d3d2aSXin Li gru1_state = states[0] 25*a58d3d2aSXin Li gru2_state = states[1] 26*a58d3d2aSXin Li x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1))) 27*a58d3d2aSXin Li gru1_out, gru1_state = self.gru1(x, gru1_state) 28*a58d3d2aSXin Li gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) 29*a58d3d2aSXin Li return self.dense_out(gru2_out), [gru1_state, gru2_state] 30