1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cpu/ChannelShuffleKernel.h>
3
4 #include <ATen/core/TensorBase.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/native/cpu/utils.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <c10/util/irange.h>
10
11 namespace at::native {
12
13 namespace {
14
15 template <typename scalar_t>
cpu_channel_shuffle(TensorBase & output,const TensorBase & input,int64_t groups)16 void cpu_channel_shuffle(
17 TensorBase& output,
18 const TensorBase& input,
19 int64_t groups) {
20 auto input_data = input.data_ptr<scalar_t>();
21 auto output_data = output.data_ptr<scalar_t>();
22
23 int64_t nbatch = input.size(0);
24 int64_t channels = input.size(1);
25 int64_t channels_per_group = channels / groups;
26 int64_t image_size = input.numel() / nbatch / channels;
27
28 // treat input tensor as shape of [n, g, oc, ...]
29 // output tensor as shape of [n, oc, g, ...]
30 //
31 // 3d, 4d, 5d: parallel on dimension of n, c
32 using Vec = vec::Vectorized<scalar_t>;
33 int64_t inner_size = image_size - (image_size % Vec::size());
34 at::parallel_for (0, nbatch * /* oc*g */channels, 0, [&](int64_t begin, int64_t end) {
35 int64_t n = 0;
36 int64_t oc = 0;
37 int64_t g = 0;
38 data_index_init(begin, n, nbatch, oc, channels_per_group, g, groups);
39
40 for (const auto i : c10::irange(begin, end)) {
41 scalar_t* output_ptr = output_data + i * image_size;
42 scalar_t* input_ptr = input_data + n * channels * image_size +
43 g * channels_per_group * image_size + oc * image_size;
44
45 int64_t d = 0;
46 for (; d < inner_size; d += Vec::size()) {
47 Vec data_vec = Vec::loadu(input_ptr + d);
48 data_vec.store(output_ptr + d);
49 }
50 for (; d < image_size; d++) {
51 output_ptr[d] = input_ptr[d];
52 }
53
54 // move on to next output index
55 data_index_step(n, nbatch, oc, channels_per_group, g, groups);
56 }
57 });
58 }
59
60 template <typename scalar_t>
cpu_channel_shuffle_cl(TensorBase & output,const TensorBase & input,int64_t groups)61 void cpu_channel_shuffle_cl(
62 TensorBase& output,
63 const TensorBase& input,
64 int64_t groups) {
65 auto input_data = input.data_ptr<scalar_t>();
66 auto output_data = output.data_ptr<scalar_t>();
67
68 int64_t nbatch = input.size(0);
69 int64_t channels = input.size(1);
70 int64_t channels_per_group = channels / groups;
71 int64_t image_size = input.numel() / nbatch / channels;
72
73 // 4d: parallel on dimension of n, h, w
74 // 5d: parallel on dimension of n, d, h, w
75 at::parallel_for(0, nbatch * image_size, 0, [&](int64_t begin, int64_t end) {
76 for (const auto i : c10::irange(begin, end)) {
77 scalar_t* output_ptr = output_data + i * channels;
78 scalar_t* input_ptr = input_data + i * channels;
79
80 // transpose each channel lane:
81 // from [groups, channels_per_group] to [channels_per_group, groups]
82 utils::transpose(groups, channels_per_group, input_ptr, channels_per_group, output_ptr, groups);
83 }
84 });
85 }
86
channel_shuffle_kernel_impl(TensorBase & output,const TensorBase & input,int64_t groups)87 void channel_shuffle_kernel_impl(
88 TensorBase& output,
89 const TensorBase& input,
90 int64_t groups) {
91 switch (input.suggest_memory_format()) {
92 case at::MemoryFormat::Contiguous: {
93 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
94 input.scalar_type(), "channel_shuffle", [&] {
95 cpu_channel_shuffle<scalar_t>(output, input, groups);
96 });
97 break;
98 }
99 case at::MemoryFormat::ChannelsLast:
100 case at::MemoryFormat::ChannelsLast3d: {
101 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
102 input.scalar_type(), "channel_shuffle_cl", [&] {
103 cpu_channel_shuffle_cl<scalar_t>(output, input, groups);
104 });
105 break;
106 }
107 default:
108 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, ChannelsLast3d, Contiguous");
109 }
110 }
111
112 } // anonymous namespace
113
114 REGISTER_DISPATCH(channel_shuffle_kernel, &channel_shuffle_kernel_impl);
115
116 } // at::native
117