xref: /aosp_15_r20/external/pytorch/torch/nn/modules/flatten.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Tuple, Union
3
4from torch import Tensor
5from torch.types import _size
6
7from .module import Module
8
9
10__all__ = ["Flatten", "Unflatten"]
11
12
13class Flatten(Module):
14    r"""
15    Flattens a contiguous range of dims into a tensor.
16
17    For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details.
18
19    Shape:
20        - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
21          where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
22          number of dimensions including none.
23        - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
24
25    Args:
26        start_dim: first dim to flatten (default = 1).
27        end_dim: last dim to flatten (default = -1).
28
29    Examples::
30        >>> input = torch.randn(32, 1, 5, 5)
31        >>> # With default parameters
32        >>> m = nn.Flatten()
33        >>> output = m(input)
34        >>> output.size()
35        torch.Size([32, 25])
36        >>> # With non-default parameters
37        >>> m = nn.Flatten(0, 2)
38        >>> output = m(input)
39        >>> output.size()
40        torch.Size([160, 5])
41    """
42
43    __constants__ = ["start_dim", "end_dim"]
44    start_dim: int
45    end_dim: int
46
47    def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
48        super().__init__()
49        self.start_dim = start_dim
50        self.end_dim = end_dim
51
52    def forward(self, input: Tensor) -> Tensor:
53        return input.flatten(self.start_dim, self.end_dim)
54
55    def extra_repr(self) -> str:
56        return f"start_dim={self.start_dim}, end_dim={self.end_dim}"
57
58
59class Unflatten(Module):
60    r"""
61    Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
62
63    * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
64      be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
65
66    * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
67      a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input;  a `NamedShape`
68      (tuple of `(name, size)` tuples) for `NamedTensor` input.
69
70    Shape:
71        - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
72          dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
73        - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
74          :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
75
76    Args:
77        dim (Union[int, str]): Dimension to be unflattened
78        unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
79
80    Examples:
81        >>> input = torch.randn(2, 50)
82        >>> # With tuple of ints
83        >>> m = nn.Sequential(
84        >>>     nn.Linear(50, 50),
85        >>>     nn.Unflatten(1, (2, 5, 5))
86        >>> )
87        >>> output = m(input)
88        >>> output.size()
89        torch.Size([2, 2, 5, 5])
90        >>> # With torch.Size
91        >>> m = nn.Sequential(
92        >>>     nn.Linear(50, 50),
93        >>>     nn.Unflatten(1, torch.Size([2, 5, 5]))
94        >>> )
95        >>> output = m(input)
96        >>> output.size()
97        torch.Size([2, 2, 5, 5])
98        >>> # With namedshape (tuple of tuples)
99        >>> input = torch.randn(2, 50, names=('N', 'features'))
100        >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
101        >>> output = unflatten(input)
102        >>> output.size()
103        torch.Size([2, 2, 5, 5])
104    """
105
106    NamedShape = Tuple[Tuple[str, int]]
107
108    __constants__ = ["dim", "unflattened_size"]
109    dim: Union[int, str]
110    unflattened_size: Union[_size, NamedShape]
111
112    def __init__(
113        self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]
114    ) -> None:
115        super().__init__()
116
117        if isinstance(dim, int):
118            self._require_tuple_int(unflattened_size)
119        elif isinstance(dim, str):
120            self._require_tuple_tuple(unflattened_size)
121        else:
122            raise TypeError("invalid argument type for dim parameter")
123
124        self.dim = dim
125        self.unflattened_size = unflattened_size
126
127    def _require_tuple_tuple(self, input):
128        if isinstance(input, tuple):
129            for idx, elem in enumerate(input):
130                if not isinstance(elem, tuple):
131                    raise TypeError(
132                        "unflattened_size must be tuple of tuples, "
133                        + f"but found element of type {type(elem).__name__} at pos {idx}"
134                    )
135            return
136        raise TypeError(
137            "unflattened_size must be a tuple of tuples, "
138            + f"but found type {type(input).__name__}"
139        )
140
141    def _require_tuple_int(self, input):
142        if isinstance(input, (tuple, list)):
143            for idx, elem in enumerate(input):
144                if not isinstance(elem, int):
145                    raise TypeError(
146                        "unflattened_size must be tuple of ints, "
147                        + f"but found element of type {type(elem).__name__} at pos {idx}"
148                    )
149            return
150        raise TypeError(
151            f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}"
152        )
153
154    def forward(self, input: Tensor) -> Tensor:
155        return input.unflatten(self.dim, self.unflattened_size)
156
157    def extra_repr(self) -> str:
158        return f"dim={self.dim}, unflattened_size={self.unflattened_size}"
159