xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/rnn_model_with_packed_sequence.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from torch import nn
2from torch.nn.utils import rnn as rnn_utils
3
4
5class RnnModelWithPackedSequence(nn.Module):
6    def __init__(self, model, batch_first):
7        super().__init__()
8        self.model = model
9        self.batch_first = batch_first
10
11    def forward(self, input, *args):
12        args, seq_lengths = args[:-1], args[-1]
13        input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first)
14        rets = self.model(input, *args)
15        ret, rets = rets[0], rets[1:]
16        ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first)
17        return tuple([ret] + list(rets))
18
19
20class RnnModelWithPackedSequenceWithoutState(nn.Module):
21    def __init__(self, model, batch_first):
22        super().__init__()
23        self.model = model
24        self.batch_first = batch_first
25
26    def forward(self, input, seq_lengths):
27        input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first)
28        rets = self.model(input)
29        ret, rets = rets[0], rets[1:]
30        ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first)
31        return list([ret] + list(rets))
32
33
34class RnnModelWithPackedSequenceWithState(nn.Module):
35    def __init__(self, model, batch_first):
36        super().__init__()
37        self.model = model
38        self.batch_first = batch_first
39
40    def forward(self, input, hx, seq_lengths):
41        input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first)
42        rets = self.model(input, hx)
43        ret, rets = rets[0], rets[1:]
44        ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first)
45        return list([ret] + list(rets))
46