1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/Utils.h>
7 #include <ATen/div_rtn.h>
8
9 #include <ATen/cuda/CUDAContext.h>
10
11 #include <ATen/native/cuda/im2col.cuh>
12 #include <ATen/native/im2col_shape_check.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/empty_like.h>
19 #include <ATen/ops/col2im_native.h>
20 #include <ATen/ops/im2col_native.h>
21 #endif
22
23 namespace at::native {
24 namespace {
25
im2col_out_cuda_template(Tensor & output,const Tensor & input_,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)26 static void im2col_out_cuda_template(
27 Tensor& output,
28 const Tensor& input_,
29 IntArrayRef kernel_size,
30 IntArrayRef dilation,
31 IntArrayRef padding,
32 IntArrayRef stride) {
33 TORCH_CHECK(
34 kernel_size.size() == 2,
35 "It is expected kernel_size equals to 2, but got size ",
36 kernel_size.size());
37
38 TORCH_CHECK(
39 dilation.size() == 2,
40 "It is expected dilation equals to 2, but got size ",
41 dilation.size());
42
43 TORCH_CHECK(
44 padding.size() == 2,
45 "It is expected padding equals to 2, but got size ",
46 padding.size());
47
48 TORCH_CHECK(
49 stride.size() == 2,
50 "It is expected stride equals to 2, but got size ",
51 stride.size());
52
53 int64_t kernel_height = kernel_size[0];
54 int64_t kernel_width = kernel_size[1];
55 int64_t dilation_height = dilation[0];
56 int64_t dilation_width = dilation[1];
57 int64_t pad_height = padding[0];
58 int64_t pad_width = padding[1];
59 int64_t stride_height = stride[0];
60 int64_t stride_width = stride[1];
61
62 TensorArg input_arg{input_, "input", 1};
63 TensorArg output_arg{output, "output", 2};
64 checkAllSameGPU(__func__, {input_arg, output_arg});
65
66 im2col_shape_check(
67 input_,
68 Tensor(),
69 kernel_height,
70 kernel_width,
71 dilation_height,
72 dilation_width,
73 pad_height,
74 pad_width,
75 stride_height,
76 stride_width);
77
78 Tensor input = input_.contiguous();
79
80 bool batched_input = true;
81
82 if (input.dim() == 3) {
83 batched_input = false;
84 input = input.view({1, input.size(0), input.size(1), input.size(2)});
85 }
86
87 int64_t batch_size = input.size(0);
88 int64_t n_input_plane = input.size(1);
89 int64_t input_height = input.size(2);
90 int64_t input_width = input.size(3);
91
92 int64_t output_height = (input_height + 2 * pad_height -
93 (dilation_height * (kernel_height - 1) + 1)) /
94 stride_height +
95 1;
96 int64_t output_width = (input_width + 2 * pad_width -
97 (dilation_width * (kernel_width - 1) + 1)) /
98 stride_width +
99 1;
100 int64_t n_output_plane = n_input_plane * kernel_width * kernel_height;
101 int64_t output_length = output_height * output_width;
102
103 output.resize_({batch_size, n_output_plane, output_length});
104
105 // Launch kernel
106 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool,
107 input.scalar_type(), "im2col_out_cuda", [&] {
108 Tensor input_n;
109 Tensor output_n;
110
111 for (int64_t elt = 0; elt < batch_size; elt++) {
112 input_n = input.select(0, elt);
113 output_n = output.select(0, elt);
114
115 im2col<scalar_t>(
116 at::cuda::getCurrentCUDAStream(),
117 input_n.const_data_ptr<scalar_t>(),
118 n_input_plane,
119 input_height,
120 input_width,
121 output_height,
122 output_width,
123 kernel_height,
124 kernel_width,
125 pad_height,
126 pad_width,
127 stride_height,
128 stride_width,
129 dilation_height,
130 dilation_width,
131 output_n.mutable_data_ptr<scalar_t>());
132 }
133
134 if (!batched_input) {
135 output.resize_({n_output_plane, output_length});
136 }
137 });
138 }
139
140 } // namespace
141
im2col_out_cuda(const Tensor & input,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride,Tensor & output)142 Tensor& im2col_out_cuda(const Tensor& input,
143 IntArrayRef kernel_size,
144 IntArrayRef dilation,
145 IntArrayRef padding,
146 IntArrayRef stride,
147 Tensor& output) {
148 im2col_out_cuda_template(
149 output, input, kernel_size, dilation, padding, stride);
150 return output;
151 }
152
im2col_cuda(const Tensor & input,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)153 Tensor im2col_cuda(
154 const Tensor& input,
155 IntArrayRef kernel_size,
156 IntArrayRef dilation,
157 IntArrayRef padding,
158 IntArrayRef stride) {
159 Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
160 im2col_out_cuda_template(
161 output, input, kernel_size, dilation, padding, stride);
162 return output;
163 }
164
165 } // namespace at::native
166