xref: /aosp_15_r20/external/pytorch/torch/nn/modules/pixelshuffle.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch.nn.functional as F
2from torch import Tensor
3
4from .module import Module
5
6
7__all__ = ["PixelShuffle", "PixelUnshuffle"]
8
9
10class PixelShuffle(Module):
11    r"""Rearrange elements in a tensor according to an upscaling factor.
12
13    Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
14    to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
15
16    This is useful for implementing efficient sub-pixel convolution
17    with a stride of :math:`1/r`.
18
19    See the paper:
20    `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
21    by Shi et al. (2016) for more details.
22
23    Args:
24        upscale_factor (int): factor to increase spatial resolution by
25
26    Shape:
27        - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
28        - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
29
30    .. math::
31        C_{out} = C_{in} \div \text{upscale\_factor}^2
32
33    .. math::
34        H_{out} = H_{in} \times \text{upscale\_factor}
35
36    .. math::
37        W_{out} = W_{in} \times \text{upscale\_factor}
38
39    Examples::
40
41        >>> pixel_shuffle = nn.PixelShuffle(3)
42        >>> input = torch.randn(1, 9, 4, 4)
43        >>> output = pixel_shuffle(input)
44        >>> print(output.size())
45        torch.Size([1, 1, 12, 12])
46
47    .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
48        https://arxiv.org/abs/1609.05158
49    """
50
51    __constants__ = ["upscale_factor"]
52    upscale_factor: int
53
54    def __init__(self, upscale_factor: int) -> None:
55        super().__init__()
56        self.upscale_factor = upscale_factor
57
58    def forward(self, input: Tensor) -> Tensor:
59        return F.pixel_shuffle(input, self.upscale_factor)
60
61    def extra_repr(self) -> str:
62        return f"upscale_factor={self.upscale_factor}"
63
64
65class PixelUnshuffle(Module):
66    r"""Reverse the PixelShuffle operation.
67
68    Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements
69    in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
70    :math:`(*, C \times r^2, H, W)`, where r is a downscale factor.
71
72    See the paper:
73    `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
74    by Shi et al. (2016) for more details.
75
76    Args:
77        downscale_factor (int): factor to decrease spatial resolution by
78
79    Shape:
80        - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
81        - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
82
83    .. math::
84        C_{out} = C_{in} \times \text{downscale\_factor}^2
85
86    .. math::
87        H_{out} = H_{in} \div \text{downscale\_factor}
88
89    .. math::
90        W_{out} = W_{in} \div \text{downscale\_factor}
91
92    Examples::
93
94        >>> pixel_unshuffle = nn.PixelUnshuffle(3)
95        >>> input = torch.randn(1, 1, 12, 12)
96        >>> output = pixel_unshuffle(input)
97        >>> print(output.size())
98        torch.Size([1, 9, 4, 4])
99
100    .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
101        https://arxiv.org/abs/1609.05158
102    """
103
104    __constants__ = ["downscale_factor"]
105    downscale_factor: int
106
107    def __init__(self, downscale_factor: int) -> None:
108        super().__init__()
109        self.downscale_factor = downscale_factor
110
111    def forward(self, input: Tensor) -> Tensor:
112        return F.pixel_unshuffle(input, self.downscale_factor)
113
114    def extra_repr(self) -> str:
115        return f"downscale_factor={self.downscale_factor}"
116