1from torch import nn 2from torch.nn.utils import rnn as rnn_utils 3 4 5class RnnModelWithPackedSequence(nn.Module): 6 def __init__(self, model, batch_first): 7 super().__init__() 8 self.model = model 9 self.batch_first = batch_first 10 11 def forward(self, input, *args): 12 args, seq_lengths = args[:-1], args[-1] 13 input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first) 14 rets = self.model(input, *args) 15 ret, rets = rets[0], rets[1:] 16 ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first) 17 return tuple([ret] + list(rets)) 18 19 20class RnnModelWithPackedSequenceWithoutState(nn.Module): 21 def __init__(self, model, batch_first): 22 super().__init__() 23 self.model = model 24 self.batch_first = batch_first 25 26 def forward(self, input, seq_lengths): 27 input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first) 28 rets = self.model(input) 29 ret, rets = rets[0], rets[1:] 30 ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first) 31 return list([ret] + list(rets)) 32 33 34class RnnModelWithPackedSequenceWithState(nn.Module): 35 def __init__(self, model, batch_first): 36 super().__init__() 37 self.model = model 38 self.batch_first = batch_first 39 40 def forward(self, input, hx, seq_lengths): 41 input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first) 42 rets = self.model(input, hx) 43 ret, rets = rets[0], rets[1:] 44 ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first) 45 return list([ret] + list(rets)) 46