xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/cells.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerdef milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
8*da0073e9SAndroid Build Coastguard Worker    Wx = x.mm(w_ih.t())
9*da0073e9SAndroid Build Coastguard Worker    Uz = hx.mm(w_hh.t())
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker    # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
12*da0073e9SAndroid Build Coastguard Worker    gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker    # Same as LSTMCell after this point
15*da0073e9SAndroid Build Coastguard Worker    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    ingate = ingate.sigmoid()
18*da0073e9SAndroid Build Coastguard Worker    forgetgate = forgetgate.sigmoid()
19*da0073e9SAndroid Build Coastguard Worker    cellgate = cellgate.tanh()
20*da0073e9SAndroid Build Coastguard Worker    outgate = outgate.sigmoid()
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    cy = (forgetgate * cx) + (ingate * cellgate)
23*da0073e9SAndroid Build Coastguard Worker    hy = outgate * cy.tanh()
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    return hy, cy
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerdef lstm_cell(
29*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
30*da0073e9SAndroid Build Coastguard Worker    hidden: Tuple[Tensor, Tensor],
31*da0073e9SAndroid Build Coastguard Worker    w_ih: Tensor,
32*da0073e9SAndroid Build Coastguard Worker    w_hh: Tensor,
33*da0073e9SAndroid Build Coastguard Worker    b_ih: Tensor,
34*da0073e9SAndroid Build Coastguard Worker    b_hh: Tensor,
35*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:
36*da0073e9SAndroid Build Coastguard Worker    hx, cx = hidden
37*da0073e9SAndroid Build Coastguard Worker    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    ingate = torch.sigmoid(ingate)
42*da0073e9SAndroid Build Coastguard Worker    forgetgate = torch.sigmoid(forgetgate)
43*da0073e9SAndroid Build Coastguard Worker    cellgate = torch.tanh(cellgate)
44*da0073e9SAndroid Build Coastguard Worker    outgate = torch.sigmoid(outgate)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    cy = (forgetgate * cx) + (ingate * cellgate)
47*da0073e9SAndroid Build Coastguard Worker    hy = outgate * torch.tanh(cy)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    return hy, cy
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Workerdef flat_lstm_cell(
53*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
54*da0073e9SAndroid Build Coastguard Worker    hx: Tensor,
55*da0073e9SAndroid Build Coastguard Worker    cx: Tensor,
56*da0073e9SAndroid Build Coastguard Worker    w_ih: Tensor,
57*da0073e9SAndroid Build Coastguard Worker    w_hh: Tensor,
58*da0073e9SAndroid Build Coastguard Worker    b_ih: Tensor,
59*da0073e9SAndroid Build Coastguard Worker    b_hh: Tensor,
60*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:
61*da0073e9SAndroid Build Coastguard Worker    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    ingate = torch.sigmoid(ingate)
66*da0073e9SAndroid Build Coastguard Worker    forgetgate = torch.sigmoid(forgetgate)
67*da0073e9SAndroid Build Coastguard Worker    cellgate = torch.tanh(cellgate)
68*da0073e9SAndroid Build Coastguard Worker    outgate = torch.sigmoid(outgate)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    cy = (forgetgate * cx) + (ingate * cellgate)
71*da0073e9SAndroid Build Coastguard Worker    hy = outgate * torch.tanh(cy)
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    return hy, cy
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Workerdef premul_lstm_cell(
77*da0073e9SAndroid Build Coastguard Worker    igates: Tensor,
78*da0073e9SAndroid Build Coastguard Worker    hidden: Tuple[Tensor, Tensor],
79*da0073e9SAndroid Build Coastguard Worker    w_hh: Tensor,
80*da0073e9SAndroid Build Coastguard Worker    b_ih: Tensor,
81*da0073e9SAndroid Build Coastguard Worker    b_hh: Tensor,
82*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:
83*da0073e9SAndroid Build Coastguard Worker    hx, cx = hidden
84*da0073e9SAndroid Build Coastguard Worker    gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    ingate = torch.sigmoid(ingate)
89*da0073e9SAndroid Build Coastguard Worker    forgetgate = torch.sigmoid(forgetgate)
90*da0073e9SAndroid Build Coastguard Worker    cellgate = torch.tanh(cellgate)
91*da0073e9SAndroid Build Coastguard Worker    outgate = torch.sigmoid(outgate)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    cy = (forgetgate * cx) + (ingate * cellgate)
94*da0073e9SAndroid Build Coastguard Worker    hy = outgate * torch.tanh(cy)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    return hy, cy
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerdef premul_lstm_cell_no_bias(
100*da0073e9SAndroid Build Coastguard Worker    igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor
101*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, Tensor]:
102*da0073e9SAndroid Build Coastguard Worker    hx, cx = hidden
103*da0073e9SAndroid Build Coastguard Worker    gates = igates + torch.mm(hx, w_hh.t()) + b_hh
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    ingate = torch.sigmoid(ingate)
108*da0073e9SAndroid Build Coastguard Worker    forgetgate = torch.sigmoid(forgetgate)
109*da0073e9SAndroid Build Coastguard Worker    cellgate = torch.tanh(cellgate)
110*da0073e9SAndroid Build Coastguard Worker    outgate = torch.sigmoid(outgate)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    cy = (forgetgate * cx) + (ingate * cellgate)
113*da0073e9SAndroid Build Coastguard Worker    hy = outgate * torch.tanh(cy)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    return hy, cy
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Workerdef gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
119*da0073e9SAndroid Build Coastguard Worker    gi = torch.mm(input, w_ih.t()) + b_ih
120*da0073e9SAndroid Build Coastguard Worker    gh = torch.mm(hidden, w_hh.t()) + b_hh
121*da0073e9SAndroid Build Coastguard Worker    i_r, i_i, i_n = gi.chunk(3, 1)
122*da0073e9SAndroid Build Coastguard Worker    h_r, h_i, h_n = gh.chunk(3, 1)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    resetgate = torch.sigmoid(i_r + h_r)
125*da0073e9SAndroid Build Coastguard Worker    inputgate = torch.sigmoid(i_i + h_i)
126*da0073e9SAndroid Build Coastguard Worker    newgate = torch.tanh(i_n + resetgate * h_n)
127*da0073e9SAndroid Build Coastguard Worker    hy = newgate + inputgate * (hidden - newgate)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    return hy
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Workerdef rnn_relu_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
133*da0073e9SAndroid Build Coastguard Worker    igates = torch.mm(input, w_ih.t()) + b_ih
134*da0073e9SAndroid Build Coastguard Worker    hgates = torch.mm(hidden, w_hh.t()) + b_hh
135*da0073e9SAndroid Build Coastguard Worker    return torch.relu(igates + hgates)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Workerdef rnn_tanh_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
139*da0073e9SAndroid Build Coastguard Worker    igates = torch.mm(input, w_ih.t()) + b_ih
140*da0073e9SAndroid Build Coastguard Worker    hgates = torch.mm(hidden, w_hh.t()) + b_hh
141*da0073e9SAndroid Build Coastguard Worker    return torch.tanh(igates + hgates)
142