xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/im2col.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/AccumulateType.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/detail/KernelUtils.h>
6 
7 #include <c10/macros/Macros.h>
8 
9 namespace at {
10 namespace native {
11 
12 using namespace at::cuda::detail;
13 
14 // Kernel for fast unfold+copy
15 // (borrowed from Caffe:
16 // https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
17 // CUDA_NUM_THREADS = 1024
18 
19 template <typename dt>
20 C10_LAUNCH_BOUNDS_1(1024)
im2col_kernel(const int64_t n,const dt * data_im,const int64_t height,const int64_t width,const int64_t kernel_height,const int64_t kernel_width,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,const int64_t height_col,const int64_t width_col,dt * data_col)21 __global__ void im2col_kernel(
22     const int64_t n,
23     const dt* data_im,
24     const int64_t height,
25     const int64_t width,
26     const int64_t kernel_height,
27     const int64_t kernel_width,
28     const int64_t pad_height,
29     const int64_t pad_width,
30     const int64_t stride_height,
31     const int64_t stride_width,
32     const int64_t dilation_height,
33     const int64_t dilation_width,
34     const int64_t height_col,
35     const int64_t width_col,
36     dt* data_col) {
37   CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) {
38     int64_t w_out = index % width_col;
39 
40     int64_t idx = index / width_col;
41 
42     int64_t h_out = idx % height_col;
43     int64_t channel_in = idx / height_col;
44     int64_t channel_out = channel_in * kernel_height * kernel_width;
45     int64_t h_in = h_out * stride_height - pad_height;
46     int64_t w_in = w_out * stride_width - pad_width;
47 
48     dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out;
49     const dt* im = data_im + (channel_in * height + h_in) * width + w_in;
50 
51     for (int64_t i = 0; i < kernel_height; ++i) {
52       for (int64_t j = 0; j < kernel_width; ++j) {
53         int64_t h = h_in + i * dilation_height;
54         int64_t w = w_in + j * dilation_width;
55         *col = (h >= 0 && w >= 0 && h < height && w < width)
56             ? im[i * dilation_height * width + j * dilation_width]
57             : static_cast<dt>(0);
58         col += height_col * width_col;
59       }
60     }
61   }
62 }
63 
64 template <typename dt>
im2col(cudaStream_t stream,const dt * data_im,const int64_t channels,const int64_t height,const int64_t width,const int64_t height_col,const int64_t width_col,const int64_t kernel_height,const int64_t kernel_width,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,dt * data_col)65 void im2col(
66     cudaStream_t stream,
67     const dt* data_im,
68     const int64_t channels,
69     const int64_t height,
70     const int64_t width,
71     const int64_t height_col,
72     const int64_t width_col,
73     const int64_t kernel_height,
74     const int64_t kernel_width,
75     const int64_t pad_height,
76     const int64_t pad_width,
77     const int64_t stride_height,
78     const int64_t stride_width,
79     const int64_t dilation_height,
80     const int64_t dilation_width,
81     dt* data_col) {
82   // We are going to launch channels * height_col * width_col kernels, each
83   // kernel responsible for copying a single-channel grid.
84   int64_t num_kernels = channels * height_col * width_col;
85   // Launch CUDA_NUM_THREADS = 1024
86   im2col_kernel<<<GET_BLOCKS(num_kernels), 1024, 0, stream>>>(
87       num_kernels,
88       data_im,
89       height,
90       width,
91       kernel_height,
92       kernel_width,
93       pad_height,
94       pad_width,
95       stride_height,
96       stride_width,
97       dilation_height,
98       dilation_width,
99       height_col,
100       width_col,
101       data_col);
102   C10_CUDA_KERNEL_LAUNCH_CHECK();
103 }
104 
105 template <typename accT, typename dt>
col2im_device(const int64_t index,const dt * data_col,const int64_t height,const int64_t width,const int64_t channels,const int64_t kernel_h,const int64_t kernel_w,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,const int64_t height_col,const int64_t width_col,dt * data_im)106 __forceinline__ __device__ void col2im_device(
107     const int64_t index,
108     const dt* data_col,
109     const int64_t height,
110     const int64_t width,
111     const int64_t channels,
112     const int64_t kernel_h,
113     const int64_t kernel_w,
114     const int64_t pad_height,
115     const int64_t pad_width,
116     const int64_t stride_height,
117     const int64_t stride_width,
118     const int64_t dilation_height,
119     const int64_t dilation_width,
120     const int64_t height_col,
121     const int64_t width_col,
122     dt* data_im) {
123   accT val = static_cast<accT>(0);
124   const int64_t w_im = index % width + pad_width;
125   const int64_t h_im = (index / width) % height + pad_height;
126   const int64_t c_im = index / (width * height);
127   int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1;
128   int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1;
129   // compute the start and end of the output
130   const int64_t w_col_start = (w_im < kernel_extent_w)
131       ? 0
132       : (w_im - kernel_extent_w) / stride_width + 1;
133   const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col);
134   const int64_t h_col_start = (h_im < kernel_extent_h)
135       ? 0
136       : (h_im - kernel_extent_h) / stride_height + 1;
137   const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col);
138 
139   // TODO: use LCM of stride and dilation to avoid unnecessary loops
140   for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) {
141     for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) {
142       int64_t h_k = (h_im - h_col * stride_height);
143       int64_t w_k = (w_im - w_col * stride_width);
144       if (h_k % dilation_height == 0 && w_k % dilation_width == 0) {
145         h_k /= dilation_height;
146         w_k /= dilation_width;
147         int64_t data_col_index =
148             (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col +
149               h_col) *
150                 width_col +
151             w_col;
152         val += data_col[data_col_index];
153       }
154     }
155   }
156   data_im[index] = static_cast<dt>(val);
157 }
158 
159 template <typename dt, typename accT>
160 C10_LAUNCH_BOUNDS_1(512)
col2im_kernel(const int64_t n,const dt * data_col,const int64_t height,const int64_t width,const int64_t channels,const int64_t kernel_h,const int64_t kernel_w,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,const int64_t height_col,const int64_t width_col,dt * data_im)161 __global__ void col2im_kernel(
162     const int64_t n,
163     const dt* data_col,
164     const int64_t height,
165     const int64_t width,
166     const int64_t channels,
167     const int64_t kernel_h,
168     const int64_t kernel_w,
169     const int64_t pad_height,
170     const int64_t pad_width,
171     const int64_t stride_height,
172     const int64_t stride_width,
173     const int64_t dilation_height,
174     const int64_t dilation_width,
175     const int64_t height_col,
176     const int64_t width_col,
177     dt* data_im) {
178   CUDA_KERNEL_LOOP(index, n) {
179     col2im_device<accT>(
180         index,
181         data_col,
182         height,
183         width,
184         channels,
185         kernel_h,
186         kernel_w,
187         pad_height,
188         pad_width,
189         stride_height,
190         stride_width,
191         dilation_height,
192         dilation_width,
193         height_col,
194         width_col,
195         data_im);
196   }
197 }
198 
199 template <typename dt, typename accT>
col2im(cudaStream_t stream,const dt * data_col,const int64_t channels,const int64_t height,const int64_t width,const int64_t height_col,const int64_t width_col,const int64_t patch_height,const int64_t patch_width,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,dt * data_im)200 void col2im(
201     cudaStream_t stream,
202     const dt* data_col,
203     const int64_t channels,
204     const int64_t height,
205     const int64_t width,
206     const int64_t height_col,
207     const int64_t width_col,
208     const int64_t patch_height,
209     const int64_t patch_width,
210     const int64_t pad_height,
211     const int64_t pad_width,
212     const int64_t stride_height,
213     const int64_t stride_width,
214     const int64_t dilation_height,
215     const int64_t dilation_width,
216     dt* data_im) {
217   int64_t num_kernels = channels * height * width;
218   // To avoid involving atomic operations, we will launch one kernel per
219   // bottom dimension, and then in the kernel add up the top dimensions.
220   // CUDA_NUM_THREADS = 1024
221   col2im_kernel<dt, accT>
222       <<<GET_BLOCKS(num_kernels, 512), 512, 0, stream>>>(
223           num_kernels,
224           data_col,
225           height,
226           width,
227           channels,
228           patch_height,
229           patch_width,
230           pad_height,
231           pad_width,
232           stride_height,
233           stride_width,
234           dilation_height,
235           dilation_width,
236           height_col,
237           width_col,
238           data_im);
239   C10_CUDA_KERNEL_LAUNCH_CHECK();
240 }
241 
242 template <typename dt>
243 C10_LAUNCH_BOUNDS_1(512)
col2im_batched_kernel(const int64_t n,const dt * data_col,const int64_t col_batch_stride,const int64_t nbatch,const int64_t height,const int64_t width,const int64_t channels,const int64_t kernel_h,const int64_t kernel_w,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,const int64_t height_col,const int64_t width_col,dt * data_im,const int64_t im_batch_stride)244 __global__ void col2im_batched_kernel(
245     const int64_t n,
246     const dt* data_col,
247     const int64_t col_batch_stride,
248     const int64_t nbatch,
249     const int64_t height,
250     const int64_t width,
251     const int64_t channels,
252     const int64_t kernel_h,
253     const int64_t kernel_w,
254     const int64_t pad_height,
255     const int64_t pad_width,
256     const int64_t stride_height,
257     const int64_t stride_width,
258     const int64_t dilation_height,
259     const int64_t dilation_width,
260     const int64_t height_col,
261     const int64_t width_col,
262     dt* data_im,
263     const int64_t im_batch_stride) {
264   using accT = at::acc_type<dt, /*is_cuda*/true>;
265   const auto im_numel = n * nbatch;
266 
267   CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) {
268     const auto ibatch = index / n;
269     const auto slice_index = index % n;
270 
271     col2im_device<accT>(
272         slice_index,
273         data_col + ibatch * col_batch_stride,
274         height,
275         width,
276         channels,
277         kernel_h,
278         kernel_w,
279         pad_height,
280         pad_width,
281         stride_height,
282         stride_width,
283         dilation_height,
284         dilation_width,
285         height_col,
286         width_col,
287         data_im + ibatch * im_batch_stride);
288   }
289 }
290 
291 template <typename dt>
col2im_batched(cudaStream_t stream,const dt * data_col,const int64_t col_batch_stride,const int64_t nbatch,const int64_t channels,const int64_t height,const int64_t width,const int64_t height_col,const int64_t width_col,const int64_t patch_height,const int64_t patch_width,const int64_t pad_height,const int64_t pad_width,const int64_t stride_height,const int64_t stride_width,const int64_t dilation_height,const int64_t dilation_width,dt * data_im,const int64_t im_batch_stride)292 void col2im_batched(
293     cudaStream_t stream,
294     const dt* data_col,
295     const int64_t col_batch_stride,
296     const int64_t nbatch,
297     const int64_t channels,
298     const int64_t height,
299     const int64_t width,
300     const int64_t height_col,
301     const int64_t width_col,
302     const int64_t patch_height,
303     const int64_t patch_width,
304     const int64_t pad_height,
305     const int64_t pad_width,
306     const int64_t stride_height,
307     const int64_t stride_width,
308     const int64_t dilation_height,
309     const int64_t dilation_width,
310     dt* data_im,
311     const int64_t im_batch_stride) {
312   const int64_t num_kernels = channels * height * width;
313   const int64_t output_numel = nbatch * num_kernels;
314   if (output_numel == 0) {
315     return;  // No work to do
316   }
317 
318   // To avoid involving atomic operations, we will launch one kernel per
319   // bottom dimension, and then in the kernel add up the top dimensions.
320   // CUDA_NUM_THREADS = 1024
321   col2im_batched_kernel<<<GET_BLOCKS(output_numel, 512), 512, 0, stream>>>(
322           num_kernels,
323           data_col,
324           col_batch_stride,
325           nbatch,
326           height,
327           width,
328           channels,
329           patch_height,
330           patch_width,
331           pad_height,
332           pad_width,
333           stride_height,
334           stride_width,
335           dilation_height,
336           dilation_width,
337           height_col,
338           width_col,
339           data_im,
340           im_batch_stride);
341   C10_CUDA_KERNEL_LAUNCH_CHECK();
342 }
343 
344 } // namespace native
345 } // namespace at
346