xref: /aosp_15_r20/external/pytorch/torch/nn/utils/rnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import warnings
2from collections.abc import Iterable
3from typing import (
4    Any,
5    Callable,
6    List,
7    NamedTuple,
8    Optional,
9    overload,
10    Tuple,
11    TypeVar,
12    Union,
13)
14from typing_extensions import Self
15
16import torch
17from torch import _VF, Tensor
18
19
20__all__ = [
21    "PackedSequence",
22    "invert_permutation",
23    "pack_padded_sequence",
24    "pad_packed_sequence",
25    "pad_sequence",
26    "unpad_sequence",
27    "pack_sequence",
28    "unpack_sequence",
29]
30
31_T = TypeVar("_T")
32_R = TypeVar("_R")
33
34
35class PackedSequence_(NamedTuple):
36    data: torch.Tensor
37    batch_sizes: torch.Tensor
38    sorted_indices: Optional[torch.Tensor]
39    unsorted_indices: Optional[torch.Tensor]
40
41
42def bind(optional: Optional[_T], fn: Callable[[_T], _R]) -> Optional[_R]:
43    if optional is None:
44        return None
45    return fn(optional)
46
47
48class PackedSequence(PackedSequence_):
49    r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence.
50
51    All RNN modules accept packed sequences as inputs.
52
53    Note:
54        Instances of this class should never be created manually. They are meant
55        to be instantiated by functions like :func:`pack_padded_sequence`.
56
57        Batch sizes represent the number elements at each sequence step in
58        the batch, not the varying sequence lengths passed to
59        :func:`pack_padded_sequence`.  For instance, given data ``abc`` and ``x``
60        the :class:`PackedSequence` would contain data ``axbc`` with
61        ``batch_sizes=[2,1,1]``.
62
63    Attributes:
64        data (Tensor): Tensor containing packed sequence
65        batch_sizes (Tensor): Tensor of integers holding
66            information about the batch size at each sequence step
67        sorted_indices (Tensor, optional): Tensor of integers holding how this
68            :class:`PackedSequence` is constructed from sequences.
69        unsorted_indices (Tensor, optional): Tensor of integers holding how this
70            to recover the original sequences with correct order.
71
72    .. note::
73        :attr:`data` can be on arbitrary device and of arbitrary dtype.
74        :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
75        tensors on the same device as :attr:`data`.
76
77        However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
78
79        This invariant is maintained throughout :class:`PackedSequence` class,
80        and all functions that construct a :class:`PackedSequence` in PyTorch
81        (i.e., they only pass in tensors conforming to this constraint).
82    """
83
84    def __new__(
85        cls,
86        data: Tensor,
87        batch_sizes: Optional[Tensor] = None,
88        sorted_indices: Optional[Tensor] = None,
89        unsorted_indices: Optional[Tensor] = None,
90    ) -> Self:
91        return super().__new__(
92            cls,
93            *_packed_sequence_init_args(
94                data, batch_sizes, sorted_indices, unsorted_indices
95            ),
96        )
97
98    # NOTE [ device and dtype of a PackedSequence ]
99    #
100    # See the note above in doc string (starting with ":attr:`data` can be on
101    # arbitrary device...").
102    def pin_memory(self) -> Self:
103        # Why not convert `batch_sizes`?
104        # See NOTE [ device and dtype of a PackedSequence ]
105        return type(self)(
106            self.data.pin_memory(),
107            self.batch_sizes,
108            bind(self.sorted_indices, lambda t: t.pin_memory()),
109            bind(self.unsorted_indices, lambda t: t.pin_memory()),
110        )
111
112    @overload
113    def to(
114        self,
115        dtype: torch.dtype,
116        non_blocking: bool = ...,
117        copy: bool = ...,
118    ) -> Self:
119        ...
120
121    @overload
122    def to(
123        self,
124        device: Optional[Union[str, torch.device, int]] = ...,
125        dtype: Optional[torch.dtype] = ...,
126        non_blocking: bool = ...,
127        copy: bool = ...,
128    ) -> Self:
129        ...
130
131    @overload
132    def to(
133        self,
134        other: Tensor,
135        non_blocking: bool = ...,
136        copy: bool = ...,
137    ) -> Self:
138        ...
139
140    def to(self, *args: Any, **kwargs: Any) -> Self:
141        r"""Perform dtype and/or device conversion on `self.data`.
142
143        It has similar signature as :meth:`torch.Tensor.to`, except optional
144        arguments like `non_blocking` and `copy` should be passed as kwargs,
145        not args, or they will not apply to the index tensors.
146
147        .. note::
148
149            If the ``self.data`` Tensor already has the correct :class:`torch.dtype`
150            and :class:`torch.device`, then ``self`` is returned.
151            Otherwise, returns a copy with the desired configuration.
152        """
153        # Why not convert `batch_sizes`?
154        # See NOTE [ device and dtype of a PackedSequence ]
155        data = self.data.to(*args, **kwargs)
156        if data is self.data:
157            return self
158        else:
159            # Does not forward device or dtype arg/kwargs, device is set from data.device
160            kwargs = dict(
161                filter(lambda t: t[0] != "device" and t[0] != "dtype", kwargs.items())
162            )
163            sorted_indices = bind(
164                self.sorted_indices, lambda t: t.to(data.device, **kwargs)
165            )
166            unsorted_indices = bind(
167                self.unsorted_indices, lambda t: t.to(data.device, **kwargs)
168            )
169            return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices)
170
171    def cuda(self, *args: Any, **kwargs: Any) -> Self:
172        # Tests to see if 'cuda' should be added to kwargs
173        ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
174            *args, **kwargs
175        )
176        if ex.is_cuda:
177            return self.to(*args, **kwargs)
178        kwargs["device"] = "cuda"
179        return self.to(*args, **kwargs)
180
181    def cpu(self, *args: Any, **kwargs: Any) -> Self:
182        ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
183            *args, **kwargs
184        )
185        if ex.device.type == "cpu":
186            return self.to(*args, **kwargs)
187        kwargs["device"] = "cpu"
188        return self.to(*args, **kwargs)
189
190    def double(self) -> Self:
191        return self.to(dtype=torch.double)
192
193    def float(self) -> Self:
194        return self.to(dtype=torch.float)
195
196    def half(self) -> Self:
197        return self.to(dtype=torch.half)
198
199    def long(self) -> Self:
200        return self.to(dtype=torch.long)
201
202    def int(self) -> Self:
203        return self.to(dtype=torch.int)
204
205    def short(self) -> Self:
206        return self.to(dtype=torch.short)
207
208    def char(self) -> Self:
209        return self.to(dtype=torch.int8)
210
211    def byte(self) -> Self:
212        return self.to(dtype=torch.uint8)
213
214    @property
215    def is_cuda(self) -> bool:
216        r"""Return true if `self.data` stored on a gpu."""
217        return self.data.is_cuda
218
219    def is_pinned(self) -> bool:
220        r"""Return true if `self.data` stored on in pinned memory."""
221        return self.data.is_pinned()
222
223
224# TorchScript doesn't support constructors on named tuples, so we use this helper
225# method to construct PackedSequence
226def _packed_sequence_init_args(
227    data: Tensor,
228    batch_sizes: Optional[Tensor] = None,
229    sorted_indices: Optional[Tensor] = None,
230    unsorted_indices: Optional[Tensor] = None,
231) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
232    # NB: if unsorted_indices is provided, it should be the inverse permutation
233    # to sorted_indices. Don't assert it here because the PackedSequence ctor
234    # should only be used internally.
235
236    if unsorted_indices is None:
237        unsorted_indices = invert_permutation(sorted_indices)
238
239    # support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
240    if batch_sizes is not None:
241        # TODO: Re-enable this check (.type isn't supported in TorchScript)
242        if batch_sizes.device.type != "cpu":
243            raise ValueError(
244                "batch_sizes should always be on CPU. "
245                "Instances of PackedSequence should never be created manually. "
246                "They should be instantiated by functions like pack_sequence "
247                "and pack_padded_sequences in nn.utils.rnn. "
248                "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence"
249            )
250        return data, batch_sizes, sorted_indices, unsorted_indices
251
252    # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
253    else:
254        assert isinstance(data, (list, tuple)) and len(data) == 2
255        return data[0], data[1], sorted_indices, unsorted_indices
256
257
258def _packed_sequence_init(
259    data: Tensor,
260    batch_sizes: Optional[Tensor] = None,
261    sorted_indices: Optional[Tensor] = None,
262    unsorted_indices: Optional[Tensor] = None,
263) -> PackedSequence:
264    data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args(
265        data, batch_sizes, sorted_indices, unsorted_indices
266    )
267    return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
268
269
270def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]:
271    if permutation is None:
272        return None
273    output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format)
274    output.scatter_(
275        0, permutation, torch.arange(0, permutation.numel(), device=permutation.device)
276    )
277    return output
278
279
280def pack_padded_sequence(
281    input: Tensor,
282    lengths: Union[Tensor, List[int]],
283    batch_first: bool = False,
284    enforce_sorted: bool = True,
285) -> PackedSequence:
286    r"""Packs a Tensor containing padded sequences of variable length.
287
288    :attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``)
289    or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length
290    of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions
291    (including 0).
292
293    For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
294    ``True``, the sequences should be sorted by length in a decreasing order, i.e.
295    ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
296    one. `enforce_sorted = True` is only necessary for ONNX export.
297
298    Note:
299        This function accepts any input that has at least two dimensions. You
300        can apply it to pack the labels, and use the output of the RNN with
301        them to compute the loss directly. A Tensor can be retrieved from
302        a :class:`PackedSequence` object by accessing its ``.data`` attribute.
303
304    Args:
305        input (Tensor): padded batch of variable length sequences.
306        lengths (Tensor or list(int)): list of sequence lengths of each batch
307            element (must be on the CPU if provided as a tensor).
308        batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
309            format, ``T x B x *`` otherwise.
310        enforce_sorted (bool, optional): if ``True``, the input is expected to
311            contain sequences sorted by length in a decreasing order. If
312            ``False``, the input will get sorted unconditionally. Default: ``True``.
313
314    Returns:
315        a :class:`PackedSequence` object
316    """
317    if not isinstance(lengths, torch.Tensor):
318        if torch._C._get_tracing_state():
319            warnings.warn(
320                "pack_padded_sequence has been called with a Python list of "
321                "sequence lengths. The tracer cannot track the data flow of Python "
322                "values, and it will treat them as constants, likely rendering "
323                "the trace incorrect for any other combination of lengths.",
324                stacklevel=2,
325            )
326        lengths = torch.as_tensor(lengths, dtype=torch.int64, device="cpu")
327    else:
328        lengths = lengths.to(dtype=torch.int64)
329
330    if enforce_sorted:
331        sorted_indices = None
332    else:
333        lengths, sorted_indices = torch.sort(lengths, descending=True)
334        sorted_indices = sorted_indices.to(input.device)
335        batch_dim = 0 if batch_first else 1
336        input = input.index_select(batch_dim, sorted_indices)
337
338    data, batch_sizes = _VF._pack_padded_sequence(input, lengths, batch_first)
339    return _packed_sequence_init(data, batch_sizes, sorted_indices, None)
340
341
342def pad_packed_sequence(
343    sequence: PackedSequence,
344    batch_first: bool = False,
345    padding_value: float = 0.0,
346    total_length: Optional[int] = None,
347) -> Tuple[Tensor, Tensor]:
348    r"""Pad a packed batch of variable length sequences.
349
350    It is an inverse operation to :func:`pack_padded_sequence`.
351
352    The returned Tensor's data will be of size ``T x B x *`` (if :attr:`batch_first` is ``False``)
353    or ``B x T x *`` (if :attr:`batch_first` is ``True``) , where ``T`` is the length of the longest
354    sequence and ``B`` is the batch size.
355
356    Example:
357        >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
358        >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
359        >>> lens = [2, 1, 3]
360        >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
361        >>> packed
362        PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
363                       sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
364        >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
365        >>> seq_unpacked
366        tensor([[1, 2, 0],
367                [3, 0, 0],
368                [4, 5, 6]])
369        >>> lens_unpacked
370        tensor([2, 1, 3])
371
372    .. note::
373        :attr:`total_length` is useful to implement the
374        ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
375        :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
376        See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
377        details.
378
379    Args:
380        sequence (PackedSequence): batch to pad
381        batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
382            format, ``T x B x *`` otherwise.
383        padding_value (float, optional): values for padded elements.
384        total_length (int, optional): if not ``None``, the output will be padded to
385            have length :attr:`total_length`. This method will throw :class:`ValueError`
386            if :attr:`total_length` is less than the max sequence length in
387            :attr:`sequence`.
388
389    Returns:
390        Tuple of Tensor containing the padded sequence, and a Tensor
391        containing the list of lengths of each sequence in the batch.
392        Batch elements will be re-ordered as they were ordered originally when
393        the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.
394    """
395    max_seq_length = sequence.batch_sizes.size(0)
396    if total_length is not None:
397        if total_length < max_seq_length:
398            raise ValueError(
399                "Expected total_length to be at least the length "
400                "of the longest sequence in input, but got "
401                f"total_length={total_length} and max sequence length being {max_seq_length}"
402            )
403        max_seq_length = total_length
404    padded_output, lengths = _VF._pad_packed_sequence(
405        sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length
406    )
407    unsorted_indices = sequence.unsorted_indices
408    if unsorted_indices is not None:
409        batch_dim = 0 if batch_first else 1
410        return (
411            padded_output.index_select(batch_dim, unsorted_indices),
412            lengths[unsorted_indices.cpu()],
413        )
414    return padded_output, lengths
415
416
417# NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable.
418def pad_sequence(
419    sequences: Union[Tensor, List[Tensor]],
420    batch_first: bool = False,
421    padding_value: float = 0.0,
422    padding_side: str = "right",
423) -> Tensor:
424    r"""Pad a list of variable length Tensors with :attr:`padding_value`.
425
426    ``pad_sequence`` stacks a list of Tensors along a new dimension, and pads them
427    to equal length. :attr:`sequences` can be list of sequences with size ``L x *``,
428    where `L` is length of the sequence and ``*`` is any number of dimensions
429    (including 0). If :attr:`batch_first` is ``False``, the output is of size
430    ``T x B x *``, and ``B x T x *`` otherwise, where ``B`` is the batch size
431    (the number of elements in :attr:`sequences`), ``T`` is the length of the longest
432    sequence.
433
434    Example:
435        >>> from torch.nn.utils.rnn import pad_sequence
436        >>> a = torch.ones(25, 300)
437        >>> b = torch.ones(22, 300)
438        >>> c = torch.ones(15, 300)
439        >>> pad_sequence([a, b, c]).size()
440        torch.Size([25, 3, 300])
441
442    Note:
443        This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
444        where `T` is the length of the longest sequence. This function assumes
445        trailing dimensions and type of all the Tensors in sequences are same.
446
447    Args:
448        sequences (list[Tensor]): list of variable length sequences.
449        batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
450            format, ``T x B x *`` otherwise.
451        padding_value (float, optional): value for padded elements. Default: 0.
452        padding_side (str, optional): the side to pad the sequences on.
453            Default: "right".
454
455    Returns:
456        Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
457        Tensor of size ``B x T x *`` otherwise
458    """
459    if not (torch.jit.is_tracing() or torch.jit.is_scripting()):
460        # JIT doesn't support `Iterable`
461        if not isinstance(sequences, Iterable):
462            msg = (
463                "pad_sequence: Expected iterable for input sequences, but got arg of type: "
464                f"{type(sequences)}"
465            )
466            raise RuntimeError(msg)
467
468        # In JIT context this leads to,
469        # RuntimeError: cannot statically infer the expected size of a list in this context
470        sequences = tuple(sequences)  # type: ignore[assignment]
471    else:
472        # For JIT, we only support Union[Tensor, Tuple[Tensor]]
473        if isinstance(sequences, torch.Tensor):
474            sequences = sequences.unbind(0)  # type: ignore[assignment]
475
476    # assuming trailing dimensions and type of all the Tensors
477    # in sequences are same and fetching those from sequences[0]
478    return torch._C._nn.pad_sequence(
479        sequences, batch_first, padding_value, padding_side  # type: ignore[arg-type]
480    )
481
482
483def unpad_sequence(
484    padded_sequences: Tensor,
485    lengths: Tensor,
486    batch_first: bool = False,
487) -> List[Tensor]:
488    r"""Unpad padded Tensor into a list of variable length Tensors.
489
490    ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.
491
492    Example:
493        >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence
494        >>> a = torch.ones(25, 300)
495        >>> b = torch.ones(22, 300)
496        >>> c = torch.ones(15, 300)
497        >>> sequences = [a, b, c]
498        >>> padded_sequences = pad_sequence(sequences)
499        >>> lengths = torch.as_tensor([v.size(0) for v in sequences])
500        >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths)
501        >>> torch.allclose(sequences[0], unpadded_sequences[0])
502        True
503        >>> torch.allclose(sequences[1], unpadded_sequences[1])
504        True
505        >>> torch.allclose(sequences[2], unpadded_sequences[2])
506        True
507
508    Args:
509        padded_sequences (Tensor): padded sequences.
510        lengths (Tensor): length of original (unpadded) sequences.
511        batch_first (bool, optional): whether batch dimension first or not. Default: False.
512
513    Returns:
514        a list of :class:`Tensor` objects
515    """
516    unpadded_sequences = []
517
518    if not batch_first:
519        padded_sequences.transpose_(0, 1)
520
521    max_length = padded_sequences.shape[1]
522    idx = torch.arange(max_length, device=lengths.device)
523
524    for seq, length in zip(padded_sequences, lengths):
525        mask = idx < length
526        unpacked_seq = seq[mask]
527        unpadded_sequences.append(unpacked_seq)
528
529    return unpadded_sequences
530
531
532def pack_sequence(
533    sequences: List[Tensor],
534    enforce_sorted: bool = True,
535) -> PackedSequence:
536    r"""Packs a list of variable length Tensors.
537
538    Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``.
539
540    ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
541    the length of a sequence and `*` is any number of trailing dimensions,
542    including zero.
543
544    For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``
545    is ``True``, the sequences should be sorted in the order of decreasing length.
546    ``enforce_sorted = True`` is only necessary for ONNX export.
547
548    Example:
549        >>> from torch.nn.utils.rnn import pack_sequence
550        >>> a = torch.tensor([1, 2, 3])
551        >>> b = torch.tensor([4, 5])
552        >>> c = torch.tensor([6])
553        >>> pack_sequence([a, b, c])
554        PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
555
556    Args:
557        sequences (list[Tensor]): A list of sequences of decreasing length.
558        enforce_sorted (bool, optional): if ``True``, checks that the input
559            contains sequences sorted by length in a decreasing order. If
560            ``False``, this condition is not checked. Default: ``True``.
561
562    Returns:
563        a :class:`PackedSequence` object
564    """
565    lengths = torch.as_tensor([v.size(0) for v in sequences])
566    return pack_padded_sequence(
567        pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted
568    )
569
570
571def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:
572    r"""Unpack PackedSequence into a list of variable length Tensors.
573
574    ``packed_sequences`` should be a PackedSequence object.
575
576    Example:
577        >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence
578        >>> a = torch.tensor([1, 2, 3])
579        >>> b = torch.tensor([4, 5])
580        >>> c = torch.tensor([6])
581        >>> sequences = [a, b, c]
582        >>> print(sequences)
583        [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
584        >>> packed_sequences = pack_sequence(sequences)
585        >>> print(packed_sequences)
586        PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
587        >>> unpacked_sequences = unpack_sequence(packed_sequences)
588        >>> print(unpacked_sequences)
589        [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
590
591    Args:
592        packed_sequences (PackedSequence): A PackedSequence object.
593
594    Returns:
595        a list of :class:`Tensor` objects
596    """
597    padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True)
598    unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True)
599    return unpacked_sequences
600