xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/word_language_model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# The model is from here:
2#   https://github.com/pytorch/examples/blob/master/word_language_model/model.py
3
4from typing import Optional, Tuple
5
6import torch
7import torch.nn as nn
8from torch import Tensor
9
10
11class RNNModel(nn.Module):
12    """Container module with an encoder, a recurrent module, and a decoder."""
13
14    def __init__(
15        self,
16        rnn_type,
17        ntoken,
18        ninp,
19        nhid,
20        nlayers,
21        dropout=0.5,
22        tie_weights=False,
23        batchsize=2,
24    ):
25        super().__init__()
26        self.drop = nn.Dropout(dropout)
27        self.encoder = nn.Embedding(ntoken, ninp)
28        if rnn_type in ["LSTM", "GRU"]:
29            self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
30        else:
31            try:
32                nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
33            except KeyError:
34                raise ValueError(
35                    """An invalid option for `--model` was supplied,
36                                 options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']"""
37                ) from None
38            self.rnn = nn.RNN(
39                ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout
40            )
41        self.decoder = nn.Linear(nhid, ntoken)
42
43        # Optionally tie weights as in:
44        # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
45        # https://arxiv.org/abs/1608.05859
46        # and
47        # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
48        # https://arxiv.org/abs/1611.01462
49        if tie_weights:
50            if nhid != ninp:
51                raise ValueError(
52                    "When using the tied flag, nhid must be equal to emsize"
53                )
54            self.decoder.weight = self.encoder.weight
55
56        self.init_weights()
57
58        self.rnn_type = rnn_type
59        self.nhid = nhid
60        self.nlayers = nlayers
61        self.hidden = self.init_hidden(batchsize)
62
63    @staticmethod
64    def repackage_hidden(h):
65        """Detach hidden states from their history."""
66        if isinstance(h, torch.Tensor):
67            return h.detach()
68        else:
69            return tuple([RNNModel.repackage_hidden(v) for v in h])
70
71    def init_weights(self):
72        initrange = 0.1
73        self.encoder.weight.data.uniform_(-initrange, initrange)
74        self.decoder.bias.data.fill_(0)
75        self.decoder.weight.data.uniform_(-initrange, initrange)
76
77    def forward(self, input, hidden):
78        emb = self.drop(self.encoder(input))
79        output, hidden = self.rnn(emb, hidden)
80        output = self.drop(output)
81        decoded = self.decoder(
82            output.view(output.size(0) * output.size(1), output.size(2))
83        )
84        self.hidden = RNNModel.repackage_hidden(hidden)
85        return decoded.view(output.size(0), output.size(1), decoded.size(1))
86
87    def init_hidden(self, bsz):
88        weight = next(self.parameters()).data
89        if self.rnn_type == "LSTM":
90            return (
91                weight.new(self.nlayers, bsz, self.nhid).zero_(),
92                weight.new(self.nlayers, bsz, self.nhid).zero_(),
93            )
94        else:
95            return weight.new(self.nlayers, bsz, self.nhid).zero_()
96
97
98class RNNModelWithTensorHidden(RNNModel):
99    """Supports GRU scripting."""
100
101    @staticmethod
102    def repackage_hidden(h):
103        """Detach hidden states from their history."""
104        return h.detach()
105
106    def forward(self, input: Tensor, hidden: Tensor):
107        emb = self.drop(self.encoder(input))
108        output, hidden = self.rnn(emb, hidden)
109        output = self.drop(output)
110        decoded = self.decoder(
111            output.view(output.size(0) * output.size(1), output.size(2))
112        )
113        self.hidden = RNNModelWithTensorHidden.repackage_hidden(hidden)
114        return decoded.view(output.size(0), output.size(1), decoded.size(1))
115
116
117class RNNModelWithTupleHidden(RNNModel):
118    """Supports LSTM scripting."""
119
120    @staticmethod
121    def repackage_hidden(h: Tuple[Tensor, Tensor]):
122        """Detach hidden states from their history."""
123        return (h[0].detach(), h[1].detach())
124
125    def forward(self, input: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
126        emb = self.drop(self.encoder(input))
127        output, hidden = self.rnn(emb, hidden)
128        output = self.drop(output)
129        decoded = self.decoder(
130            output.view(output.size(0) * output.size(1), output.size(2))
131        )
132        self.hidden = self.repackage_hidden(tuple(hidden))
133        return decoded.view(output.size(0), output.size(1), decoded.size(1))
134