1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/NamedTensorUtils.h>
7 #include <ATen/NumericUtils.h>
8 #include <ATen/native/Pool.h>
9 #include <ATen/cuda/CUDAContext.h>
10 #include <ATen/cuda/NumericLimits.cuh>
11 #include <ATen/cuda/detail/TensorInfo.cuh>
12 #include <ATen/cuda/detail/IndexUtils.cuh>
13 #include <ATen/cuda/detail/KernelUtils.h>
14 #include <c10/macros/Macros.h>
15 #include <ATen/native/cuda/LaunchUtils.h>
16
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/max_pool2d_with_indices_native.h>
21 #include <ATen/ops/max_pool2d_with_indices_backward_native.h>
22 #endif
23
24 namespace at::native {
25 namespace {
26
min(int a,int b)27 __device__ inline int min(int a, int b) {
28 return a <= b ? a : b;
29 }
30
31 #define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
32
33 #define BLOCK_STRIDE 2 // increasing block_stride to lower # of blocks launched
34
p_start(int size,int pad,int kernel,int dilation,int stride)35 static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
36 return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
37 }
38
p_end(int size,int pad,int pooled_size,int stride)39 static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) {
40 return min((size + pad) / stride + 1, pooled_size);
41 }
42
43 // kernels borrowed from Caffe
44 template <typename scalar_t>
max_pool_forward_nchw(const int nthreads,const scalar_t * bottom_data,const int64_t channels,const int64_t height,const int64_t width,const int pooled_height,const int pooled_width,const int kernel_h,const int kernel_w,const int stride_h,const int stride_w,const int pad_h,const int pad_w,const int dilation_h,const int dilation_w,scalar_t * top_data,int64_t * top_mask)45 __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data,
46 const int64_t channels, const int64_t height,
47 const int64_t width, const int pooled_height, const int pooled_width,
48 const int kernel_h, const int kernel_w, const int stride_h,
49 const int stride_w, const int pad_h, const int pad_w,
50 const int dilation_h, const int dilation_w, scalar_t* top_data,
51 int64_t* top_mask) {
52 CUDA_KERNEL_LOOP(index, nthreads) {
53 int pw = index % pooled_width;
54 int ph = (index / pooled_width) % pooled_height;
55 int c = (index / pooled_width / pooled_height) % channels;
56 int n = index / pooled_width / pooled_height / channels;
57 int hstart = ph * stride_h - pad_h;
58 int wstart = pw * stride_w - pad_w;
59 int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
60 int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
61 while(hstart < 0)
62 hstart += dilation_h;
63 while(wstart < 0)
64 wstart += dilation_w;
65 scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
66 int maxidx = hstart * width + wstart;
67 const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
68 for (int h = hstart; h < hend; h += dilation_h) {
69 for (int w = wstart; w < wend; w += dilation_w) {
70 scalar_t val = btm_data[h * width + w];
71 if ((val > maxval) || at::_isnan(val)) {
72 maxidx = h * width + w;
73 maxval = val;
74 }
75 }
76 }
77 top_data[index] = maxval;
78 top_mask[index] = maxidx;
79 }
80 }
81
82 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)83 C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
84 __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch,
85 const int64_t channels, const int64_t height,
86 const int64_t width, const int pooled_height, const int pooled_width,
87 const int kernel_h, const int kernel_w, const int stride_h,
88 const int stride_w, const int pad_h, const int pad_w,
89 const int dilation_h, const int dilation_w,
90 const int in_stride_n, const int in_stride_c,
91 const int in_stride_h, const int in_stride_w,
92 const int kernel_stride_C, const int kernel_size_C,
93 scalar_t* top_data, int64_t* top_mask) {
94 extern __shared__ int smem[];
95 int *out_mask_cached = smem;
96 scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]);
97
98 // flattening cta for pre-computation & smem initialization;
99 int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
100 int block_size = blockDim.x * blockDim.y * blockDim.z;
101
102 // use shared memory to store temporary output value. This is simply to
103 // reduce register usage.
104 for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) {
105 out_cached[i] = at::numeric_limits<scalar_t>::lower_bound();
106 out_mask_cached[i] = 0;
107 }
108
109 __syncthreads();
110
111 int batch_id = blockIdx.x % nbatch;
112 int channel_id = blockIdx.x / nbatch;
113 int channel_offset = threadIdx.x + channel_id * blockDim.x;
114
115 top_data = top_data + batch_id * pooled_height * pooled_width * channels;
116 top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
117 bottom_data = bottom_data + batch_id * in_stride_n;
118
119 out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
120 out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
121
122 int oH = (pooled_height + gridDim.z-1) / gridDim.z;
123 int oW = (pooled_width + gridDim.y-1) / gridDim.y;
124 int ostartH = threadIdx.z + blockIdx.z*oH;
125 int oendH = ::min(ostartH+oH, pooled_height);
126 int ostartW = threadIdx.y + blockIdx.y*oW;
127 int oendW = ::min(ostartW+oW, pooled_width);
128
129 for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
130 int hstart = oh * stride_h - pad_h;
131 int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
132 for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
133 int wstart = ow * stride_w - pad_w;
134 int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
135 while(hstart < 0)
136 hstart += dilation_h;
137 while(wstart < 0)
138 wstart += dilation_w;
139 for (int ih = hstart; ih < hend; ih += dilation_h) {
140 for (int iw = wstart; iw < wend; iw += dilation_w) {
141 int cached_index = threadIdx.x;
142 const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
143 for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
144 scalar_t val = ptr_input[c*in_stride_c];
145 if ((val > out_cached[cached_index]) || at::_isnan(val)) {
146 out_cached[cached_index] = val;
147 out_mask_cached[cached_index] = ih * width + iw;
148 }
149 cached_index += blockDim.x;
150 }
151 }
152 }
153 scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels;
154 int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels;
155
156 int cached_index = threadIdx.x;
157 for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
158 ptr_output_data[c] = out_cached[cached_index];
159 ptr_output_mask[c] = out_mask_cached[cached_index];
160 out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
161 out_mask_cached[cached_index] = 0;
162 cached_index += blockDim.x;
163 }
164 }
165 }
166 }
167
168
169 static const int BLOCK_THREADS = 256;
170
171 template <typename scalar_t, typename accscalar_t>
172 #if defined (USE_ROCM)
173 C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
174 #else
175 C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
176 #endif
max_pool_backward_nchw(const scalar_t * top_diff,const int64_t * top_mask,const int num,const int64_t channels,const int64_t height,const int64_t width,const int pooled_height,const int pooled_width,const int kernel_h,const int kernel_w,const int stride_h,const int stride_w,const int pad_h,const int pad_w,const int dilation_h,const int dilation_w,scalar_t * bottom_diff)177 __global__ void max_pool_backward_nchw(const scalar_t* top_diff,
178 const int64_t* top_mask, const int num, const int64_t channels,
179 const int64_t height, const int64_t width, const int pooled_height,
180 const int pooled_width, const int kernel_h, const int kernel_w,
181 const int stride_h, const int stride_w, const int pad_h, const int pad_w,
182 const int dilation_h, const int dilation_w,
183 scalar_t* bottom_diff) {
184 CUDA_KERNEL_LOOP(index, height*width) {
185 int h = index / width;
186 int w = index - h * width;
187 int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
188 int phend = p_end(h, pad_h, pooled_height, stride_h);
189 int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
190 int pwend = p_end(w, pad_w, pooled_width, stride_w);
191 for (int n = blockIdx.y; n < num; n += gridDim.y) {
192 for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
193 accscalar_t gradient = accscalar_t(0);
194 int offset = (n * channels + c) * pooled_height * pooled_width;
195 for (int ph = phstart; ph < phend; ++ph) {
196 for (int pw = pwstart; pw < pwend; ++pw) {
197 if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
198 gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
199 }
200 }
201 }
202 bottom_diff[(n*channels+c)*height*width+index] = static_cast<scalar_t>(gradient);
203 }
204 }
205 }
206 }
207
208 template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)209 C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
210 __global__ void max_pool_backward_nhwc(const scalar_t* top_diff,
211 const int64_t* top_mask, const int nbatch, const int64_t channels,
212 const int64_t height, const int64_t width, const int pooled_height,
213 const int pooled_width, const int kernel_h, const int kernel_w,
214 const int stride_h, const int stride_w, const int pad_h, const int pad_w,
215 const int dilation_h, const int dilation_w,
216 const int out_stride_c, const int out_stride_h, const int out_stride_w,
217 const int kernel_stride_C, const int kernel_size_C,
218 scalar_t* bottom_diff) {
219 extern __shared__ int smem[];
220 accscalar_t *out_cached = reinterpret_cast<accscalar_t*>(smem);
221
222 int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
223 int block_size = blockDim.x * blockDim.y * blockDim.z;
224
225 int batch_id = blockIdx.x % nbatch;
226 int channel_id = blockIdx.x / nbatch;
227 int channel_offset = threadIdx.x + channel_id * blockDim.x;
228
229 for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) {
230 out_cached[i] = accscalar_t(0.0);
231 }
232
233 __syncthreads();
234
235 out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
236
237 bottom_diff = bottom_diff + batch_id * height * width * channels;
238 top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
239 top_diff = top_diff + batch_id * pooled_height * pooled_width * channels;
240
241 int iH = (height + gridDim.z-1) / gridDim.z;
242 int iW = (width + gridDim.y-1) / gridDim.y;
243 int istartH = threadIdx.z + blockIdx.z*iH;
244 int iendH = ::min(static_cast<int64_t>(istartH)+iH, height);
245 int istartW = threadIdx.y + blockIdx.y*iW;
246 int iendW = ::min(static_cast<int64_t>(istartW)+iW, width);
247
248 for (int ih = istartH; ih < iendH; ih+=blockDim.z) {
249 int phstart = p_start(ih, pad_h, kernel_h, dilation_h, stride_h);
250 int phend = p_end(ih, pad_h, pooled_height, stride_h);
251 for (int iw = istartW; iw < iendW; iw+=blockDim.y) {
252 int pwstart = p_start(iw, pad_w, kernel_w, dilation_w, stride_w);
253 int pwend = p_end(iw, pad_w, pooled_width, stride_w);
254 int index_shift = ih * width + iw;
255 if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
256 for(int oh = phstart; oh < phend; ++oh) {
257 for(int ow = pwstart; ow < pwend; ++ow) {
258 int cached_index = threadIdx.x;
259 const int64_t* ptr_top_mask = top_mask + oh * out_stride_h + ow * out_stride_w;
260 for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
261 if (ptr_top_mask[c*out_stride_c] == index_shift) {
262 out_cached[cached_index] +=
263 static_cast<accscalar_t>(top_diff[oh * out_stride_h + ow * out_stride_w + c*out_stride_c]);
264 }
265 cached_index += blockDim.x;
266 }
267 }
268 }
269 scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels;
270 int cached_index = threadIdx.x;
271 for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
272 ptr_bottom_diff[c] = static_cast<scalar_t>(out_cached[cached_index]);
273 out_cached[cached_index] = accscalar_t(0.0);
274 cached_index += blockDim.x;
275 }
276 } else {
277 const int64_t* ptr_top_mask = top_mask + phstart * out_stride_h + pwstart * out_stride_w;
278 scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels;
279 int cached_index = threadIdx.x;
280 for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
281 if (ptr_top_mask[c*out_stride_c] == index_shift) {
282 ptr_bottom_diff[c] =
283 static_cast<scalar_t>(top_diff[phstart * out_stride_h + pwstart * out_stride_w + c*out_stride_c]);
284 }
285 cached_index += blockDim.x;
286 }
287 }
288 }
289 }
290 }
291
292 } // namespace
293
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cuda)294 TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cuda)
295 (const Tensor& input_,
296 IntArrayRef kernel_size,
297 IntArrayRef stride,
298 IntArrayRef padding,
299 IntArrayRef dilation,
300 bool ceil_mode,
301 const Tensor& output,
302 const Tensor& indices) {
303 NoNamesGuard guard;
304
305 TensorArg output_arg{ output, "output", 1 };
306 TensorArg indices_arg{ indices, "indices", 2 };
307 TensorArg input_arg{ input_, "input_", 3 };
308
309 checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg});
310 if (output.numel() == 0) {
311 return;
312 }
313
314 const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
315 const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
316
317 const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
318 const int dW = stride.empty() ? kW :
319 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
320
321 const int padH = safe_downcast<int, int64_t>(padding[0]);
322 const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
323
324 const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
325 const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
326
327 const auto memory_format = input_.suggest_memory_format();
328
329 const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
330 const int64_t nInputPlane = input_.size(-3);
331 const int64_t inputHeight = input_.size(-2);
332 const int64_t inputWidth = input_.size(-1);
333
334 const int64_t outputHeight = output.size(-2);
335 const int64_t outputWidth = output.size(-1);
336
337 Tensor input = input_.contiguous(memory_format);
338
339 const int64_t in_stride_n = input_.ndimension() == 4 ? input.stride(-4) : 0;
340 const int64_t in_stride_c = input.stride(-3);
341 const int64_t in_stride_h = input.stride(-2);
342 const int64_t in_stride_w = input.stride(-1);
343
344 const int count = safe_downcast<int, int64_t>(output.numel());
345
346 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
347 "max_pool2d_with_indices_out_cuda_frame",
348 [&] {
349 using accscalar_t = acc_type<scalar_t, true>;
350
351 scalar_t *output_data = output.mutable_data_ptr<scalar_t>();
352 const scalar_t *input_data = input.const_data_ptr<scalar_t>();
353 int64_t *indices_data = indices.mutable_data_ptr<int64_t>();
354
355 switch (memory_format) {
356 case MemoryFormat::ChannelsLast: {
357 const int max_threads = std::min<int>(
358 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
359 int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
360 int block_x = std::min<int>(
361 maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
362 int block_y = std::min<int>(
363 maxThreadsDim[1], std::min<int>(lastPow2(outputWidth), max_threads / block_x));
364 int block_z = std::min<int>(
365 maxThreadsDim[2], std::min<int>(lastPow2(outputHeight), max_threads / block_x / block_y));
366 block_x = std::min<int>(
367 maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
368 const dim3 block(block_x, block_y, block_z);
369
370 int kernel_stride_C = ceil_div(
371 safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
372 int kernel_size_C = ceil_div(
373 safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);
374
375 int grid_x = nbatch*kernel_stride_C;
376 int grid_y = std::min<int>(
377 at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
378 ceil_div(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE));
379 int grid_z = std::min<int>(
380 at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
381 ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE));
382 const dim3 grid(grid_x, grid_y, grid_z);
383
384 size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
385 AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
386
387 max_pool_forward_nhwc<scalar_t>
388 <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
389 input_data, nbatch,
390 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
391 kH, kW, dH, dW, padH, padW, dilationH, dilationW,
392 in_stride_n, in_stride_c,
393 in_stride_h, in_stride_w,
394 kernel_stride_C, kernel_size_C,
395 output_data, indices_data);
396 C10_CUDA_KERNEL_LAUNCH_CHECK();
397 break;
398 }
399 case MemoryFormat::Contiguous: {
400 const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
401 BLOCK_THREADS);
402 max_pool_forward_nchw<scalar_t>
403 <<<ceil_div(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
404 count, input_data,
405 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
406 kH, kW, dH, dW, padH, padW, dilationH, dilationW,
407 output_data, indices_data);
408 C10_CUDA_KERNEL_LAUNCH_CHECK();
409 break;
410 }
411 default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
412 }
413 }
414 );
415 }
416
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cuda)417 TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cuda)
418 (const Tensor& gradOutput_,
419 const Tensor& input_,
420 IntArrayRef kernel_size,
421 IntArrayRef stride,
422 IntArrayRef padding,
423 IntArrayRef dilation,
424 bool ceil_mode,
425 const Tensor& indices_,
426 const Tensor& gradInput) {
427 NoNamesGuard guard;
428
429 TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
430 TensorArg gradOutput_arg{ gradOutput_, "gradOutput_", 2 };
431 TensorArg input_arg{ input_, "input_", 3 };
432 TensorArg indices_arg{ indices_, "indices", 4 };
433
434 checkAllSameGPU(__func__,
435 {gradInput_arg, gradOutput_arg, input_arg, indices_arg});
436 if (gradOutput_.numel() == 0) {
437 return;
438 }
439
440 const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
441 const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
442
443 const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
444 const int dW = stride.empty() ? kW :
445 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
446
447 const int padH = safe_downcast<int, int64_t>(padding[0]);
448 const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
449
450 const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
451 const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
452
453 const auto memory_format = input_.suggest_memory_format();
454
455 const Tensor input = input_.contiguous(memory_format);
456
457 const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
458 const int64_t nInputPlane = input.size(-3);
459 const int64_t inputHeight = input.size(-2);
460 const int64_t inputWidth = input.size(-1);
461
462 const int64_t in_stride_n = input.ndimension() == 4 ? input.stride(-4) : 0;
463 const int64_t in_stride_c = input.stride(-3);
464 const int64_t in_stride_h = input.stride(-2);
465 const int64_t in_stride_w = input.stride(-1);
466
467 const Tensor gradOutput = gradOutput_.contiguous(memory_format);
468
469 const int64_t outputHeight = gradOutput.size(-2);
470 const int64_t outputWidth = gradOutput.size(-1);
471
472 const int64_t out_stride_c = gradOutput.stride(-3);
473 const int64_t out_stride_h = gradOutput.stride(-2);
474 const int64_t out_stride_w = gradOutput.stride(-1);
475
476 const Tensor indices = indices_.contiguous(memory_format);
477
478 gradInput.zero_();
479
480 int64_t count = input.numel();
481
482 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
483 "max_pool2d_with_indices_out_cuda_frame",
484 [&] {
485 using accscalar_t = acc_type<scalar_t, true>;
486
487 const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
488 scalar_t *gradInput_data = gradInput.mutable_data_ptr<scalar_t>();
489 const int64_t *indices_data = indices.const_data_ptr<int64_t>();
490
491 switch (memory_format) {
492 case MemoryFormat::ChannelsLast: {
493 const int max_threads = std::min<int>(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
494 int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
495 int block_x = std::min<int>(
496 maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
497 int block_y = std::min<int>(
498 maxThreadsDim[1], std::min<int>(lastPow2(inputWidth), max_threads / block_x));
499 int block_z = std::min<int>(
500 maxThreadsDim[2], std::min<int>(lastPow2(inputHeight), max_threads / block_x / block_y));
501 block_x = std::min<int>(
502 maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
503 const dim3 block(block_x, block_y, block_z);
504
505 int kernel_stride_C = ceil_div(
506 safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
507 int kernel_size_C = ceil_div(
508 safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);
509
510 int grid_x = nbatch*kernel_stride_C;
511 int grid_y = std::min<int>(
512 at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
513 ceil_div(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE));
514 int grid_z = std::min<int>(
515 at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
516 ceil_div(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE));
517 const dim3 grid(grid_x, grid_y, grid_z);
518
519 size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t);
520 AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
521
522 // The backward kernel is launched on input instead output.
523 // If it is launched on output layer, atomic_add would not provide much benefit on FP16.
524 // Please check comments at https://github.com/pytorch/pytorch/pull/34519.
525 max_pool_backward_nhwc<scalar_t, accscalar_t>
526 <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
527 gradOutput_data,
528 indices_data,
529 nbatch,
530 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
531 kH, kW, dH, dW, padH, padW, dilationH, dilationW,
532 out_stride_c, out_stride_h, out_stride_w,
533 kernel_stride_C, kernel_size_C,
534 gradInput_data);
535 C10_CUDA_KERNEL_LAUNCH_CHECK();
536 break;
537 }
538 case MemoryFormat::Contiguous: {
539 int imgcount = inputWidth * inputHeight;
540 dim3 grid;
541 const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
542 grid.x = blocks;
543 grid.y = nbatch;
544 uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
545 if (maxGridY < grid.y) grid.y = maxGridY;
546 grid.z = nInputPlane;
547 uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
548 if (maxGridZ < grid.z) grid.z = maxGridZ;
549
550 max_pool_backward_nchw<scalar_t, accscalar_t>
551 <<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
552 gradOutput_data,
553 indices_data,
554 nbatch,
555 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
556 kH, kW, dH, dW, padH, padW, dilationH, dilationW,
557 gradInput_data);
558 C10_CUDA_KERNEL_LAUNCH_CHECK();
559 break;
560 }
561 default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
562 }
563 }
564 );
565 }
566
567 } // at::native
568