1*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerdef milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): 8*da0073e9SAndroid Build Coastguard Worker Wx = x.mm(w_ih.t()) 9*da0073e9SAndroid Build Coastguard Worker Uz = hx.mm(w_hh.t()) 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf 12*da0073e9SAndroid Build Coastguard Worker gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker # Same as LSTMCell after this point 15*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker ingate = ingate.sigmoid() 18*da0073e9SAndroid Build Coastguard Worker forgetgate = forgetgate.sigmoid() 19*da0073e9SAndroid Build Coastguard Worker cellgate = cellgate.tanh() 20*da0073e9SAndroid Build Coastguard Worker outgate = outgate.sigmoid() 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 23*da0073e9SAndroid Build Coastguard Worker hy = outgate * cy.tanh() 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker return hy, cy 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerdef lstm_cell( 29*da0073e9SAndroid Build Coastguard Worker input: Tensor, 30*da0073e9SAndroid Build Coastguard Worker hidden: Tuple[Tensor, Tensor], 31*da0073e9SAndroid Build Coastguard Worker w_ih: Tensor, 32*da0073e9SAndroid Build Coastguard Worker w_hh: Tensor, 33*da0073e9SAndroid Build Coastguard Worker b_ih: Tensor, 34*da0073e9SAndroid Build Coastguard Worker b_hh: Tensor, 35*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: 36*da0073e9SAndroid Build Coastguard Worker hx, cx = hidden 37*da0073e9SAndroid Build Coastguard Worker gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker ingate = torch.sigmoid(ingate) 42*da0073e9SAndroid Build Coastguard Worker forgetgate = torch.sigmoid(forgetgate) 43*da0073e9SAndroid Build Coastguard Worker cellgate = torch.tanh(cellgate) 44*da0073e9SAndroid Build Coastguard Worker outgate = torch.sigmoid(outgate) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 47*da0073e9SAndroid Build Coastguard Worker hy = outgate * torch.tanh(cy) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker return hy, cy 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Workerdef flat_lstm_cell( 53*da0073e9SAndroid Build Coastguard Worker input: Tensor, 54*da0073e9SAndroid Build Coastguard Worker hx: Tensor, 55*da0073e9SAndroid Build Coastguard Worker cx: Tensor, 56*da0073e9SAndroid Build Coastguard Worker w_ih: Tensor, 57*da0073e9SAndroid Build Coastguard Worker w_hh: Tensor, 58*da0073e9SAndroid Build Coastguard Worker b_ih: Tensor, 59*da0073e9SAndroid Build Coastguard Worker b_hh: Tensor, 60*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: 61*da0073e9SAndroid Build Coastguard Worker gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker ingate = torch.sigmoid(ingate) 66*da0073e9SAndroid Build Coastguard Worker forgetgate = torch.sigmoid(forgetgate) 67*da0073e9SAndroid Build Coastguard Worker cellgate = torch.tanh(cellgate) 68*da0073e9SAndroid Build Coastguard Worker outgate = torch.sigmoid(outgate) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 71*da0073e9SAndroid Build Coastguard Worker hy = outgate * torch.tanh(cy) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker return hy, cy 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Workerdef premul_lstm_cell( 77*da0073e9SAndroid Build Coastguard Worker igates: Tensor, 78*da0073e9SAndroid Build Coastguard Worker hidden: Tuple[Tensor, Tensor], 79*da0073e9SAndroid Build Coastguard Worker w_hh: Tensor, 80*da0073e9SAndroid Build Coastguard Worker b_ih: Tensor, 81*da0073e9SAndroid Build Coastguard Worker b_hh: Tensor, 82*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: 83*da0073e9SAndroid Build Coastguard Worker hx, cx = hidden 84*da0073e9SAndroid Build Coastguard Worker gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker ingate = torch.sigmoid(ingate) 89*da0073e9SAndroid Build Coastguard Worker forgetgate = torch.sigmoid(forgetgate) 90*da0073e9SAndroid Build Coastguard Worker cellgate = torch.tanh(cellgate) 91*da0073e9SAndroid Build Coastguard Worker outgate = torch.sigmoid(outgate) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 94*da0073e9SAndroid Build Coastguard Worker hy = outgate * torch.tanh(cy) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker return hy, cy 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerdef premul_lstm_cell_no_bias( 100*da0073e9SAndroid Build Coastguard Worker igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor 101*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]: 102*da0073e9SAndroid Build Coastguard Worker hx, cx = hidden 103*da0073e9SAndroid Build Coastguard Worker gates = igates + torch.mm(hx, w_hh.t()) + b_hh 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker ingate = torch.sigmoid(ingate) 108*da0073e9SAndroid Build Coastguard Worker forgetgate = torch.sigmoid(forgetgate) 109*da0073e9SAndroid Build Coastguard Worker cellgate = torch.tanh(cellgate) 110*da0073e9SAndroid Build Coastguard Worker outgate = torch.sigmoid(outgate) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 113*da0073e9SAndroid Build Coastguard Worker hy = outgate * torch.tanh(cy) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker return hy, cy 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Workerdef gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): 119*da0073e9SAndroid Build Coastguard Worker gi = torch.mm(input, w_ih.t()) + b_ih 120*da0073e9SAndroid Build Coastguard Worker gh = torch.mm(hidden, w_hh.t()) + b_hh 121*da0073e9SAndroid Build Coastguard Worker i_r, i_i, i_n = gi.chunk(3, 1) 122*da0073e9SAndroid Build Coastguard Worker h_r, h_i, h_n = gh.chunk(3, 1) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker resetgate = torch.sigmoid(i_r + h_r) 125*da0073e9SAndroid Build Coastguard Worker inputgate = torch.sigmoid(i_i + h_i) 126*da0073e9SAndroid Build Coastguard Worker newgate = torch.tanh(i_n + resetgate * h_n) 127*da0073e9SAndroid Build Coastguard Worker hy = newgate + inputgate * (hidden - newgate) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker return hy 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Workerdef rnn_relu_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): 133*da0073e9SAndroid Build Coastguard Worker igates = torch.mm(input, w_ih.t()) + b_ih 134*da0073e9SAndroid Build Coastguard Worker hgates = torch.mm(hidden, w_hh.t()) + b_hh 135*da0073e9SAndroid Build Coastguard Worker return torch.relu(igates + hgates) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Workerdef rnn_tanh_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): 139*da0073e9SAndroid Build Coastguard Worker igates = torch.mm(input, w_ih.t()) + b_ih 140*da0073e9SAndroid Build Coastguard Worker hgates = torch.mm(hidden, w_hh.t()) + b_hh 141*da0073e9SAndroid Build Coastguard Worker return torch.tanh(igates + hgates) 142