1 #pragma once
2
3 #include <ATen/AccumulateType.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/detail/KernelUtils.h>
6
7 #include <c10/macros/Macros.h>
8
9 namespace at {
10 namespace native {
11
12 using namespace at::cuda::detail;
13
14 // Kernel for fast unfold+copy
15 // (borrowed from Caffe:
16 // https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
17 // CUDA_NUM_THREADS = 1024
18
19 template <typename dt>
20 C10_LAUNCH_BOUNDS_1(1024)
im2col_kernel(const int64_t n,const dt * data_im,const int64_t height,const int64_t width,const int64_t kernel_height,const int64_t kernel_width,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,const int64_t height_col,const int64_t width_col,dt * data_col)21 __global__ void im2col_kernel(
22 const int64_t n,
23 const dt* data_im,
24 const int64_t height,
25 const int64_t width,
26 const int64_t kernel_height,
27 const int64_t kernel_width,
28 const int64_t pad_height,
29 const int64_t pad_width,
30 const int64_t stride_height,
31 const int64_t stride_width,
32 const int64_t dilation_height,
33 const int64_t dilation_width,
34 const int64_t height_col,
35 const int64_t width_col,
36 dt* data_col) {
37 CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) {
38 int64_t w_out = index % width_col;
39
40 int64_t idx = index / width_col;
41
42 int64_t h_out = idx % height_col;
43 int64_t channel_in = idx / height_col;
44 int64_t channel_out = channel_in * kernel_height * kernel_width;
45 int64_t h_in = h_out * stride_height - pad_height;
46 int64_t w_in = w_out * stride_width - pad_width;
47
48 dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out;
49 const dt* im = data_im + (channel_in * height + h_in) * width + w_in;
50
51 for (int64_t i = 0; i < kernel_height; ++i) {
52 for (int64_t j = 0; j < kernel_width; ++j) {
53 int64_t h = h_in + i * dilation_height;
54 int64_t w = w_in + j * dilation_width;
55 *col = (h >= 0 && w >= 0 && h < height && w < width)
56 ? im[i * dilation_height * width + j * dilation_width]
57 : static_cast<dt>(0);
58 col += height_col * width_col;
59 }
60 }
61 }
62 }
63
64 template <typename dt>
im2col(cudaStream_t stream,const dt * data_im,const int64_t channels,const int64_t height,const int64_t width,const int64_t height_col,const int64_t width_col,const int64_t kernel_height,const int64_t kernel_width,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,dt * data_col)65 void im2col(
66 cudaStream_t stream,
67 const dt* data_im,
68 const int64_t channels,
69 const int64_t height,
70 const int64_t width,
71 const int64_t height_col,
72 const int64_t width_col,
73 const int64_t kernel_height,
74 const int64_t kernel_width,
75 const int64_t pad_height,
76 const int64_t pad_width,
77 const int64_t stride_height,
78 const int64_t stride_width,
79 const int64_t dilation_height,
80 const int64_t dilation_width,
81 dt* data_col) {
82 // We are going to launch channels * height_col * width_col kernels, each
83 // kernel responsible for copying a single-channel grid.
84 int64_t num_kernels = channels * height_col * width_col;
85 // Launch CUDA_NUM_THREADS = 1024
86 im2col_kernel<<<GET_BLOCKS(num_kernels), 1024, 0, stream>>>(
87 num_kernels,
88 data_im,
89 height,
90 width,
91 kernel_height,
92 kernel_width,
93 pad_height,
94 pad_width,
95 stride_height,
96 stride_width,
97 dilation_height,
98 dilation_width,
99 height_col,
100 width_col,
101 data_col);
102 C10_CUDA_KERNEL_LAUNCH_CHECK();
103 }
104
105 template <typename accT, typename dt>
col2im_device(const int64_t index,const dt * data_col,const int64_t height,const int64_t width,const int64_t channels,const int64_t kernel_h,const int64_t kernel_w,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,const int64_t height_col,const int64_t width_col,dt * data_im)106 __forceinline__ __device__ void col2im_device(
107 const int64_t index,
108 const dt* data_col,
109 const int64_t height,
110 const int64_t width,
111 const int64_t channels,
112 const int64_t kernel_h,
113 const int64_t kernel_w,
114 const int64_t pad_height,
115 const int64_t pad_width,
116 const int64_t stride_height,
117 const int64_t stride_width,
118 const int64_t dilation_height,
119 const int64_t dilation_width,
120 const int64_t height_col,
121 const int64_t width_col,
122 dt* data_im) {
123 accT val = static_cast<accT>(0);
124 const int64_t w_im = index % width + pad_width;
125 const int64_t h_im = (index / width) % height + pad_height;
126 const int64_t c_im = index / (width * height);
127 int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1;
128 int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1;
129 // compute the start and end of the output
130 const int64_t w_col_start = (w_im < kernel_extent_w)
131 ? 0
132 : (w_im - kernel_extent_w) / stride_width + 1;
133 const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col);
134 const int64_t h_col_start = (h_im < kernel_extent_h)
135 ? 0
136 : (h_im - kernel_extent_h) / stride_height + 1;
137 const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col);
138
139 // TODO: use LCM of stride and dilation to avoid unnecessary loops
140 for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) {
141 for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) {
142 int64_t h_k = (h_im - h_col * stride_height);
143 int64_t w_k = (w_im - w_col * stride_width);
144 if (h_k % dilation_height == 0 && w_k % dilation_width == 0) {
145 h_k /= dilation_height;
146 w_k /= dilation_width;
147 int64_t data_col_index =
148 (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col +
149 h_col) *
150 width_col +
151 w_col;
152 val += data_col[data_col_index];
153 }
154 }
155 }
156 data_im[index] = static_cast<dt>(val);
157 }
158
159 template <typename dt, typename accT>
160 C10_LAUNCH_BOUNDS_1(512)
col2im_kernel(const int64_t n,const dt * data_col,const int64_t height,const int64_t width,const int64_t channels,const int64_t kernel_h,const int64_t kernel_w,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,const int64_t height_col,const int64_t width_col,dt * data_im)161 __global__ void col2im_kernel(
162 const int64_t n,
163 const dt* data_col,
164 const int64_t height,
165 const int64_t width,
166 const int64_t channels,
167 const int64_t kernel_h,
168 const int64_t kernel_w,
169 const int64_t pad_height,
170 const int64_t pad_width,
171 const int64_t stride_height,
172 const int64_t stride_width,
173 const int64_t dilation_height,
174 const int64_t dilation_width,
175 const int64_t height_col,
176 const int64_t width_col,
177 dt* data_im) {
178 CUDA_KERNEL_LOOP(index, n) {
179 col2im_device<accT>(
180 index,
181 data_col,
182 height,
183 width,
184 channels,
185 kernel_h,
186 kernel_w,
187 pad_height,
188 pad_width,
189 stride_height,
190 stride_width,
191 dilation_height,
192 dilation_width,
193 height_col,
194 width_col,
195 data_im);
196 }
197 }
198
199 template <typename dt, typename accT>
col2im(cudaStream_t stream,const dt * data_col,const int64_t channels,const int64_t height,const int64_t width,const int64_t height_col,const int64_t width_col,const int64_t patch_height,const int64_t patch_width,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,dt * data_im)200 void col2im(
201 cudaStream_t stream,
202 const dt* data_col,
203 const int64_t channels,
204 const int64_t height,
205 const int64_t width,
206 const int64_t height_col,
207 const int64_t width_col,
208 const int64_t patch_height,
209 const int64_t patch_width,
210 const int64_t pad_height,
211 const int64_t pad_width,
212 const int64_t stride_height,
213 const int64_t stride_width,
214 const int64_t dilation_height,
215 const int64_t dilation_width,
216 dt* data_im) {
217 int64_t num_kernels = channels * height * width;
218 // To avoid involving atomic operations, we will launch one kernel per
219 // bottom dimension, and then in the kernel add up the top dimensions.
220 // CUDA_NUM_THREADS = 1024
221 col2im_kernel<dt, accT>
222 <<<GET_BLOCKS(num_kernels, 512), 512, 0, stream>>>(
223 num_kernels,
224 data_col,
225 height,
226 width,
227 channels,
228 patch_height,
229 patch_width,
230 pad_height,
231 pad_width,
232 stride_height,
233 stride_width,
234 dilation_height,
235 dilation_width,
236 height_col,
237 width_col,
238 data_im);
239 C10_CUDA_KERNEL_LAUNCH_CHECK();
240 }
241
242 template <typename dt>
243 C10_LAUNCH_BOUNDS_1(512)
col2im_batched_kernel(const int64_t n,const dt * data_col,const int64_t col_batch_stride,const int64_t nbatch,const int64_t height,const int64_t width,const int64_t channels,const int64_t kernel_h,const int64_t kernel_w,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,const int64_t height_col,const int64_t width_col,dt * data_im,const int64_t im_batch_stride)244 __global__ void col2im_batched_kernel(
245 const int64_t n,
246 const dt* data_col,
247 const int64_t col_batch_stride,
248 const int64_t nbatch,
249 const int64_t height,
250 const int64_t width,
251 const int64_t channels,
252 const int64_t kernel_h,
253 const int64_t kernel_w,
254 const int64_t pad_height,
255 const int64_t pad_width,
256 const int64_t stride_height,
257 const int64_t stride_width,
258 const int64_t dilation_height,
259 const int64_t dilation_width,
260 const int64_t height_col,
261 const int64_t width_col,
262 dt* data_im,
263 const int64_t im_batch_stride) {
264 using accT = at::acc_type<dt, /*is_cuda*/true>;
265 const auto im_numel = n * nbatch;
266
267 CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) {
268 const auto ibatch = index / n;
269 const auto slice_index = index % n;
270
271 col2im_device<accT>(
272 slice_index,
273 data_col + ibatch * col_batch_stride,
274 height,
275 width,
276 channels,
277 kernel_h,
278 kernel_w,
279 pad_height,
280 pad_width,
281 stride_height,
282 stride_width,
283 dilation_height,
284 dilation_width,
285 height_col,
286 width_col,
287 data_im + ibatch * im_batch_stride);
288 }
289 }
290
291 template <typename dt>
col2im_batched(cudaStream_t stream,const dt * data_col,const int64_t col_batch_stride,const int64_t nbatch,const int64_t channels,const int64_t height,const int64_t width,const int64_t height_col,const int64_t width_col,const int64_t patch_height,const int64_t patch_width,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,dt * data_im,const int64_t im_batch_stride)292 void col2im_batched(
293 cudaStream_t stream,
294 const dt* data_col,
295 const int64_t col_batch_stride,
296 const int64_t nbatch,
297 const int64_t channels,
298 const int64_t height,
299 const int64_t width,
300 const int64_t height_col,
301 const int64_t width_col,
302 const int64_t patch_height,
303 const int64_t patch_width,
304 const int64_t pad_height,
305 const int64_t pad_width,
306 const int64_t stride_height,
307 const int64_t stride_width,
308 const int64_t dilation_height,
309 const int64_t dilation_width,
310 dt* data_im,
311 const int64_t im_batch_stride) {
312 const int64_t num_kernels = channels * height * width;
313 const int64_t output_numel = nbatch * num_kernels;
314 if (output_numel == 0) {
315 return; // No work to do
316 }
317
318 // To avoid involving atomic operations, we will launch one kernel per
319 // bottom dimension, and then in the kernel add up the top dimensions.
320 // CUDA_NUM_THREADS = 1024
321 col2im_batched_kernel<<<GET_BLOCKS(output_numel, 512), 512, 0, stream>>>(
322 num_kernels,
323 data_col,
324 col_batch_stride,
325 nbatch,
326 height,
327 width,
328 channels,
329 patch_height,
330 patch_width,
331 pad_height,
332 pad_width,
333 stride_height,
334 stride_width,
335 dilation_height,
336 dilation_width,
337 height_col,
338 width_col,
339 data_im,
340 im_batch_stride);
341 C10_CUDA_KERNEL_LAUNCH_CHECK();
342 }
343
344 } // namespace native
345 } // namespace at
346