1import torch.nn.functional as F 2from torch import Tensor 3 4from .module import Module 5 6 7__all__ = ["ChannelShuffle"] 8 9 10class ChannelShuffle(Module): 11 r"""Divides and rearranges the channels in a tensor. 12 13 This operation divides the channels in a tensor of shape :math:`(N, C, *)` 14 into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them, 15 while retaining the original tensor shape in the final output. 16 17 Args: 18 groups (int): number of groups to divide channels in. 19 20 Examples:: 21 22 >>> channel_shuffle = nn.ChannelShuffle(2) 23 >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2) 24 >>> input 25 tensor([[[[ 1., 2.], 26 [ 3., 4.]], 27 [[ 5., 6.], 28 [ 7., 8.]], 29 [[ 9., 10.], 30 [11., 12.]], 31 [[13., 14.], 32 [15., 16.]]]]) 33 >>> output = channel_shuffle(input) 34 >>> output 35 tensor([[[[ 1., 2.], 36 [ 3., 4.]], 37 [[ 9., 10.], 38 [11., 12.]], 39 [[ 5., 6.], 40 [ 7., 8.]], 41 [[13., 14.], 42 [15., 16.]]]]) 43 """ 44 45 __constants__ = ["groups"] 46 groups: int 47 48 def __init__(self, groups: int) -> None: 49 super().__init__() 50 self.groups = groups 51 52 def forward(self, input: Tensor) -> Tensor: 53 return F.channel_shuffle(input, self.groups) 54 55 def extra_repr(self) -> str: 56 return f"groups={self.groups}" 57