1# mypy: allow-untyped-defs 2from typing import Tuple, Union 3 4from torch import Tensor 5from torch.types import _size 6 7from .module import Module 8 9 10__all__ = ["Flatten", "Unflatten"] 11 12 13class Flatten(Module): 14 r""" 15 Flattens a contiguous range of dims into a tensor. 16 17 For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details. 18 19 Shape: 20 - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' 21 where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any 22 number of dimensions including none. 23 - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. 24 25 Args: 26 start_dim: first dim to flatten (default = 1). 27 end_dim: last dim to flatten (default = -1). 28 29 Examples:: 30 >>> input = torch.randn(32, 1, 5, 5) 31 >>> # With default parameters 32 >>> m = nn.Flatten() 33 >>> output = m(input) 34 >>> output.size() 35 torch.Size([32, 25]) 36 >>> # With non-default parameters 37 >>> m = nn.Flatten(0, 2) 38 >>> output = m(input) 39 >>> output.size() 40 torch.Size([160, 5]) 41 """ 42 43 __constants__ = ["start_dim", "end_dim"] 44 start_dim: int 45 end_dim: int 46 47 def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: 48 super().__init__() 49 self.start_dim = start_dim 50 self.end_dim = end_dim 51 52 def forward(self, input: Tensor) -> Tensor: 53 return input.flatten(self.start_dim, self.end_dim) 54 55 def extra_repr(self) -> str: 56 return f"start_dim={self.start_dim}, end_dim={self.end_dim}" 57 58 59class Unflatten(Module): 60 r""" 61 Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. 62 63 * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can 64 be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. 65 66 * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be 67 a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` 68 (tuple of `(name, size)` tuples) for `NamedTensor` input. 69 70 Shape: 71 - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at 72 dimension :attr:`dim` and :math:`*` means any number of dimensions including none. 73 - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and 74 :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. 75 76 Args: 77 dim (Union[int, str]): Dimension to be unflattened 78 unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension 79 80 Examples: 81 >>> input = torch.randn(2, 50) 82 >>> # With tuple of ints 83 >>> m = nn.Sequential( 84 >>> nn.Linear(50, 50), 85 >>> nn.Unflatten(1, (2, 5, 5)) 86 >>> ) 87 >>> output = m(input) 88 >>> output.size() 89 torch.Size([2, 2, 5, 5]) 90 >>> # With torch.Size 91 >>> m = nn.Sequential( 92 >>> nn.Linear(50, 50), 93 >>> nn.Unflatten(1, torch.Size([2, 5, 5])) 94 >>> ) 95 >>> output = m(input) 96 >>> output.size() 97 torch.Size([2, 2, 5, 5]) 98 >>> # With namedshape (tuple of tuples) 99 >>> input = torch.randn(2, 50, names=('N', 'features')) 100 >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) 101 >>> output = unflatten(input) 102 >>> output.size() 103 torch.Size([2, 2, 5, 5]) 104 """ 105 106 NamedShape = Tuple[Tuple[str, int]] 107 108 __constants__ = ["dim", "unflattened_size"] 109 dim: Union[int, str] 110 unflattened_size: Union[_size, NamedShape] 111 112 def __init__( 113 self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape] 114 ) -> None: 115 super().__init__() 116 117 if isinstance(dim, int): 118 self._require_tuple_int(unflattened_size) 119 elif isinstance(dim, str): 120 self._require_tuple_tuple(unflattened_size) 121 else: 122 raise TypeError("invalid argument type for dim parameter") 123 124 self.dim = dim 125 self.unflattened_size = unflattened_size 126 127 def _require_tuple_tuple(self, input): 128 if isinstance(input, tuple): 129 for idx, elem in enumerate(input): 130 if not isinstance(elem, tuple): 131 raise TypeError( 132 "unflattened_size must be tuple of tuples, " 133 + f"but found element of type {type(elem).__name__} at pos {idx}" 134 ) 135 return 136 raise TypeError( 137 "unflattened_size must be a tuple of tuples, " 138 + f"but found type {type(input).__name__}" 139 ) 140 141 def _require_tuple_int(self, input): 142 if isinstance(input, (tuple, list)): 143 for idx, elem in enumerate(input): 144 if not isinstance(elem, int): 145 raise TypeError( 146 "unflattened_size must be tuple of ints, " 147 + f"but found element of type {type(elem).__name__} at pos {idx}" 148 ) 149 return 150 raise TypeError( 151 f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}" 152 ) 153 154 def forward(self, input: Tensor) -> Tensor: 155 return input.unflatten(self.dim, self.unflattened_size) 156 157 def extra_repr(self) -> str: 158 return f"dim={self.dim}, unflattened_size={self.unflattened_size}" 159