1import warnings 2from collections.abc import Iterable 3from typing import ( 4 Any, 5 Callable, 6 List, 7 NamedTuple, 8 Optional, 9 overload, 10 Tuple, 11 TypeVar, 12 Union, 13) 14from typing_extensions import Self 15 16import torch 17from torch import _VF, Tensor 18 19 20__all__ = [ 21 "PackedSequence", 22 "invert_permutation", 23 "pack_padded_sequence", 24 "pad_packed_sequence", 25 "pad_sequence", 26 "unpad_sequence", 27 "pack_sequence", 28 "unpack_sequence", 29] 30 31_T = TypeVar("_T") 32_R = TypeVar("_R") 33 34 35class PackedSequence_(NamedTuple): 36 data: torch.Tensor 37 batch_sizes: torch.Tensor 38 sorted_indices: Optional[torch.Tensor] 39 unsorted_indices: Optional[torch.Tensor] 40 41 42def bind(optional: Optional[_T], fn: Callable[[_T], _R]) -> Optional[_R]: 43 if optional is None: 44 return None 45 return fn(optional) 46 47 48class PackedSequence(PackedSequence_): 49 r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence. 50 51 All RNN modules accept packed sequences as inputs. 52 53 Note: 54 Instances of this class should never be created manually. They are meant 55 to be instantiated by functions like :func:`pack_padded_sequence`. 56 57 Batch sizes represent the number elements at each sequence step in 58 the batch, not the varying sequence lengths passed to 59 :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x`` 60 the :class:`PackedSequence` would contain data ``axbc`` with 61 ``batch_sizes=[2,1,1]``. 62 63 Attributes: 64 data (Tensor): Tensor containing packed sequence 65 batch_sizes (Tensor): Tensor of integers holding 66 information about the batch size at each sequence step 67 sorted_indices (Tensor, optional): Tensor of integers holding how this 68 :class:`PackedSequence` is constructed from sequences. 69 unsorted_indices (Tensor, optional): Tensor of integers holding how this 70 to recover the original sequences with correct order. 71 72 .. note:: 73 :attr:`data` can be on arbitrary device and of arbitrary dtype. 74 :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64`` 75 tensors on the same device as :attr:`data`. 76 77 However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor. 78 79 This invariant is maintained throughout :class:`PackedSequence` class, 80 and all functions that construct a :class:`PackedSequence` in PyTorch 81 (i.e., they only pass in tensors conforming to this constraint). 82 """ 83 84 def __new__( 85 cls, 86 data: Tensor, 87 batch_sizes: Optional[Tensor] = None, 88 sorted_indices: Optional[Tensor] = None, 89 unsorted_indices: Optional[Tensor] = None, 90 ) -> Self: 91 return super().__new__( 92 cls, 93 *_packed_sequence_init_args( 94 data, batch_sizes, sorted_indices, unsorted_indices 95 ), 96 ) 97 98 # NOTE [ device and dtype of a PackedSequence ] 99 # 100 # See the note above in doc string (starting with ":attr:`data` can be on 101 # arbitrary device..."). 102 def pin_memory(self) -> Self: 103 # Why not convert `batch_sizes`? 104 # See NOTE [ device and dtype of a PackedSequence ] 105 return type(self)( 106 self.data.pin_memory(), 107 self.batch_sizes, 108 bind(self.sorted_indices, lambda t: t.pin_memory()), 109 bind(self.unsorted_indices, lambda t: t.pin_memory()), 110 ) 111 112 @overload 113 def to( 114 self, 115 dtype: torch.dtype, 116 non_blocking: bool = ..., 117 copy: bool = ..., 118 ) -> Self: 119 ... 120 121 @overload 122 def to( 123 self, 124 device: Optional[Union[str, torch.device, int]] = ..., 125 dtype: Optional[torch.dtype] = ..., 126 non_blocking: bool = ..., 127 copy: bool = ..., 128 ) -> Self: 129 ... 130 131 @overload 132 def to( 133 self, 134 other: Tensor, 135 non_blocking: bool = ..., 136 copy: bool = ..., 137 ) -> Self: 138 ... 139 140 def to(self, *args: Any, **kwargs: Any) -> Self: 141 r"""Perform dtype and/or device conversion on `self.data`. 142 143 It has similar signature as :meth:`torch.Tensor.to`, except optional 144 arguments like `non_blocking` and `copy` should be passed as kwargs, 145 not args, or they will not apply to the index tensors. 146 147 .. note:: 148 149 If the ``self.data`` Tensor already has the correct :class:`torch.dtype` 150 and :class:`torch.device`, then ``self`` is returned. 151 Otherwise, returns a copy with the desired configuration. 152 """ 153 # Why not convert `batch_sizes`? 154 # See NOTE [ device and dtype of a PackedSequence ] 155 data = self.data.to(*args, **kwargs) 156 if data is self.data: 157 return self 158 else: 159 # Does not forward device or dtype arg/kwargs, device is set from data.device 160 kwargs = dict( 161 filter(lambda t: t[0] != "device" and t[0] != "dtype", kwargs.items()) 162 ) 163 sorted_indices = bind( 164 self.sorted_indices, lambda t: t.to(data.device, **kwargs) 165 ) 166 unsorted_indices = bind( 167 self.unsorted_indices, lambda t: t.to(data.device, **kwargs) 168 ) 169 return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices) 170 171 def cuda(self, *args: Any, **kwargs: Any) -> Self: 172 # Tests to see if 'cuda' should be added to kwargs 173 ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( 174 *args, **kwargs 175 ) 176 if ex.is_cuda: 177 return self.to(*args, **kwargs) 178 kwargs["device"] = "cuda" 179 return self.to(*args, **kwargs) 180 181 def cpu(self, *args: Any, **kwargs: Any) -> Self: 182 ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( 183 *args, **kwargs 184 ) 185 if ex.device.type == "cpu": 186 return self.to(*args, **kwargs) 187 kwargs["device"] = "cpu" 188 return self.to(*args, **kwargs) 189 190 def double(self) -> Self: 191 return self.to(dtype=torch.double) 192 193 def float(self) -> Self: 194 return self.to(dtype=torch.float) 195 196 def half(self) -> Self: 197 return self.to(dtype=torch.half) 198 199 def long(self) -> Self: 200 return self.to(dtype=torch.long) 201 202 def int(self) -> Self: 203 return self.to(dtype=torch.int) 204 205 def short(self) -> Self: 206 return self.to(dtype=torch.short) 207 208 def char(self) -> Self: 209 return self.to(dtype=torch.int8) 210 211 def byte(self) -> Self: 212 return self.to(dtype=torch.uint8) 213 214 @property 215 def is_cuda(self) -> bool: 216 r"""Return true if `self.data` stored on a gpu.""" 217 return self.data.is_cuda 218 219 def is_pinned(self) -> bool: 220 r"""Return true if `self.data` stored on in pinned memory.""" 221 return self.data.is_pinned() 222 223 224# TorchScript doesn't support constructors on named tuples, so we use this helper 225# method to construct PackedSequence 226def _packed_sequence_init_args( 227 data: Tensor, 228 batch_sizes: Optional[Tensor] = None, 229 sorted_indices: Optional[Tensor] = None, 230 unsorted_indices: Optional[Tensor] = None, 231) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: 232 # NB: if unsorted_indices is provided, it should be the inverse permutation 233 # to sorted_indices. Don't assert it here because the PackedSequence ctor 234 # should only be used internally. 235 236 if unsorted_indices is None: 237 unsorted_indices = invert_permutation(sorted_indices) 238 239 # support being called as `PackedSequence(data, batch_sizes, sorted_indices)` 240 if batch_sizes is not None: 241 # TODO: Re-enable this check (.type isn't supported in TorchScript) 242 if batch_sizes.device.type != "cpu": 243 raise ValueError( 244 "batch_sizes should always be on CPU. " 245 "Instances of PackedSequence should never be created manually. " 246 "They should be instantiated by functions like pack_sequence " 247 "and pack_padded_sequences in nn.utils.rnn. " 248 "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence" 249 ) 250 return data, batch_sizes, sorted_indices, unsorted_indices 251 252 # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)` 253 else: 254 assert isinstance(data, (list, tuple)) and len(data) == 2 255 return data[0], data[1], sorted_indices, unsorted_indices 256 257 258def _packed_sequence_init( 259 data: Tensor, 260 batch_sizes: Optional[Tensor] = None, 261 sorted_indices: Optional[Tensor] = None, 262 unsorted_indices: Optional[Tensor] = None, 263) -> PackedSequence: 264 data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args( 265 data, batch_sizes, sorted_indices, unsorted_indices 266 ) 267 return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices) 268 269 270def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]: 271 if permutation is None: 272 return None 273 output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format) 274 output.scatter_( 275 0, permutation, torch.arange(0, permutation.numel(), device=permutation.device) 276 ) 277 return output 278 279 280def pack_padded_sequence( 281 input: Tensor, 282 lengths: Union[Tensor, List[int]], 283 batch_first: bool = False, 284 enforce_sorted: bool = True, 285) -> PackedSequence: 286 r"""Packs a Tensor containing padded sequences of variable length. 287 288 :attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) 289 or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length 290 of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions 291 (including 0). 292 293 For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is 294 ``True``, the sequences should be sorted by length in a decreasing order, i.e. 295 ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest 296 one. `enforce_sorted = True` is only necessary for ONNX export. 297 298 Note: 299 This function accepts any input that has at least two dimensions. You 300 can apply it to pack the labels, and use the output of the RNN with 301 them to compute the loss directly. A Tensor can be retrieved from 302 a :class:`PackedSequence` object by accessing its ``.data`` attribute. 303 304 Args: 305 input (Tensor): padded batch of variable length sequences. 306 lengths (Tensor or list(int)): list of sequence lengths of each batch 307 element (must be on the CPU if provided as a tensor). 308 batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *`` 309 format, ``T x B x *`` otherwise. 310 enforce_sorted (bool, optional): if ``True``, the input is expected to 311 contain sequences sorted by length in a decreasing order. If 312 ``False``, the input will get sorted unconditionally. Default: ``True``. 313 314 Returns: 315 a :class:`PackedSequence` object 316 """ 317 if not isinstance(lengths, torch.Tensor): 318 if torch._C._get_tracing_state(): 319 warnings.warn( 320 "pack_padded_sequence has been called with a Python list of " 321 "sequence lengths. The tracer cannot track the data flow of Python " 322 "values, and it will treat them as constants, likely rendering " 323 "the trace incorrect for any other combination of lengths.", 324 stacklevel=2, 325 ) 326 lengths = torch.as_tensor(lengths, dtype=torch.int64, device="cpu") 327 else: 328 lengths = lengths.to(dtype=torch.int64) 329 330 if enforce_sorted: 331 sorted_indices = None 332 else: 333 lengths, sorted_indices = torch.sort(lengths, descending=True) 334 sorted_indices = sorted_indices.to(input.device) 335 batch_dim = 0 if batch_first else 1 336 input = input.index_select(batch_dim, sorted_indices) 337 338 data, batch_sizes = _VF._pack_padded_sequence(input, lengths, batch_first) 339 return _packed_sequence_init(data, batch_sizes, sorted_indices, None) 340 341 342def pad_packed_sequence( 343 sequence: PackedSequence, 344 batch_first: bool = False, 345 padding_value: float = 0.0, 346 total_length: Optional[int] = None, 347) -> Tuple[Tensor, Tensor]: 348 r"""Pad a packed batch of variable length sequences. 349 350 It is an inverse operation to :func:`pack_padded_sequence`. 351 352 The returned Tensor's data will be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) 353 or ``B x T x *`` (if :attr:`batch_first` is ``True``) , where ``T`` is the length of the longest 354 sequence and ``B`` is the batch size. 355 356 Example: 357 >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 358 >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) 359 >>> lens = [2, 1, 3] 360 >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False) 361 >>> packed 362 PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), 363 sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) 364 >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) 365 >>> seq_unpacked 366 tensor([[1, 2, 0], 367 [3, 0, 0], 368 [4, 5, 6]]) 369 >>> lens_unpacked 370 tensor([2, 1, 3]) 371 372 .. note:: 373 :attr:`total_length` is useful to implement the 374 ``pack sequence -> recurrent network -> unpack sequence`` pattern in a 375 :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. 376 See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for 377 details. 378 379 Args: 380 sequence (PackedSequence): batch to pad 381 batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` 382 format, ``T x B x *`` otherwise. 383 padding_value (float, optional): values for padded elements. 384 total_length (int, optional): if not ``None``, the output will be padded to 385 have length :attr:`total_length`. This method will throw :class:`ValueError` 386 if :attr:`total_length` is less than the max sequence length in 387 :attr:`sequence`. 388 389 Returns: 390 Tuple of Tensor containing the padded sequence, and a Tensor 391 containing the list of lengths of each sequence in the batch. 392 Batch elements will be re-ordered as they were ordered originally when 393 the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``. 394 """ 395 max_seq_length = sequence.batch_sizes.size(0) 396 if total_length is not None: 397 if total_length < max_seq_length: 398 raise ValueError( 399 "Expected total_length to be at least the length " 400 "of the longest sequence in input, but got " 401 f"total_length={total_length} and max sequence length being {max_seq_length}" 402 ) 403 max_seq_length = total_length 404 padded_output, lengths = _VF._pad_packed_sequence( 405 sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length 406 ) 407 unsorted_indices = sequence.unsorted_indices 408 if unsorted_indices is not None: 409 batch_dim = 0 if batch_first else 1 410 return ( 411 padded_output.index_select(batch_dim, unsorted_indices), 412 lengths[unsorted_indices.cpu()], 413 ) 414 return padded_output, lengths 415 416 417# NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable. 418def pad_sequence( 419 sequences: Union[Tensor, List[Tensor]], 420 batch_first: bool = False, 421 padding_value: float = 0.0, 422 padding_side: str = "right", 423) -> Tensor: 424 r"""Pad a list of variable length Tensors with :attr:`padding_value`. 425 426 ``pad_sequence`` stacks a list of Tensors along a new dimension, and pads them 427 to equal length. :attr:`sequences` can be list of sequences with size ``L x *``, 428 where `L` is length of the sequence and ``*`` is any number of dimensions 429 (including 0). If :attr:`batch_first` is ``False``, the output is of size 430 ``T x B x *``, and ``B x T x *`` otherwise, where ``B`` is the batch size 431 (the number of elements in :attr:`sequences`), ``T`` is the length of the longest 432 sequence. 433 434 Example: 435 >>> from torch.nn.utils.rnn import pad_sequence 436 >>> a = torch.ones(25, 300) 437 >>> b = torch.ones(22, 300) 438 >>> c = torch.ones(15, 300) 439 >>> pad_sequence([a, b, c]).size() 440 torch.Size([25, 3, 300]) 441 442 Note: 443 This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` 444 where `T` is the length of the longest sequence. This function assumes 445 trailing dimensions and type of all the Tensors in sequences are same. 446 447 Args: 448 sequences (list[Tensor]): list of variable length sequences. 449 batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` 450 format, ``T x B x *`` otherwise. 451 padding_value (float, optional): value for padded elements. Default: 0. 452 padding_side (str, optional): the side to pad the sequences on. 453 Default: "right". 454 455 Returns: 456 Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. 457 Tensor of size ``B x T x *`` otherwise 458 """ 459 if not (torch.jit.is_tracing() or torch.jit.is_scripting()): 460 # JIT doesn't support `Iterable` 461 if not isinstance(sequences, Iterable): 462 msg = ( 463 "pad_sequence: Expected iterable for input sequences, but got arg of type: " 464 f"{type(sequences)}" 465 ) 466 raise RuntimeError(msg) 467 468 # In JIT context this leads to, 469 # RuntimeError: cannot statically infer the expected size of a list in this context 470 sequences = tuple(sequences) # type: ignore[assignment] 471 else: 472 # For JIT, we only support Union[Tensor, Tuple[Tensor]] 473 if isinstance(sequences, torch.Tensor): 474 sequences = sequences.unbind(0) # type: ignore[assignment] 475 476 # assuming trailing dimensions and type of all the Tensors 477 # in sequences are same and fetching those from sequences[0] 478 return torch._C._nn.pad_sequence( 479 sequences, batch_first, padding_value, padding_side # type: ignore[arg-type] 480 ) 481 482 483def unpad_sequence( 484 padded_sequences: Tensor, 485 lengths: Tensor, 486 batch_first: bool = False, 487) -> List[Tensor]: 488 r"""Unpad padded Tensor into a list of variable length Tensors. 489 490 ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors. 491 492 Example: 493 >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence 494 >>> a = torch.ones(25, 300) 495 >>> b = torch.ones(22, 300) 496 >>> c = torch.ones(15, 300) 497 >>> sequences = [a, b, c] 498 >>> padded_sequences = pad_sequence(sequences) 499 >>> lengths = torch.as_tensor([v.size(0) for v in sequences]) 500 >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths) 501 >>> torch.allclose(sequences[0], unpadded_sequences[0]) 502 True 503 >>> torch.allclose(sequences[1], unpadded_sequences[1]) 504 True 505 >>> torch.allclose(sequences[2], unpadded_sequences[2]) 506 True 507 508 Args: 509 padded_sequences (Tensor): padded sequences. 510 lengths (Tensor): length of original (unpadded) sequences. 511 batch_first (bool, optional): whether batch dimension first or not. Default: False. 512 513 Returns: 514 a list of :class:`Tensor` objects 515 """ 516 unpadded_sequences = [] 517 518 if not batch_first: 519 padded_sequences.transpose_(0, 1) 520 521 max_length = padded_sequences.shape[1] 522 idx = torch.arange(max_length, device=lengths.device) 523 524 for seq, length in zip(padded_sequences, lengths): 525 mask = idx < length 526 unpacked_seq = seq[mask] 527 unpadded_sequences.append(unpacked_seq) 528 529 return unpadded_sequences 530 531 532def pack_sequence( 533 sequences: List[Tensor], 534 enforce_sorted: bool = True, 535) -> PackedSequence: 536 r"""Packs a list of variable length Tensors. 537 538 Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``. 539 540 ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is 541 the length of a sequence and `*` is any number of trailing dimensions, 542 including zero. 543 544 For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` 545 is ``True``, the sequences should be sorted in the order of decreasing length. 546 ``enforce_sorted = True`` is only necessary for ONNX export. 547 548 Example: 549 >>> from torch.nn.utils.rnn import pack_sequence 550 >>> a = torch.tensor([1, 2, 3]) 551 >>> b = torch.tensor([4, 5]) 552 >>> c = torch.tensor([6]) 553 >>> pack_sequence([a, b, c]) 554 PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) 555 556 Args: 557 sequences (list[Tensor]): A list of sequences of decreasing length. 558 enforce_sorted (bool, optional): if ``True``, checks that the input 559 contains sequences sorted by length in a decreasing order. If 560 ``False``, this condition is not checked. Default: ``True``. 561 562 Returns: 563 a :class:`PackedSequence` object 564 """ 565 lengths = torch.as_tensor([v.size(0) for v in sequences]) 566 return pack_padded_sequence( 567 pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted 568 ) 569 570 571def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]: 572 r"""Unpack PackedSequence into a list of variable length Tensors. 573 574 ``packed_sequences`` should be a PackedSequence object. 575 576 Example: 577 >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence 578 >>> a = torch.tensor([1, 2, 3]) 579 >>> b = torch.tensor([4, 5]) 580 >>> c = torch.tensor([6]) 581 >>> sequences = [a, b, c] 582 >>> print(sequences) 583 [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] 584 >>> packed_sequences = pack_sequence(sequences) 585 >>> print(packed_sequences) 586 PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) 587 >>> unpacked_sequences = unpack_sequence(packed_sequences) 588 >>> print(unpacked_sequences) 589 [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] 590 591 Args: 592 packed_sequences (PackedSequence): A PackedSequence object. 593 594 Returns: 595 a list of :class:`Tensor` objects 596 """ 597 padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True) 598 unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True) 599 return unpacked_sequences 600