1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/Utils.h>
7 #include <ATen/div_rtn.h>
8
9 #include <ATen/cuda/CUDAContext.h>
10
11 #include <ATen/native/cuda/im2col.cuh>
12 #include <ATen/native/im2col_shape_check.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/col2im_native.h>
19 #include <ATen/ops/empty_like.h>
20 #include <ATen/ops/im2col_native.h>
21 #endif
22
23 namespace at::native {
24 namespace {
25
col2im_out_cuda_template(Tensor & output,const Tensor & input_,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)26 void col2im_out_cuda_template(
27 Tensor& output,
28 const Tensor& input_,
29 IntArrayRef output_size,
30 IntArrayRef kernel_size,
31 IntArrayRef dilation,
32 IntArrayRef padding,
33 IntArrayRef stride) {
34 TensorArg input_arg{input_, "input", 1};
35 TensorArg output_arg{output, "output", 2};
36 checkAllSameGPU(__func__, {input_arg, output_arg});
37
38 TORCH_CHECK(
39 output_size.size() == 2,
40 "It is expected output_size equals to 2, but got size ",
41 output_size.size());
42
43 TORCH_CHECK(
44 kernel_size.size() == 2,
45 "It is expected kernel_size equals to 2, but got size ",
46 kernel_size.size());
47
48 TORCH_CHECK(
49 dilation.size() == 2,
50 "It is expected dilation equals to 2, but got size ",
51 dilation.size());
52
53 TORCH_CHECK(
54 padding.size() == 2,
55 "It is expected padding equals to 2, but got size ",
56 padding.size());
57
58 TORCH_CHECK(
59 stride.size() == 2,
60 "It is expected stride equals to 2, but got size ",
61 stride.size());
62
63 int64_t output_height = output_size[0];
64 int64_t output_width = output_size[1];
65 int64_t kernel_height = kernel_size[0];
66 int64_t kernel_width = kernel_size[1];
67 int64_t dilation_height = dilation[0];
68 int64_t dilation_width = dilation[1];
69 int64_t pad_height = padding[0];
70 int64_t pad_width = padding[1];
71 int64_t stride_height = stride[0];
72 int64_t stride_width = stride[1];
73
74 col2im_shape_check(
75 input_,
76 Tensor(),
77 output_height,
78 output_width,
79 kernel_height,
80 kernel_width,
81 dilation_height,
82 dilation_width,
83 pad_height,
84 pad_width,
85 stride_height,
86 stride_width);
87
88 Tensor input = input_.contiguous();
89
90 bool batched_input = true;
91 if (input.dim() == 2) {
92 // Force batch
93 batched_input = false;
94 input = input.view({1, input.size(0), input.size(1)});
95 }
96
97 int64_t batch_size = input.size(0);
98 int64_t n_input_plane = input.size(1);
99 int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height);
100 int64_t input_batch_stride = input.stride(0);
101
102 output.resize_({batch_size, n_output_plane, output_height, output_width});
103 int64_t output_batch_stride = output.stride(0);
104
105 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool,
106 input.scalar_type(), "col2im_out_cuda", [&] {
107 int64_t height_col = (output_height + 2 * pad_height -
108 (dilation_height * (kernel_height - 1) + 1)) /
109 stride_height +
110 1;
111 int64_t width_col = (output_width + 2 * pad_width -
112 (dilation_width * (kernel_width - 1) + 1)) /
113 stride_width +
114 1;
115
116 col2im_batched(
117 at::cuda::getCurrentCUDAStream(),
118 input.const_data_ptr<scalar_t>(),
119 input_batch_stride,
120 batch_size,
121 n_output_plane,
122 output_height,
123 output_width,
124 height_col,
125 width_col,
126 kernel_height,
127 kernel_width,
128 pad_height,
129 pad_width,
130 stride_height,
131 stride_width,
132 dilation_height,
133 dilation_width,
134 output.mutable_data_ptr<scalar_t>(),
135 output_batch_stride);
136
137 if (!batched_input) {
138 output.resize_({n_output_plane, output_height, output_width});
139 }
140 });
141 }
142
143 } // namespace
144
col2im_out_cuda(const Tensor & input,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride,Tensor & output)145 Tensor& col2im_out_cuda(const Tensor& input,
146 IntArrayRef output_size,
147 IntArrayRef kernel_size,
148 IntArrayRef dilation,
149 IntArrayRef padding,
150 IntArrayRef stride,
151 Tensor& output) {
152 col2im_out_cuda_template(
153 output, input, output_size, kernel_size, dilation, padding, stride);
154 return output;
155 }
156
col2im_cuda(const Tensor & input,IntArrayRef output_size,IntArrayRef kernel_size,IntArrayRef dilation,IntArrayRef padding,IntArrayRef stride)157 Tensor col2im_cuda(
158 const Tensor& input,
159 IntArrayRef output_size,
160 IntArrayRef kernel_size,
161 IntArrayRef dilation,
162 IntArrayRef padding,
163 IntArrayRef stride) {
164 Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
165
166 col2im_out_cuda_template(
167 output, input, output_size, kernel_size, dilation, padding, stride);
168 return output;
169 }
170
171 } // namespace at::native
172