1import torch.nn.functional as F 2from torch import Tensor 3 4from .module import Module 5 6 7__all__ = ["PixelShuffle", "PixelUnshuffle"] 8 9 10class PixelShuffle(Module): 11 r"""Rearrange elements in a tensor according to an upscaling factor. 12 13 Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` 14 to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. 15 16 This is useful for implementing efficient sub-pixel convolution 17 with a stride of :math:`1/r`. 18 19 See the paper: 20 `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ 21 by Shi et al. (2016) for more details. 22 23 Args: 24 upscale_factor (int): factor to increase spatial resolution by 25 26 Shape: 27 - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions 28 - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where 29 30 .. math:: 31 C_{out} = C_{in} \div \text{upscale\_factor}^2 32 33 .. math:: 34 H_{out} = H_{in} \times \text{upscale\_factor} 35 36 .. math:: 37 W_{out} = W_{in} \times \text{upscale\_factor} 38 39 Examples:: 40 41 >>> pixel_shuffle = nn.PixelShuffle(3) 42 >>> input = torch.randn(1, 9, 4, 4) 43 >>> output = pixel_shuffle(input) 44 >>> print(output.size()) 45 torch.Size([1, 1, 12, 12]) 46 47 .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: 48 https://arxiv.org/abs/1609.05158 49 """ 50 51 __constants__ = ["upscale_factor"] 52 upscale_factor: int 53 54 def __init__(self, upscale_factor: int) -> None: 55 super().__init__() 56 self.upscale_factor = upscale_factor 57 58 def forward(self, input: Tensor) -> Tensor: 59 return F.pixel_shuffle(input, self.upscale_factor) 60 61 def extra_repr(self) -> str: 62 return f"upscale_factor={self.upscale_factor}" 63 64 65class PixelUnshuffle(Module): 66 r"""Reverse the PixelShuffle operation. 67 68 Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements 69 in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape 70 :math:`(*, C \times r^2, H, W)`, where r is a downscale factor. 71 72 See the paper: 73 `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ 74 by Shi et al. (2016) for more details. 75 76 Args: 77 downscale_factor (int): factor to decrease spatial resolution by 78 79 Shape: 80 - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions 81 - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where 82 83 .. math:: 84 C_{out} = C_{in} \times \text{downscale\_factor}^2 85 86 .. math:: 87 H_{out} = H_{in} \div \text{downscale\_factor} 88 89 .. math:: 90 W_{out} = W_{in} \div \text{downscale\_factor} 91 92 Examples:: 93 94 >>> pixel_unshuffle = nn.PixelUnshuffle(3) 95 >>> input = torch.randn(1, 1, 12, 12) 96 >>> output = pixel_unshuffle(input) 97 >>> print(output.size()) 98 torch.Size([1, 9, 4, 4]) 99 100 .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: 101 https://arxiv.org/abs/1609.05158 102 """ 103 104 __constants__ = ["downscale_factor"] 105 downscale_factor: int 106 107 def __init__(self, downscale_factor: int) -> None: 108 super().__init__() 109 self.downscale_factor = downscale_factor 110 111 def forward(self, input: Tensor) -> Tensor: 112 return F.pixel_unshuffle(input, self.downscale_factor) 113 114 def extra_repr(self) -> str: 115 return f"downscale_factor={self.downscale_factor}" 116