xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Col2Im.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/Utils.h>
7 #include <ATen/div_rtn.h>
8 
9 #include <ATen/cuda/CUDAContext.h>
10 
11 #include <ATen/native/cuda/im2col.cuh>
12 #include <ATen/native/im2col_shape_check.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/col2im_native.h>
19 #include <ATen/ops/empty_like.h>
20 #include <ATen/ops/im2col_native.h>
21 #endif
22 
23 namespace at::native {
24 namespace {
25 
col2im_out_cuda_template(Tensor & output,const Tensor & input_,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)26 void col2im_out_cuda_template(
27     Tensor& output,
28     const Tensor& input_,
29     IntArrayRef output_size,
30     IntArrayRef kernel_size,
31     IntArrayRef dilation,
32     IntArrayRef padding,
33     IntArrayRef stride) {
34   TensorArg input_arg{input_, "input", 1};
35   TensorArg output_arg{output, "output", 2};
36   checkAllSameGPU(__func__, {input_arg, output_arg});
37 
38   TORCH_CHECK(
39       output_size.size() == 2,
40       "It is expected output_size equals to 2, but got size ",
41       output_size.size());
42 
43   TORCH_CHECK(
44       kernel_size.size() == 2,
45       "It is expected kernel_size equals to 2, but got size ",
46       kernel_size.size());
47 
48   TORCH_CHECK(
49       dilation.size() == 2,
50       "It is expected dilation equals to 2, but got size ",
51       dilation.size());
52 
53   TORCH_CHECK(
54       padding.size() == 2,
55       "It is expected padding equals to 2, but got size ",
56       padding.size());
57 
58   TORCH_CHECK(
59       stride.size() == 2,
60       "It is expected stride equals to 2, but got size ",
61       stride.size());
62 
63   int64_t output_height = output_size[0];
64   int64_t output_width = output_size[1];
65   int64_t kernel_height = kernel_size[0];
66   int64_t kernel_width = kernel_size[1];
67   int64_t dilation_height = dilation[0];
68   int64_t dilation_width = dilation[1];
69   int64_t pad_height = padding[0];
70   int64_t pad_width = padding[1];
71   int64_t stride_height = stride[0];
72   int64_t stride_width = stride[1];
73 
74   col2im_shape_check(
75       input_,
76       Tensor(),
77       output_height,
78       output_width,
79       kernel_height,
80       kernel_width,
81       dilation_height,
82       dilation_width,
83       pad_height,
84       pad_width,
85       stride_height,
86       stride_width);
87 
88   Tensor input = input_.contiguous();
89 
90   bool batched_input = true;
91   if (input.dim() == 2) {
92     // Force batch
93     batched_input = false;
94     input = input.view({1, input.size(0), input.size(1)});
95   }
96 
97   int64_t batch_size = input.size(0);
98   int64_t n_input_plane = input.size(1);
99   int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height);
100   int64_t input_batch_stride = input.stride(0);
101 
102   output.resize_({batch_size, n_output_plane, output_height, output_width});
103   int64_t output_batch_stride = output.stride(0);
104 
105   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool,
106       input.scalar_type(), "col2im_out_cuda", [&] {
107     int64_t height_col = (output_height + 2 * pad_height -
108                           (dilation_height * (kernel_height - 1) + 1)) /
109             stride_height +
110         1;
111     int64_t width_col = (output_width + 2 * pad_width -
112                          (dilation_width * (kernel_width - 1) + 1)) /
113             stride_width +
114         1;
115 
116     col2im_batched(
117         at::cuda::getCurrentCUDAStream(),
118         input.const_data_ptr<scalar_t>(),
119         input_batch_stride,
120         batch_size,
121         n_output_plane,
122         output_height,
123         output_width,
124         height_col,
125         width_col,
126         kernel_height,
127         kernel_width,
128         pad_height,
129         pad_width,
130         stride_height,
131         stride_width,
132         dilation_height,
133         dilation_width,
134         output.mutable_data_ptr<scalar_t>(),
135         output_batch_stride);
136 
137     if (!batched_input) {
138       output.resize_({n_output_plane, output_height, output_width});
139     }
140   });
141 }
142 
143 } // namespace
144 
col2im_out_cuda(const Tensor & input,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride,Tensor & output)145 Tensor& col2im_out_cuda(const Tensor& input,
146     IntArrayRef output_size,
147     IntArrayRef kernel_size,
148     IntArrayRef dilation,
149     IntArrayRef padding,
150     IntArrayRef stride,
151     Tensor& output) {
152   col2im_out_cuda_template(
153       output, input, output_size, kernel_size, dilation, padding, stride);
154   return output;
155 }
156 
col2im_cuda(const Tensor & input,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)157 Tensor col2im_cuda(
158     const Tensor& input,
159     IntArrayRef output_size,
160     IntArrayRef kernel_size,
161     IntArrayRef dilation,
162     IntArrayRef padding,
163     IntArrayRef stride) {
164   Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
165 
166   col2im_out_cuda_template(
167       output, input, output_size, kernel_size, dilation, padding, stride);
168   return output;
169 }
170 
171 } // namespace at::native
172