xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/ChannelShuffleKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cpu/ChannelShuffleKernel.h>
3 
4 #include <ATen/core/TensorBase.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/native/cpu/utils.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <c10/util/irange.h>
10 
11 namespace at::native {
12 
13 namespace {
14 
15 template <typename scalar_t>
cpu_channel_shuffle(TensorBase & output,const TensorBase & input,int64_t groups)16 void cpu_channel_shuffle(
17     TensorBase& output,
18     const TensorBase& input,
19     int64_t groups) {
20   auto input_data = input.data_ptr<scalar_t>();
21   auto output_data = output.data_ptr<scalar_t>();
22 
23   int64_t nbatch = input.size(0);
24   int64_t channels = input.size(1);
25   int64_t channels_per_group = channels / groups;
26   int64_t image_size = input.numel() / nbatch / channels;
27 
28   // treat input tensor as shape of [n, g, oc, ...]
29   // output tensor as shape of [n, oc, g, ...]
30   //
31   // 3d, 4d, 5d: parallel on dimension of n, c
32   using Vec = vec::Vectorized<scalar_t>;
33   int64_t inner_size = image_size - (image_size % Vec::size());
34   at::parallel_for (0, nbatch * /* oc*g */channels, 0, [&](int64_t begin, int64_t end) {
35     int64_t n = 0;
36     int64_t oc = 0;
37     int64_t g = 0;
38     data_index_init(begin, n, nbatch, oc, channels_per_group, g, groups);
39 
40     for (const auto i : c10::irange(begin, end)) {
41       scalar_t* output_ptr = output_data + i * image_size;
42       scalar_t* input_ptr = input_data + n * channels * image_size +
43           g * channels_per_group * image_size + oc * image_size;
44 
45       int64_t d = 0;
46       for (; d < inner_size; d += Vec::size()) {
47         Vec data_vec = Vec::loadu(input_ptr + d);
48         data_vec.store(output_ptr + d);
49       }
50       for (; d < image_size; d++) {
51         output_ptr[d] = input_ptr[d];
52       }
53 
54       // move on to next output index
55       data_index_step(n, nbatch, oc, channels_per_group, g, groups);
56     }
57   });
58 }
59 
60 template <typename scalar_t>
cpu_channel_shuffle_cl(TensorBase & output,const TensorBase & input,int64_t groups)61 void cpu_channel_shuffle_cl(
62     TensorBase& output,
63     const TensorBase& input,
64     int64_t groups) {
65   auto input_data = input.data_ptr<scalar_t>();
66   auto output_data = output.data_ptr<scalar_t>();
67 
68   int64_t nbatch = input.size(0);
69   int64_t channels = input.size(1);
70   int64_t channels_per_group = channels / groups;
71   int64_t image_size = input.numel() / nbatch / channels;
72 
73   // 4d: parallel on dimension of n, h, w
74   // 5d: parallel on dimension of n, d, h, w
75   at::parallel_for(0, nbatch * image_size, 0, [&](int64_t begin, int64_t end) {
76     for (const auto i : c10::irange(begin, end)) {
77       scalar_t* output_ptr = output_data + i * channels;
78       scalar_t* input_ptr = input_data + i * channels;
79 
80       // transpose each channel lane:
81       // from [groups, channels_per_group] to [channels_per_group, groups]
82       utils::transpose(groups, channels_per_group, input_ptr, channels_per_group, output_ptr, groups);
83     }
84   });
85 }
86 
channel_shuffle_kernel_impl(TensorBase & output,const TensorBase & input,int64_t groups)87 void channel_shuffle_kernel_impl(
88     TensorBase& output,
89     const TensorBase& input,
90     int64_t groups) {
91   switch (input.suggest_memory_format()) {
92     case at::MemoryFormat::Contiguous: {
93       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
94           input.scalar_type(), "channel_shuffle", [&] {
95         cpu_channel_shuffle<scalar_t>(output, input, groups);
96       });
97       break;
98     }
99     case at::MemoryFormat::ChannelsLast:
100     case at::MemoryFormat::ChannelsLast3d: {
101       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
102           input.scalar_type(), "channel_shuffle_cl", [&] {
103         cpu_channel_shuffle_cl<scalar_t>(output, input, groups);
104       });
105       break;
106     }
107     default:
108       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, ChannelsLast3d, Contiguous");
109   }
110 }
111 
112 } // anonymous namespace
113 
114 REGISTER_DISPATCH(channel_shuffle_kernel, &channel_shuffle_kernel_impl);
115 
116 } // at::native
117