xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Im2Col.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/Utils.h>
7 #include <ATen/div_rtn.h>
8 
9 #include <ATen/cuda/CUDAContext.h>
10 
11 #include <ATen/native/cuda/im2col.cuh>
12 #include <ATen/native/im2col_shape_check.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/empty_like.h>
19 #include <ATen/ops/col2im_native.h>
20 #include <ATen/ops/im2col_native.h>
21 #endif
22 
23 namespace at::native {
24 namespace {
25 
im2col_out_cuda_template(Tensor & output,const Tensor & input_,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)26 static void im2col_out_cuda_template(
27     Tensor& output,
28     const Tensor& input_,
29     IntArrayRef kernel_size,
30     IntArrayRef dilation,
31     IntArrayRef padding,
32     IntArrayRef stride) {
33   TORCH_CHECK(
34       kernel_size.size() == 2,
35       "It is expected kernel_size equals to 2, but got size ",
36       kernel_size.size());
37 
38   TORCH_CHECK(
39       dilation.size() == 2,
40       "It is expected dilation equals to 2, but got size ",
41       dilation.size());
42 
43   TORCH_CHECK(
44       padding.size() == 2,
45       "It is expected padding equals to 2, but got size ",
46       padding.size());
47 
48   TORCH_CHECK(
49       stride.size() == 2,
50       "It is expected stride equals to 2, but got size ",
51       stride.size());
52 
53   int64_t kernel_height = kernel_size[0];
54   int64_t kernel_width = kernel_size[1];
55   int64_t dilation_height = dilation[0];
56   int64_t dilation_width = dilation[1];
57   int64_t pad_height = padding[0];
58   int64_t pad_width = padding[1];
59   int64_t stride_height = stride[0];
60   int64_t stride_width = stride[1];
61 
62   TensorArg input_arg{input_, "input", 1};
63   TensorArg output_arg{output, "output", 2};
64   checkAllSameGPU(__func__, {input_arg, output_arg});
65 
66   im2col_shape_check(
67       input_,
68       Tensor(),
69       kernel_height,
70       kernel_width,
71       dilation_height,
72       dilation_width,
73       pad_height,
74       pad_width,
75       stride_height,
76       stride_width);
77 
78   Tensor input = input_.contiguous();
79 
80   bool batched_input = true;
81 
82   if (input.dim() == 3) {
83     batched_input = false;
84     input = input.view({1, input.size(0), input.size(1), input.size(2)});
85   }
86 
87   int64_t batch_size = input.size(0);
88   int64_t n_input_plane = input.size(1);
89   int64_t input_height = input.size(2);
90   int64_t input_width = input.size(3);
91 
92   int64_t output_height = (input_height + 2 * pad_height -
93                            (dilation_height * (kernel_height - 1) + 1)) /
94           stride_height +
95       1;
96   int64_t output_width = (input_width + 2 * pad_width -
97                           (dilation_width * (kernel_width - 1) + 1)) /
98           stride_width +
99       1;
100   int64_t n_output_plane = n_input_plane * kernel_width * kernel_height;
101   int64_t output_length = output_height * output_width;
102 
103   output.resize_({batch_size, n_output_plane, output_length});
104 
105   // Launch kernel
106   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool,
107       input.scalar_type(), "im2col_out_cuda", [&] {
108     Tensor input_n;
109     Tensor output_n;
110 
111     for (int64_t elt = 0; elt < batch_size; elt++) {
112       input_n = input.select(0, elt);
113       output_n = output.select(0, elt);
114 
115       im2col<scalar_t>(
116           at::cuda::getCurrentCUDAStream(),
117           input_n.const_data_ptr<scalar_t>(),
118           n_input_plane,
119           input_height,
120           input_width,
121           output_height,
122           output_width,
123           kernel_height,
124           kernel_width,
125           pad_height,
126           pad_width,
127           stride_height,
128           stride_width,
129           dilation_height,
130           dilation_width,
131           output_n.mutable_data_ptr<scalar_t>());
132     }
133 
134     if (!batched_input) {
135       output.resize_({n_output_plane, output_length});
136     }
137   });
138 }
139 
140 } // namespace
141 
im2col_out_cuda(const Tensor & input,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride,Tensor & output)142 Tensor& im2col_out_cuda(const Tensor& input,
143     IntArrayRef kernel_size,
144     IntArrayRef dilation,
145     IntArrayRef padding,
146     IntArrayRef stride,
147     Tensor& output) {
148   im2col_out_cuda_template(
149       output, input, kernel_size, dilation, padding, stride);
150   return output;
151 }
152 
im2col_cuda(const Tensor & input,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)153 Tensor im2col_cuda(
154     const Tensor& input,
155     IntArrayRef kernel_size,
156     IntArrayRef dilation,
157     IntArrayRef padding,
158     IntArrayRef stride) {
159   Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
160   im2col_out_cuda_template(
161       output, input, kernel_size, dilation, padding, stride);
162   return output;
163 }
164 
165 } // namespace at::native
166