1# The model is from here: 2# https://github.com/pytorch/examples/blob/master/word_language_model/model.py 3 4from typing import Optional, Tuple 5 6import torch 7import torch.nn as nn 8from torch import Tensor 9 10 11class RNNModel(nn.Module): 12 """Container module with an encoder, a recurrent module, and a decoder.""" 13 14 def __init__( 15 self, 16 rnn_type, 17 ntoken, 18 ninp, 19 nhid, 20 nlayers, 21 dropout=0.5, 22 tie_weights=False, 23 batchsize=2, 24 ): 25 super().__init__() 26 self.drop = nn.Dropout(dropout) 27 self.encoder = nn.Embedding(ntoken, ninp) 28 if rnn_type in ["LSTM", "GRU"]: 29 self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 30 else: 31 try: 32 nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] 33 except KeyError: 34 raise ValueError( 35 """An invalid option for `--model` was supplied, 36 options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""" 37 ) from None 38 self.rnn = nn.RNN( 39 ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout 40 ) 41 self.decoder = nn.Linear(nhid, ntoken) 42 43 # Optionally tie weights as in: 44 # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 45 # https://arxiv.org/abs/1608.05859 46 # and 47 # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 48 # https://arxiv.org/abs/1611.01462 49 if tie_weights: 50 if nhid != ninp: 51 raise ValueError( 52 "When using the tied flag, nhid must be equal to emsize" 53 ) 54 self.decoder.weight = self.encoder.weight 55 56 self.init_weights() 57 58 self.rnn_type = rnn_type 59 self.nhid = nhid 60 self.nlayers = nlayers 61 self.hidden = self.init_hidden(batchsize) 62 63 @staticmethod 64 def repackage_hidden(h): 65 """Detach hidden states from their history.""" 66 if isinstance(h, torch.Tensor): 67 return h.detach() 68 else: 69 return tuple([RNNModel.repackage_hidden(v) for v in h]) 70 71 def init_weights(self): 72 initrange = 0.1 73 self.encoder.weight.data.uniform_(-initrange, initrange) 74 self.decoder.bias.data.fill_(0) 75 self.decoder.weight.data.uniform_(-initrange, initrange) 76 77 def forward(self, input, hidden): 78 emb = self.drop(self.encoder(input)) 79 output, hidden = self.rnn(emb, hidden) 80 output = self.drop(output) 81 decoded = self.decoder( 82 output.view(output.size(0) * output.size(1), output.size(2)) 83 ) 84 self.hidden = RNNModel.repackage_hidden(hidden) 85 return decoded.view(output.size(0), output.size(1), decoded.size(1)) 86 87 def init_hidden(self, bsz): 88 weight = next(self.parameters()).data 89 if self.rnn_type == "LSTM": 90 return ( 91 weight.new(self.nlayers, bsz, self.nhid).zero_(), 92 weight.new(self.nlayers, bsz, self.nhid).zero_(), 93 ) 94 else: 95 return weight.new(self.nlayers, bsz, self.nhid).zero_() 96 97 98class RNNModelWithTensorHidden(RNNModel): 99 """Supports GRU scripting.""" 100 101 @staticmethod 102 def repackage_hidden(h): 103 """Detach hidden states from their history.""" 104 return h.detach() 105 106 def forward(self, input: Tensor, hidden: Tensor): 107 emb = self.drop(self.encoder(input)) 108 output, hidden = self.rnn(emb, hidden) 109 output = self.drop(output) 110 decoded = self.decoder( 111 output.view(output.size(0) * output.size(1), output.size(2)) 112 ) 113 self.hidden = RNNModelWithTensorHidden.repackage_hidden(hidden) 114 return decoded.view(output.size(0), output.size(1), decoded.size(1)) 115 116 117class RNNModelWithTupleHidden(RNNModel): 118 """Supports LSTM scripting.""" 119 120 @staticmethod 121 def repackage_hidden(h: Tuple[Tensor, Tensor]): 122 """Detach hidden states from their history.""" 123 return (h[0].detach(), h[1].detach()) 124 125 def forward(self, input: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): 126 emb = self.drop(self.encoder(input)) 127 output, hidden = self.rnn(emb, hidden) 128 output = self.drop(output) 129 decoded = self.decoder( 130 output.view(output.size(0) * output.size(1), output.size(2)) 131 ) 132 self.hidden = self.repackage_hidden(tuple(hidden)) 133 return decoded.view(output.size(0), output.size(1), decoded.size(1)) 134