xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/NamedTensorUtils.h>
7 #include <ATen/NumericUtils.h>
8 #include <ATen/native/Pool.h>
9 #include <ATen/cuda/CUDAContext.h>
10 #include <ATen/cuda/NumericLimits.cuh>
11 #include <ATen/cuda/detail/TensorInfo.cuh>
12 #include <ATen/cuda/detail/IndexUtils.cuh>
13 #include <ATen/cuda/detail/KernelUtils.h>
14 #include <c10/macros/Macros.h>
15 #include <ATen/native/cuda/LaunchUtils.h>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/max_pool2d_with_indices_native.h>
21 #include <ATen/ops/max_pool2d_with_indices_backward_native.h>
22 #endif
23 
24 namespace at::native {
25 namespace {
26 
min(int a,int b)27 __device__ inline int min(int a, int b) {
28   return a <= b ? a : b;
29 }
30 
31 #define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
32 
33 #define BLOCK_STRIDE 2 // increasing block_stride to lower # of blocks launched
34 
p_start(int size,int pad,int kernel,int dilation,int stride)35 static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
36   return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
37 }
38 
p_end(int size,int pad,int pooled_size,int stride)39 static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) {
40   return min((size + pad) / stride + 1, pooled_size);
41 }
42 
43 // kernels borrowed from Caffe
44 template <typename scalar_t>
max_pool_forward_nchw(const int nthreads,const scalar_t * bottom_data,const int64_t channels,const int64_t height,const int64_t width,const int pooled_height,const int pooled_width,const int kernel_h,const int kernel_w,const int stride_h,const int stride_w,const int pad_h,const int pad_w,const int dilation_h,const int dilation_w,scalar_t * top_data,int64_t * top_mask)45 __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data,
46     const int64_t channels, const int64_t height,
47     const int64_t width, const int pooled_height, const int pooled_width,
48     const int kernel_h, const int kernel_w, const int stride_h,
49     const int stride_w, const int pad_h, const int pad_w,
50     const int dilation_h, const int dilation_w, scalar_t* top_data,
51     int64_t* top_mask) {
52   CUDA_KERNEL_LOOP(index, nthreads) {
53     int pw = index % pooled_width;
54     int ph = (index / pooled_width) % pooled_height;
55     int c = (index / pooled_width / pooled_height) % channels;
56     int n = index / pooled_width / pooled_height / channels;
57     int hstart = ph * stride_h - pad_h;
58     int wstart = pw * stride_w - pad_w;
59     int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
60     int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
61     while(hstart < 0)
62       hstart += dilation_h;
63     while(wstart < 0)
64       wstart += dilation_w;
65     scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
66     int maxidx = hstart * width + wstart;
67     const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
68     for (int h = hstart; h < hend; h += dilation_h) {
69       for (int w = wstart; w < wend; w += dilation_w) {
70         scalar_t val = btm_data[h * width + w];
71         if ((val > maxval) || at::_isnan(val)) {
72           maxidx = h * width + w;
73           maxval = val;
74         }
75       }
76     }
77     top_data[index] = maxval;
78     top_mask[index] = maxidx;
79   }
80 }
81 
82 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)83 C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
84 __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch,
85                                    const int64_t channels, const int64_t height,
86                                    const int64_t width, const int pooled_height, const int pooled_width,
87                                    const int kernel_h, const int kernel_w, const int stride_h,
88                                    const int stride_w, const int pad_h, const int pad_w,
89                                    const int dilation_h, const int dilation_w,
90                                    const int in_stride_n, const int in_stride_c,
91                                    const int in_stride_h, const int in_stride_w,
92                                    const int kernel_stride_C, const int kernel_size_C,
93                                    scalar_t* top_data, int64_t* top_mask) {
94   extern __shared__ int smem[];
95   int *out_mask_cached = smem;
96   scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]);
97 
98   // flattening cta for pre-computation & smem initialization;
99   int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
100   int block_size = blockDim.x * blockDim.y * blockDim.z;
101 
102   // use shared memory to store temporary output value. This is simply to
103   // reduce register usage.
104   for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) {
105     out_cached[i] = at::numeric_limits<scalar_t>::lower_bound();
106     out_mask_cached[i] = 0;
107   }
108 
109   __syncthreads();
110 
111   int batch_id = blockIdx.x % nbatch;
112   int channel_id = blockIdx.x / nbatch;
113   int channel_offset = threadIdx.x + channel_id * blockDim.x;
114 
115   top_data = top_data + batch_id * pooled_height * pooled_width * channels;
116   top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
117   bottom_data = bottom_data + batch_id * in_stride_n;
118 
119   out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
120   out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
121 
122   int oH = (pooled_height + gridDim.z-1) / gridDim.z;
123   int oW = (pooled_width + gridDim.y-1) / gridDim.y;
124   int ostartH = threadIdx.z + blockIdx.z*oH;
125   int oendH = ::min(ostartH+oH, pooled_height);
126   int ostartW = threadIdx.y + blockIdx.y*oW;
127   int oendW = ::min(ostartW+oW, pooled_width);
128 
129   for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
130     int hstart = oh * stride_h - pad_h;
131     int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
132     for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
133       int wstart = ow * stride_w - pad_w;
134       int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
135       while(hstart < 0)
136         hstart += dilation_h;
137       while(wstart < 0)
138         wstart += dilation_w;
139       for (int ih = hstart; ih < hend; ih += dilation_h) {
140         for (int iw = wstart; iw < wend; iw += dilation_w) {
141           int cached_index = threadIdx.x;
142           const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
143           for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
144             scalar_t val = ptr_input[c*in_stride_c];
145             if ((val > out_cached[cached_index]) || at::_isnan(val)) {
146               out_cached[cached_index] = val;
147               out_mask_cached[cached_index] = ih * width + iw;
148             }
149             cached_index += blockDim.x;
150           }
151         }
152       }
153       scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels;
154       int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels;
155 
156       int cached_index = threadIdx.x;
157       for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
158         ptr_output_data[c] = out_cached[cached_index];
159         ptr_output_mask[c] = out_mask_cached[cached_index];
160         out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
161         out_mask_cached[cached_index] = 0;
162         cached_index += blockDim.x;
163       }
164     }
165   }
166 }
167 
168 
169 static const int BLOCK_THREADS = 256;
170 
171 template <typename scalar_t, typename accscalar_t>
172 #if defined (USE_ROCM)
173 C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
174 #else
175 C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
176 #endif
max_pool_backward_nchw(const scalar_t * top_diff,const int64_t * top_mask,const int num,const int64_t channels,const int64_t height,const int64_t width,const int pooled_height,const int pooled_width,const int kernel_h,const int kernel_w,const int stride_h,const int stride_w,const int pad_h,const int pad_w,const int dilation_h,const int dilation_w,scalar_t * bottom_diff)177 __global__ void max_pool_backward_nchw(const scalar_t* top_diff,
178     const int64_t* top_mask, const int num, const int64_t channels,
179     const int64_t height, const int64_t width, const int pooled_height,
180     const int pooled_width, const int kernel_h, const int kernel_w,
181     const int stride_h, const int stride_w, const int pad_h, const int pad_w,
182     const int dilation_h, const int dilation_w,
183     scalar_t* bottom_diff) {
184   CUDA_KERNEL_LOOP(index, height*width) {
185     int h = index / width;
186     int w = index - h * width;
187     int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
188     int phend = p_end(h, pad_h, pooled_height, stride_h);
189     int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
190     int pwend = p_end(w, pad_w, pooled_width, stride_w);
191     for (int n = blockIdx.y; n < num; n += gridDim.y) {
192       for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
193         accscalar_t gradient = accscalar_t(0);
194         int offset = (n * channels + c) * pooled_height * pooled_width;
195         for (int ph = phstart; ph < phend; ++ph) {
196           for (int pw = pwstart; pw < pwend; ++pw) {
197             if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
198               gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
199             }
200           }
201         }
202         bottom_diff[(n*channels+c)*height*width+index] = static_cast<scalar_t>(gradient);
203       }
204     }
205   }
206 }
207 
208 template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)209 C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
210 __global__ void max_pool_backward_nhwc(const scalar_t* top_diff,
211                                     const int64_t* top_mask, const int nbatch, const int64_t channels,
212                                     const int64_t height, const int64_t width, const int pooled_height,
213                                     const int pooled_width, const int kernel_h, const int kernel_w,
214                                     const int stride_h, const int stride_w, const int pad_h, const int pad_w,
215                                     const int dilation_h, const int dilation_w,
216                                     const int out_stride_c, const int out_stride_h, const int out_stride_w,
217                                     const int kernel_stride_C, const int kernel_size_C,
218                                     scalar_t* bottom_diff) {
219   extern __shared__ int smem[];
220   accscalar_t *out_cached = reinterpret_cast<accscalar_t*>(smem);
221 
222   int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
223   int block_size = blockDim.x * blockDim.y * blockDim.z;
224 
225   int batch_id = blockIdx.x % nbatch;
226   int channel_id = blockIdx.x / nbatch;
227   int channel_offset = threadIdx.x + channel_id * blockDim.x;
228 
229   for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) {
230     out_cached[i] = accscalar_t(0.0);
231   }
232 
233   __syncthreads();
234 
235   out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
236 
237   bottom_diff = bottom_diff + batch_id * height * width * channels;
238   top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
239   top_diff = top_diff + batch_id * pooled_height * pooled_width * channels;
240 
241   int iH = (height + gridDim.z-1) / gridDim.z;
242   int iW = (width + gridDim.y-1) / gridDim.y;
243   int istartH = threadIdx.z + blockIdx.z*iH;
244   int iendH = ::min(static_cast<int64_t>(istartH)+iH, height);
245   int istartW = threadIdx.y + blockIdx.y*iW;
246   int iendW = ::min(static_cast<int64_t>(istartW)+iW, width);
247 
248   for (int ih = istartH; ih < iendH; ih+=blockDim.z) {
249     int phstart = p_start(ih, pad_h, kernel_h, dilation_h, stride_h);
250     int phend = p_end(ih, pad_h, pooled_height, stride_h);
251     for (int iw = istartW; iw < iendW; iw+=blockDim.y) {
252       int pwstart = p_start(iw, pad_w, kernel_w, dilation_w, stride_w);
253       int pwend = p_end(iw, pad_w, pooled_width, stride_w);
254       int index_shift = ih * width + iw;
255       if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
256         for(int oh = phstart; oh < phend; ++oh) {
257           for(int ow = pwstart; ow < pwend; ++ow) {
258             int cached_index = threadIdx.x;
259             const int64_t* ptr_top_mask = top_mask + oh * out_stride_h + ow * out_stride_w;
260             for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
261               if (ptr_top_mask[c*out_stride_c] == index_shift) {
262                 out_cached[cached_index] +=
263                   static_cast<accscalar_t>(top_diff[oh * out_stride_h + ow * out_stride_w + c*out_stride_c]);
264               }
265               cached_index += blockDim.x;
266             }
267           }
268         }
269         scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels;
270         int cached_index = threadIdx.x;
271         for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
272           ptr_bottom_diff[c] = static_cast<scalar_t>(out_cached[cached_index]);
273           out_cached[cached_index] = accscalar_t(0.0);
274           cached_index += blockDim.x;
275         }
276       } else {
277         const int64_t* ptr_top_mask = top_mask + phstart * out_stride_h + pwstart * out_stride_w;
278         scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels;
279         int cached_index = threadIdx.x;
280         for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
281           if (ptr_top_mask[c*out_stride_c] == index_shift) {
282             ptr_bottom_diff[c] =
283               static_cast<scalar_t>(top_diff[phstart * out_stride_h + pwstart * out_stride_w + c*out_stride_c]);
284           }
285           cached_index += blockDim.x;
286         }
287       }
288     }
289   }
290 }
291 
292 } // namespace
293 
TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cuda)294 TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cuda)
295 (const Tensor& input_,
296 IntArrayRef kernel_size,
297 IntArrayRef stride,
298 IntArrayRef padding,
299 IntArrayRef dilation,
300 bool ceil_mode,
301 const Tensor& output,
302 const Tensor& indices) {
303   NoNamesGuard guard;
304 
305   TensorArg output_arg{ output, "output", 1 };
306   TensorArg indices_arg{ indices, "indices", 2 };
307   TensorArg input_arg{ input_, "input_", 3 };
308 
309   checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg});
310   if (output.numel() == 0) {
311     return;
312   }
313 
314   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
315   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
316 
317   const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
318   const int dW = stride.empty() ? kW :
319                  stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
320 
321   const int padH = safe_downcast<int, int64_t>(padding[0]);
322   const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
323 
324   const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
325   const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
326 
327   const auto memory_format = input_.suggest_memory_format();
328 
329   const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
330   const int64_t nInputPlane = input_.size(-3);
331   const int64_t inputHeight = input_.size(-2);
332   const int64_t inputWidth = input_.size(-1);
333 
334   const int64_t outputHeight = output.size(-2);
335   const int64_t outputWidth = output.size(-1);
336 
337   Tensor input = input_.contiguous(memory_format);
338 
339   const int64_t in_stride_n = input_.ndimension() == 4 ? input.stride(-4) : 0;
340   const int64_t in_stride_c = input.stride(-3);
341   const int64_t in_stride_h = input.stride(-2);
342   const int64_t in_stride_w = input.stride(-1);
343 
344   const int count = safe_downcast<int, int64_t>(output.numel());
345 
346   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
347     "max_pool2d_with_indices_out_cuda_frame",
348     [&] {
349       using accscalar_t = acc_type<scalar_t, true>;
350 
351       scalar_t *output_data = output.mutable_data_ptr<scalar_t>();
352       const scalar_t *input_data = input.const_data_ptr<scalar_t>();
353       int64_t *indices_data = indices.mutable_data_ptr<int64_t>();
354 
355       switch (memory_format) {
356         case MemoryFormat::ChannelsLast: {
357           const int max_threads = std::min<int>(
358               at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
359           int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
360           int block_x = std::min<int>(
361               maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
362           int block_y = std::min<int>(
363               maxThreadsDim[1], std::min<int>(lastPow2(outputWidth), max_threads / block_x));
364           int block_z = std::min<int>(
365               maxThreadsDim[2], std::min<int>(lastPow2(outputHeight), max_threads / block_x / block_y));
366           block_x = std::min<int>(
367               maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
368           const dim3 block(block_x, block_y, block_z);
369 
370           int kernel_stride_C = ceil_div(
371               safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
372           int kernel_size_C = ceil_div(
373               safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);
374 
375           int grid_x = nbatch*kernel_stride_C;
376           int grid_y = std::min<int>(
377               at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
378               ceil_div(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE));
379           int grid_z = std::min<int>(
380               at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
381               ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE));
382           const dim3 grid(grid_x, grid_y, grid_z);
383 
384           size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
385           AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
386 
387           max_pool_forward_nhwc<scalar_t>
388           <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
389               input_data, nbatch,
390                   nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
391                   kH, kW, dH, dW, padH, padW, dilationH, dilationW,
392                   in_stride_n, in_stride_c,
393                   in_stride_h, in_stride_w,
394                   kernel_stride_C, kernel_size_C,
395                   output_data, indices_data);
396           C10_CUDA_KERNEL_LAUNCH_CHECK();
397           break;
398         }
399         case MemoryFormat::Contiguous: {
400           const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
401                                             BLOCK_THREADS);
402           max_pool_forward_nchw<scalar_t>
403               <<<ceil_div(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
404               count, input_data,
405                   nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
406                   kH, kW, dH, dW, padH, padW, dilationH, dilationW,
407                   output_data, indices_data);
408           C10_CUDA_KERNEL_LAUNCH_CHECK();
409           break;
410         }
411         default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
412       }
413     }
414   );
415 }
416 
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cuda)417 TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cuda)
418 (const Tensor& gradOutput_,
419 const Tensor& input_,
420 IntArrayRef kernel_size,
421 IntArrayRef stride,
422 IntArrayRef padding,
423 IntArrayRef dilation,
424 bool ceil_mode,
425 const Tensor& indices_,
426 const Tensor& gradInput) {
427   NoNamesGuard guard;
428 
429   TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
430   TensorArg gradOutput_arg{ gradOutput_, "gradOutput_", 2 };
431   TensorArg input_arg{ input_, "input_", 3 };
432   TensorArg indices_arg{ indices_, "indices", 4 };
433 
434   checkAllSameGPU(__func__,
435                   {gradInput_arg, gradOutput_arg, input_arg, indices_arg});
436   if (gradOutput_.numel() == 0) {
437     return;
438   }
439 
440   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
441   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
442 
443   const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
444   const int dW = stride.empty() ? kW :
445                  stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
446 
447   const int padH = safe_downcast<int, int64_t>(padding[0]);
448   const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
449 
450   const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
451   const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
452 
453   const auto memory_format = input_.suggest_memory_format();
454 
455   const Tensor input = input_.contiguous(memory_format);
456 
457   const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
458   const int64_t nInputPlane = input.size(-3);
459   const int64_t inputHeight = input.size(-2);
460   const int64_t inputWidth = input.size(-1);
461 
462   const int64_t in_stride_n = input.ndimension() == 4 ? input.stride(-4) : 0;
463   const int64_t in_stride_c = input.stride(-3);
464   const int64_t in_stride_h = input.stride(-2);
465   const int64_t in_stride_w = input.stride(-1);
466 
467   const Tensor gradOutput = gradOutput_.contiguous(memory_format);
468 
469   const int64_t outputHeight = gradOutput.size(-2);
470   const int64_t outputWidth = gradOutput.size(-1);
471 
472   const int64_t out_stride_c = gradOutput.stride(-3);
473   const int64_t out_stride_h = gradOutput.stride(-2);
474   const int64_t out_stride_w = gradOutput.stride(-1);
475 
476   const Tensor indices = indices_.contiguous(memory_format);
477 
478   gradInput.zero_();
479 
480   int64_t count = input.numel();
481 
482   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
483     "max_pool2d_with_indices_out_cuda_frame",
484     [&] {
485       using accscalar_t = acc_type<scalar_t, true>;
486 
487       const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
488       scalar_t *gradInput_data = gradInput.mutable_data_ptr<scalar_t>();
489       const int64_t *indices_data = indices.const_data_ptr<int64_t>();
490 
491       switch (memory_format) {
492         case MemoryFormat::ChannelsLast: {
493           const int max_threads = std::min<int>(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
494           int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
495           int block_x = std::min<int>(
496               maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
497           int block_y = std::min<int>(
498               maxThreadsDim[1], std::min<int>(lastPow2(inputWidth), max_threads / block_x));
499           int block_z = std::min<int>(
500               maxThreadsDim[2], std::min<int>(lastPow2(inputHeight), max_threads / block_x / block_y));
501           block_x = std::min<int>(
502               maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
503           const dim3 block(block_x, block_y, block_z);
504 
505           int kernel_stride_C = ceil_div(
506               safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
507           int kernel_size_C = ceil_div(
508               safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);
509 
510           int grid_x = nbatch*kernel_stride_C;
511           int grid_y = std::min<int>(
512               at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
513               ceil_div(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE));
514           int grid_z = std::min<int>(
515               at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
516               ceil_div(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE));
517           const dim3 grid(grid_x, grid_y, grid_z);
518 
519           size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t);
520           AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
521 
522           // The backward kernel is launched on input instead output.
523           // If it is launched on output layer, atomic_add would not provide much benefit on FP16.
524           // Please check comments at https://github.com/pytorch/pytorch/pull/34519.
525           max_pool_backward_nhwc<scalar_t, accscalar_t>
526           <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
527                   gradOutput_data,
528                   indices_data,
529                   nbatch,
530                   nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
531                   kH, kW, dH, dW, padH, padW, dilationH, dilationW,
532                   out_stride_c, out_stride_h, out_stride_w,
533                   kernel_stride_C, kernel_size_C,
534                   gradInput_data);
535           C10_CUDA_KERNEL_LAUNCH_CHECK();
536           break;
537         }
538         case MemoryFormat::Contiguous: {
539           int imgcount = inputWidth * inputHeight;
540           dim3 grid;
541           const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
542           grid.x = blocks;
543           grid.y = nbatch;
544           uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
545           if (maxGridY < grid.y) grid.y = maxGridY;
546           grid.z = nInputPlane;
547           uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
548           if (maxGridZ < grid.z) grid.z = maxGridZ;
549 
550           max_pool_backward_nchw<scalar_t, accscalar_t>
551           <<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
552                   gradOutput_data,
553                   indices_data,
554                   nbatch,
555                   nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
556                   kH, kW, dH, dW, padH, padW, dilationH, dilationW,
557                   gradInput_data);
558           C10_CUDA_KERNEL_LAUNCH_CHECK();
559           break;
560         }
561         default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
562       }
563     }
564   );
565 }
566 
567 } // at::native
568