xref: /aosp_15_r20/external/pytorch/torch/nn/modules/channelshuffle.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch.nn.functional as F
2from torch import Tensor
3
4from .module import Module
5
6
7__all__ = ["ChannelShuffle"]
8
9
10class ChannelShuffle(Module):
11    r"""Divides and rearranges the channels in a tensor.
12
13    This operation divides the channels in a tensor of shape :math:`(N, C, *)`
14    into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them,
15    while retaining the original tensor shape in the final output.
16
17    Args:
18        groups (int): number of groups to divide channels in.
19
20    Examples::
21
22        >>> channel_shuffle = nn.ChannelShuffle(2)
23        >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
24        >>> input
25        tensor([[[[ 1.,  2.],
26                  [ 3.,  4.]],
27                 [[ 5.,  6.],
28                  [ 7.,  8.]],
29                 [[ 9., 10.],
30                  [11., 12.]],
31                 [[13., 14.],
32                  [15., 16.]]]])
33        >>> output = channel_shuffle(input)
34        >>> output
35        tensor([[[[ 1.,  2.],
36                  [ 3.,  4.]],
37                 [[ 9., 10.],
38                  [11., 12.]],
39                 [[ 5.,  6.],
40                  [ 7.,  8.]],
41                 [[13., 14.],
42                  [15., 16.]]]])
43    """
44
45    __constants__ = ["groups"]
46    groups: int
47
48    def __init__(self, groups: int) -> None:
49        super().__init__()
50        self.groups = groups
51
52    def forward(self, input: Tensor) -> Tensor:
53        return F.channel_shuffle(input, self.groups)
54
55    def extra_repr(self) -> str:
56        return f"groups={self.groups}"
57