xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ceil_div.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/Atomic.cuh>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/TensorUtils.h>
8 #include <ATen/OpMathType.h>
9 #include <ATen/Utils.h>
10 #include <c10/util/Exception.h>
11 #include <ATen/native/cuda/LaunchUtils.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_adaptive_avg_pool2d_backward_native.h>
18 #include <ATen/ops/_adaptive_avg_pool2d_native.h>
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/zeros_like.h>
21 #endif
22 
23 #include <ATen/native/AdaptivePooling.h>
24 
25 #include <algorithm>
26 #include <cfloat>
27 #include <cmath>
28 
29 #define START_IND(a,b,c) ((int64_t)((a / b) * c + ((a % b) * c) / b))
30 #define END_IND(a,b,c) (1 + ((int64_t)(a + 1) * c - 1) / b)
31 
32 #define START_IND_INT(a,b,c) ((a * c) / b)
33 #define END_IND_INT(a,b,c) (((a + 1) * c + b - 1) / b)
34 // #define START_IND(a,b,c) a * c / b
35 // #define END_IND(a,b,c)  (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
36 
37 #define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
38 #define BLOCK_STRIDE 2 // increasing block_stride to lower # of blocks launched
39 
40 namespace at::native {
41 
42 namespace {
43 
44   // 4d tensor B x D x H x W
45   // All kernels view batch dim B and feature dim D as collapsed.
46 
47   /*
48    * Description:
49    *    this function adaptively average pools an input 4D tensor along dimensions 2 and 3
50    *    4D input, 4D output
51    */
52    template <typename scalar_t>
adaptive_average_pool(const scalar_t * input,scalar_t * output,int isizeH,int isizeW,int osizeH,int osizeW,int64_t istrideD,int64_t istrideH,int64_t istrideW)53   __global__ void adaptive_average_pool(const scalar_t *input, scalar_t *output,
54                           int isizeH, int isizeW,
55                           int osizeH, int osizeW,
56                           int64_t istrideD, int64_t istrideH, int64_t istrideW)
57   {
58     using opmath_t = at::opmath_type<scalar_t>;
59     // iterators on output pixels
60     int oh, ow;
61 
62     // select input/output plane based on thread/block ID
63     int o_plane = blockIdx.x;
64     int i_plane = o_plane;
65 
66     output = output + o_plane*osizeH*osizeW;
67     input = input + i_plane*istrideD;
68 
69     int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
70     int oendH = osizeH;
71     const int ostepH = blockDim.y*gridDim.y;
72 
73     int ostartW = threadIdx.x;
74     int oendW = osizeW;
75     const int ostepW = blockDim.x;
76 
77     // For all output pixels...
78     for(oh = ostartH; oh < oendH; oh += ostepH) {
79 
80       int istartH = START_IND(oh, osizeH, isizeH);
81       int iendH   = END_IND(oh, osizeH, isizeH);
82       int kH = iendH - istartH;
83 
84       for(ow = ostartW; ow < oendW; ow += ostepW) {
85 
86         int istartW = START_IND(ow, osizeW, isizeW);
87         int iendW   = END_IND(ow, osizeW, isizeW);
88         int kW = iendW - istartW;
89 
90         // Compute the average pooling over corresponding input pixels
91         const scalar_t *ptr_input = input + istartH*istrideH + istartW*istrideW;
92         scalar_t *ptr_output = output + oh*osizeW + ow;
93         opmath_t sum = static_cast<opmath_t>(0);
94         int ih, iw;
95         for(ih = 0; ih < kH; ++ih) {
96           for(iw = 0; iw < kW; ++iw) {
97             scalar_t val = ptr_input[iw*istrideW];
98             sum += val;
99           }
100           ptr_input += istrideH; // next input line
101         }
102         // Update output
103         *ptr_output = sum / kH / kW;
104       }
105     }
106   }
107 
108   /*
109    * Description:
110    *    this function computes the gradInput from gradOutput
111    */
112    template <typename T>
adaptive_average_gradinput(T * gradInput,const T * gradOutput,int isizeH,int isizeW,int osizeH,int osizeW)113   __global__ void adaptive_average_gradinput(
114     T *gradInput, const T *gradOutput,
115     int isizeH, int isizeW, int osizeH, int osizeW
116   )
117   {
118     // iterators on input pixels
119     int ih, iw;
120 
121     // select input/output plane based on thread/block ID
122     int i_plane = blockIdx.x;
123     int o_plane = i_plane;
124 
125     gradOutput = gradOutput + o_plane*osizeH*osizeW;
126     gradInput = gradInput + i_plane*isizeH*isizeW;
127 
128     int istartH = blockDim.y*blockIdx.y + threadIdx.y;
129     int iendH = isizeH;
130     int istepH = blockDim.y*gridDim.y;
131 
132     int istartW = threadIdx.x;
133     int iendW = isizeW;
134     int istepW = blockDim.x;
135 
136     // compute gradInput
137     for(ih = istartH; ih < iendH; ih += istepH) {
138 
139       int ostartH = START_IND(ih, isizeH, osizeH);
140       int oendH   = END_IND(ih, isizeH, osizeH);
141 
142       for(iw = istartW; iw < iendW; iw += istepW) {
143 
144         int ostartW = START_IND(iw, isizeW, osizeW);
145         int oendW   = END_IND(iw, isizeW, osizeW);
146 
147         // Compute the gradients over corresponding output pixels
148         T *ptr_gradInput = gradInput + ih*isizeW + iw;
149 
150         int oh, ow;
151         for(oh = ostartH; oh < oendH; ++oh) {
152           int kH = START_IND(oh, osizeH, isizeH) - END_IND(oh, osizeH, isizeH);
153           for(ow = ostartW; ow < oendW; ++ow) {
154             int kW = START_IND(ow, osizeW, isizeW) - END_IND(ow, osizeW, isizeW);
155             T grad_delta = gradOutput[ow + oh*osizeW] / kH / kW;
156             *ptr_gradInput += grad_delta;
157           }
158         }
159       }
160     }
161   }
162 
163   /*
164    * Description:
165    *    this function computes the gradInput from gradOutput
166    *    (uses atomic add)
167    */
168    template <typename T>
atomic_adaptive_average_gradinput(T * gradInput,const T * gradOutput,int isizeH,int isizeW,int osizeH,int osizeW)169   __global__ void atomic_adaptive_average_gradinput(
170     T *gradInput, const T *gradOutput,
171     int isizeH, int isizeW, int osizeH, int osizeW
172   )
173   {
174     // iterators on output indices
175     int oh, ow;
176 
177     // select input/output plane based on thread/block ID
178     int o_plane = blockIdx.x;
179     int i_plane = o_plane;
180 
181     gradOutput = gradOutput + o_plane*osizeW*osizeH;
182     gradInput = gradInput + i_plane*isizeW*isizeH;
183 
184     int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
185     int oendH = osizeH;
186     int ostepH = blockDim.y*gridDim.y;
187 
188     int ostartW = threadIdx.x;
189     int oendW = osizeW;
190     int ostepW = blockDim.x;
191 
192     // For all output pixels...
193     for(oh = ostartH; oh < oendH; oh += ostepH) {
194 
195       int istartH = START_IND(oh, osizeH, isizeH);
196       int iendH   = END_IND(oh, osizeH, isizeH);
197       int kH = iendH - istartH;
198 
199       for(ow = ostartW; ow < oendW; ow += ostepW) {
200 
201         int istartW = START_IND(ow, osizeW, isizeW);
202         int iendW   = END_IND(ow, osizeW, isizeW);
203         int kW = iendW - istartW;
204 
205         // Compute the gradients for over corresponding input pixels
206         T *ptr_gradInput = gradInput + istartH*isizeW + istartW;
207         const T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
208         T grad_delta = *ptr_gradOutput / kW / kH;
209 
210         int ih, iw;
211         for(ih = 0; ih < kH; ++ih) {
212           for(iw = 0; iw < kW; ++iw) {
213             // atomic add since different threads could update same variable
214             gpuAtomicAddNoReturn(&(ptr_gradInput[iw]), grad_delta);
215           }
216           ptr_gradInput += isizeW; // next input line
217         }
218       }
219     }
220   }
221 
222   /*
223    * Description:
224    *    this function adaptively average pools an input 4D tensor along dimensions 2 and 3
225    *    NHWC layout for both input and output tensor
226    *    4D input, 4D output
227    */
228    template <typename index_t, typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)229   C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
230   __global__ void adaptive_average_pool_nhwc(const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
231                           int sizeB, int sizeC,
232                           int isizeH, int isizeW,
233                           int osizeH, int osizeW,
234                           int kernel_stride_C, int kernel_size_C,
235                           index_t istrideB, index_t istrideC,
236                           index_t istrideH, index_t istrideW)
237   {
238     using opmath_t = at::opmath_type<scalar_t>;
239     extern __shared__ int smem[];
240     opmath_t *out_cached = reinterpret_cast<opmath_t*>(smem);
241 
242     // flattening cta for pre-computation & smem initialization;
243     int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
244     int block_size = blockDim.x * blockDim.y * blockDim.z;
245 
246     // use shared memory to store temporary output value. This is simply to
247     // reduce register usage.
248     for (index_t i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) {
249       out_cached[i] = opmath_t(0.0);
250     }
251 
252     __syncthreads();
253 
254     // each CTA handles a portion of a single slice on batch dimension;
255     int batch_id = blockIdx.x % sizeB;
256     int channel_id = blockIdx.x / sizeB;
257     int channel_offset = threadIdx.x + channel_id * blockDim.x;
258 
259     // each CTA handles a single slice on batch dimension;
260     // We use gridDim.x to handle striding on C as well.
261     output = output + batch_id * osizeH * osizeW * sizeC;
262     input = input + batch_id * istrideB;
263 
264     // split out_cached and exclusively it assigned to each thread;
265     out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C * blockDim.x];
266 
267     // iterate on output H & W.
268     // Each CTA handles a consecutive H & W section (TILE); Do NOT stride CTA on
269     // tile so there's a better chance to hit L1 cache.
270     index_t oH = (osizeH + gridDim.z-1) / gridDim.z;
271     index_t oW = (osizeW + gridDim.y-1) / gridDim.y;
272     index_t ostartH = threadIdx.z + blockIdx.z*oH;
273     index_t oendH = ::min(ostartH+oH, osizeH);
274     index_t ostartW = threadIdx.y + blockIdx.y*oW;
275     index_t oendW = ::min(ostartW+oW, osizeW);
276 
277     // Stride for threads, each warp can reuse L1 as they go. So theoretically
278     // better chance to survive cache eviction.
279     for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
280       int istartH = START_IND_INT(oh, osizeH, isizeH);
281       int iendH = END_IND_INT(oh, osizeH, isizeH);
282       for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
283         int istartW = START_IND_INT(ow, osizeW, isizeW);
284         int iendW = END_IND_INT(ow, osizeW, isizeW);
285         scalar_t factor = scalar_t(1.0) / ((iendH-istartH) * (iendW-istartW));
286 
287         // loop on input: hierarchy h->w->c, use shared memory here hopefully
288         // would not stall global memory read;
289         for (index_t ih = istartH; ih < iendH; ih++) {
290           for (index_t iw = istartW; iw < iendW; iw++) {
291             int cached_index = threadIdx.x;
292             const scalar_t *ptr_input = input + ih*istrideH + iw*istrideW;
293             for (index_t c = channel_offset;
294                  c < sizeC;
295                  c += blockDim.x*kernel_stride_C) {
296               out_cached[cached_index] += ptr_input[c*istrideC];
297               cached_index += blockDim.x;
298             }
299           }
300         }
301         scalar_t *ptr_output = output + (oh * osizeW + ow) * sizeC;
302 
303         int cached_index = threadIdx.x;
304         // write accumulated output to global memory;
305         for (index_t c = channel_offset;
306              c < sizeC;
307              c += blockDim.x*kernel_stride_C) {
308           // This causes numerical issueptr when unit test with NCHW kernel;
309           // switch to could verify the correctness;
310           // output[c] = out_cached[c] / (iendH-istartH) / (iendW-istartW);
311           ptr_output[c] = out_cached[cached_index] * factor;
312           out_cached[cached_index] = opmath_t(0.0);
313           cached_index += blockDim.x;
314         }
315         // no need to __syncthreads() since out_cached is not shared.
316       }
317     }
318   }
319 
320   /*
321    * Description:
322    *    this function computes the gradInput from gradOutput
323    *    NHWC layout for both input and output tensor
324    *    4D input, 4D output
325    */
326    template <typename index_t, typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)327   C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
328   __global__ void adaptive_average_gradinput_nhwc(scalar_t* __restrict__ gradInput, const scalar_t* __restrict__ gradOutput,
329                           int sizeB, int sizeC,
330                           int isizeH, int isizeW,
331                           int osizeH, int osizeW,
332                           int kernel_stride_C, int kernel_size_C,
333                           index_t ostrideB, index_t ostrideC,
334                           index_t ostrideH, index_t ostrideW)
335   {
336     extern __shared__ int smem[];
337     index_t *ostartW_cached = smem;
338     index_t *oendW_cached = &ostartW_cached[isizeW];
339 
340     // be careful with alignment, in case scalar_t is fp16, we want to assign
341     // int pointers first.
342     scalar_t *r_kW_cached = reinterpret_cast<scalar_t*>(&oendW_cached[isizeW]);
343     scalar_t *r_kH_cached = &r_kW_cached[osizeW];
344     scalar_t *out_cached = &r_kH_cached[osizeH];
345 
346     // flattening cta for pre-computation & smem initialization;
347     int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
348     int block_size = blockDim.x * blockDim.y * blockDim.z;
349 
350     // Precompute output start/end index per input index on width dimension;
351     // Not doing this for height dimension, as that's our out-most loop.
352     for (index_t i = thread_id; i < isizeW; i+= block_size) {
353       ostartW_cached[i] = START_IND_INT(i, isizeW, osizeW);
354       oendW_cached[i] = END_IND_INT(i, isizeW, osizeW);
355     }
356 
357     // Precompute pooling height/weight factor for each output element;
358     // This is used to weight output gradient when accumulate them on input
359     // gradient.
360     // Technically we don't have to compute it for the whole `osizeH`, since
361     // each cta only covers a consecutive portion of the entire output. But it's
362     // not going to save us from code divergence, and shared memory save is not
363     // an issue neither, so just leave it as is for now.
364     for (index_t i = thread_id; i < osizeH; i+= block_size) {
365       r_kH_cached[i] = scalar_t(1.0) / (END_IND_INT(i, osizeH, isizeH) - START_IND_INT(i, osizeH, isizeH));
366     }
367     for (index_t i = thread_id; i < osizeW; i+= block_size) {
368       r_kW_cached[i] = scalar_t(1.0) / (END_IND_INT(i, osizeW, isizeW) - START_IND_INT(i, osizeW, isizeW));
369     }
370 
371     // each CTA handles a portion of a single slice on batch dimension;
372     int batch_id = blockIdx.x % sizeB;
373     int channel_id = blockIdx.x / sizeB;
374     int channel_offset = threadIdx.x + channel_id * blockDim.x;
375 
376     // use shared memory to store temporary output value. This is simply to
377     // reduce register usage.
378     for (index_t i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) {
379       out_cached[i] = scalar_t(0.0);
380     }
381 
382     __syncthreads();
383 
384     // each CTA handles a portion of a single slice on batch dimension;
385     // We use gridDim.x to handle striding on C as well.
386     gradInput = gradInput + batch_id * isizeH * isizeW * sizeC;
387     gradOutput = gradOutput + batch_id * ostrideB;
388 
389     // split out_cached and exclusively it assigned to each thread;
390     out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x * kernel_size_C];
391 
392     // iterate on input H & W.
393     // Each CTA handles a consecutive H & W section (TILE); Do NOT stride CTA on
394     // tile so there's a better chance to hit L1 cache.
395     index_t iH = (isizeH + gridDim.z-1) / gridDim.z;
396     index_t iW = (isizeW + gridDim.y-1) / gridDim.y;
397     index_t istartH = threadIdx.z + blockIdx.z*iH;
398     index_t iendH = ::min(istartH+iH, isizeH);
399     index_t istartW = threadIdx.y + blockIdx.y*iW;
400     index_t iendW = ::min(istartW+iW, isizeW);
401 
402     // Stride for threads, each warp can reuse L1 as they go. So theoretically
403     // better chance to survive cache eviction.
404     for (index_t ih = istartH; ih < iendH; ih+=blockDim.z) {
405       index_t ostartH = START_IND_INT(ih, isizeH, osizeH);
406       index_t oendH = END_IND_INT(ih, isizeH, osizeH);
407       for (index_t iw = istartW; iw < iendW; iw+=blockDim.y) {
408         // loop on output: hierarchy h->w->c, so we could reuse weight factor f
409         // because it remains the same for given oh & ow
410         for(index_t oh = ostartH; oh < oendH; ++oh) {
411           for(index_t ow = ostartW_cached[iw]; ow < oendW_cached[iw]; ++ow) {
412             scalar_t f = r_kW_cached[ow] * r_kH_cached[oh];
413             const scalar_t* ptr_gradOutput = gradOutput + oh*ostrideH + ow*ostrideW;
414             int cached_index = threadIdx.x;
415             for (index_t c = channel_offset;
416                  c < sizeC;
417                  c += blockDim.x*kernel_stride_C) {
418               out_cached[cached_index] += ptr_gradOutput[c*ostrideC] * f;
419               cached_index += blockDim.x;
420             }
421           }
422         }
423         scalar_t *ptr_gradInput = gradInput + (ih * isizeW + iw) * sizeC;
424         int cached_index = threadIdx.x;
425         // write accumulated gradIput to global memory;
426         for (index_t c = channel_offset;
427              c < sizeC;
428              c += blockDim.x*kernel_stride_C) {
429           ptr_gradInput[c] = out_cached[cached_index];
430           out_cached[cached_index] = scalar_t(0.0);
431           cached_index += blockDim.x;
432         }
433         // no need to __syncthreads() since out_cached is not shared.
434       }
435     }
436   }
437 
438   // 4d tensor B x D x H x W
439 
adaptive_avg_pool2d_out_cuda_template(Tensor & output,const Tensor & input,IntArrayRef output_size)440   void adaptive_avg_pool2d_out_cuda_template(
441     Tensor& output,
442     const Tensor& input,
443     IntArrayRef output_size)
444   {
445     TensorArg input_arg{ input, "input", 1 },
446               output_arg{ output, "output", 2 };
447     checkAllSameGPU(__func__, {input_arg, output_arg});
448 
449     TORCH_CHECK(output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2");
450     int64_t ndim = input.dim();
451     TORCH_CHECK((ndim == 3 || ndim == 4),
452       "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got ", input.sizes());
453     for (const auto i : {-2, -1}) {
454       TORCH_CHECK(input.size(i) > 0,
455         "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
456         "but input has sizes ", input.sizes(), " with dimension ", i + ndim, " being "
457         "empty");
458     }
459 
460     Tensor input_ = input;
461     switch (input.suggest_memory_format()) {
462       case at::MemoryFormat::ChannelsLast: {
463         // special case for tensor memory format in channels_last
464         TORCH_CHECK(input.ndimension() == 4,
465                     "adaptive_avg_pool2d(): Expected 4D tensor, but got ",
466                     input.sizes());
467 
468         int sizeB = input_.size(0);
469         int sizeC = input_.size(1);
470         int isizeH = input_.size(2);
471         int isizeW = input_.size(3);
472 
473         int64_t istrideB = input_.stride(0);
474         int64_t istrideC = input_.stride(1);
475         int64_t istrideH = input_.stride(2);
476         int64_t istrideW = input_.stride(3);
477 
478         int osizeH = output_size[0];
479         int osizeW = output_size[1];
480         // preserve channels_last stride on output tensor;
481         if (!output.is_contiguous(at::MemoryFormat::ChannelsLast)) {
482           // TODO: modify this after resize_ added `memory_format` tag
483           output.resize_({sizeB, sizeC, osizeH, osizeW}).as_strided_({sizeB, sizeC, osizeH, osizeW}, {sizeC*osizeH*osizeW, 1, osizeW*sizeC, sizeC});
484         }
485 
486         if (output.numel() == 0) {
487           return;
488         }
489 
490         const int max_threads = std::min<int>(
491             at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
492         int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
493         int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
494         size_t sharedMemPerBlock = at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
495 
496         // Launch kernel on output tensor elements. Logic behind launch config:
497         // output tensor size NCHW, strides NHWC;
498         // Launch on:
499         // N -> grid.x
500         // H -> grid.z * block.z
501         // W -> grid.y * block.y
502         // C -> block.x
503         // encourage larger block_y & block_z for better cache hit while maintain
504         // reasonable block_x for coalesced memory access;
505         int block_x = std::min<int>(
506             maxThreadsDim[0], std::min<int>(lastPow2(sizeC), at::cuda::warp_size()));
507         int block_y = std::min<int>(
508             maxThreadsDim[1], std::min<int>(lastPow2(osizeW), max_threads / block_x));
509         int block_z = std::min<int>(
510             maxThreadsDim[2], std::min<int>(lastPow2(osizeH), max_threads / block_x / block_y));
511         block_x = std::min<int>(
512             maxThreadsDim[0], std::min<int>(lastPow2(sizeC), max_threads / block_y / block_z));
513         const dim3 block(block_x, block_y, block_z);
514         int kernel_stride_C = ceil_div(sizeC, block_x * 4);
515         int kernel_size_C = ceil_div(sizeC, block_x * kernel_stride_C);
516 
517         // Do NOT clip grid_x, striding on Batch dimension is not in the kernel,
518         // although it could be easily implemented given current kernel.
519         int grid_x = sizeB*kernel_stride_C;
520         // it's OK to clip grid_y & grid_z, as we block the two dimensions in the kernel;
521         int grid_y = std::min<int>(
522             maxGridSize[1], ceil_div(osizeW, block_y*BLOCK_STRIDE));
523         int grid_z = std::min<int>(
524             maxGridSize[2], ceil_div(osizeH, block_z*BLOCK_STRIDE));
525         const dim3 grid(grid_x, grid_y, grid_z);
526 
527 
528         // we are dealing with packed tensor here. max index is the same as numel.
529         // TODO: to really support input tensor large enought to go beyond int32,
530         // we will need to restrict out shared memory usage and adjust the launch
531         // config;
532         AT_ASSERT(input_.numel() < std::numeric_limits<int32_t>::max());
533         AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
534             input_.scalar_type(), "adaptive_avg_pool2d_nhwc_cuda", [&] {
535               using opmath_t = at::opmath_type<scalar_t>;
536               size_t shmem_size = (kernel_size_C * block_x * block_y * block_z) * sizeof(opmath_t);
537               AT_ASSERT(shmem_size <= sharedMemPerBlock);
538               adaptive_average_pool_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
539                 input_.const_data_ptr<scalar_t>(),
540                 output.mutable_data_ptr<scalar_t>(),
541                 sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
542                 kernel_stride_C, kernel_size_C,
543                 istrideB, istrideC, istrideH, istrideW);
544               C10_CUDA_KERNEL_LAUNCH_CHECK();
545             }
546           );
547         break;
548       }
549       case at::MemoryFormat::Contiguous: {
550         int64_t grid_x = input.size(-3);
551         if (input.ndimension() == 4) {
552            input_ = input.contiguous();
553            grid_x *= input_.size(-4);
554         }
555         int64_t sizeD  = input_.size(-3);
556         int64_t isizeH = input_.size(-2);
557         int64_t isizeW = input_.size(-1);
558 
559         int64_t istrideD = input_.stride(-3);
560         int64_t istrideH = input_.stride(-2);
561         int64_t istrideW = input_.stride(-1);
562 
563         int64_t osizeH = output_size[0];
564         int64_t osizeW = output_size[1];
565         if (input.ndimension() == 4) {
566            output.resize_({input_.size(-4), sizeD, osizeH, osizeW});
567         } else {
568            output.resize_({sizeD, osizeH, osizeW});
569         }
570         if (output.numel() == 0) {
571           return;
572         }
573 
574         AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
575             input_.scalar_type(), "adaptive_avg_pool2d_cuda", [&] {
576               const scalar_t *input_data = input_.const_data_ptr<scalar_t>();
577               scalar_t *output_data = output.mutable_data_ptr<scalar_t>();
578 
579               // cuda blocks & threads:
580               int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
581               dim3 blocks(grid_x, blocksH);
582               dim3 threads(32, 8);
583 
584               // run averagepool kernel
585               adaptive_average_pool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
586                 input_data, output_data,
587                 isizeH, isizeW, osizeH, osizeW,
588                 istrideD, istrideH, istrideW);
589               C10_CUDA_KERNEL_LAUNCH_CHECK();
590             }
591           );
592         break;
593       }
594       default:
595         TORCH_CHECK(
596           false,
597           "Unsupported memory format. Supports only ChannelsLast, Contiguous");
598     }
599   }
600 
adaptive_avg_pool2d_backward_out_cuda_template(Tensor & gradInput,const Tensor & gradOutput_,const Tensor & input)601   void adaptive_avg_pool2d_backward_out_cuda_template(
602     Tensor& gradInput,
603     const Tensor& gradOutput_,
604     const Tensor& input)
605   {
606     TensorArg grad_input_arg{ gradInput, "gradInput", 1 },
607               grad_output_arg{ gradOutput_, "gradOutput_", 2 },
608               input_arg{ input, "input", 3 };
609 
610     adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool2d_backward");
611 
612     checkAllSameGPU(__func__, {grad_input_arg, grad_output_arg, input_arg});
613 
614     switch (input.suggest_memory_format()) {
615       case at::MemoryFormat::ChannelsLast: {
616         // special case for tensor memory format in channels_last
617         TORCH_CHECK(input.ndimension() == 4,
618                     "adaptive_avg_pool2d_backward_cuda(): Expected 4D tensor, but got ", input.ndimension());
619 
620         int sizeB = input.size(0);
621         int sizeC = input.size(1);
622         int isizeH = input.size(2);
623         int isizeW = input.size(3);
624 
625         Tensor gradOutput = gradOutput_;
626 
627         int64_t ostrideB = gradOutput.stride(0);
628         int64_t ostrideC = gradOutput.stride(1);
629         int64_t ostrideH = gradOutput.stride(2);
630         int64_t ostrideW = gradOutput.stride(3);
631 
632         int osizeH = gradOutput.size(-2);
633         int osizeW = gradOutput.size(-1);
634 
635         // preserve channels_last stride on input tensor;
636         if (!gradInput.is_contiguous(at::MemoryFormat::ChannelsLast)) {
637           gradInput.as_strided_(
638               {sizeB, sizeC, isizeH, isizeW},
639               {sizeC*isizeH*isizeW, 1, isizeW*sizeC, sizeC});
640         }
641 
642         int max_threads = std::min<int>(
643             at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
644         int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
645         int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
646         size_t sharedMemPerBlock = at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
647 
648         // Launch kernel on input tensor elements. Logic behind launch config:
649         // input tensor size NCHW, strides NHWC;
650         // Launch on:
651         // N(C) -> grid.x (striding on C to reduce sh_mem usage)
652         // H    -> grid.z * block.z
653         // W    -> grid.y * block.y
654         // C    -> block.x
655         // encourage larger block_y & block_z for better cache hit while maintain
656         // reasonable block_x for coalesced memory access;
657         bool done = false;
658         do {
659           int block_x = std::max<int>(std::min<int>(
660               maxThreadsDim[0], std::min<int>(lastPow2(sizeC), at::cuda::warp_size())), 1);
661           int block_y = std::max<int>(std::min<int>(
662               maxThreadsDim[1], std::min<int>(lastPow2(isizeW), max_threads / block_x)), 1);
663           int block_z = std::max<int>(std::min<int>(
664               maxThreadsDim[2], std::min<int>(lastPow2(isizeH), max_threads / block_x / block_y)), 1);
665           block_x = std::max<int>(std::min<int>(
666               maxThreadsDim[0], std::min<int>(lastPow2(sizeC), max_threads / block_y / block_z)), 1);
667           const dim3 block(block_x, block_y, block_z);
668           int kernel_stride_C = ceil_div(sizeC, block_x * 4);
669           int kernel_size_C = ceil_div(sizeC, block_x * kernel_stride_C);
670 
671           // Do NOT clip grid_x, striding on Batch dimension is not in the kernel,
672           // although it could be easily implemented given current kernel.
673           int grid_x = sizeB*kernel_stride_C;
674           // it's OK to clip grid_y & grid_z, as we block the two dimensions in the kernel;
675           int grid_y = std::min<int>(
676               maxGridSize[1], ceil_div(isizeW, block_y*BLOCK_STRIDE));
677           int grid_z = std::min<int>(
678               maxGridSize[2], ceil_div(isizeH, block_z*BLOCK_STRIDE));
679           const dim3 grid(grid_x, grid_y, grid_z);
680 
681           // we are dealing with packed tensor here. max index is the same as numel.
682           // TODO: to really support input tensor large enought to go beyond int32,
683           // we will need to restrict out shared memory usage and adjust the launch
684           // config;
685           AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
686           AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
687               input.scalar_type(), "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
688                 size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
689                 if (shmem_size <= sharedMemPerBlock) {
690                   adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
691                     gradInput.mutable_data_ptr<scalar_t>(),
692                     gradOutput.const_data_ptr<scalar_t>(),
693                     sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
694                     kernel_stride_C, kernel_size_C,
695                     ostrideB, ostrideC, ostrideH, ostrideW);
696                   C10_CUDA_KERNEL_LAUNCH_CHECK();
697                   done = true;
698                 } else {
699                   TORCH_WARN_ONCE("Requested shmem_size exceeds sharedMemPerBlock limit! Reducing max_threads...");
700                   max_threads /= 2;
701                 }
702               }
703             );
704         } while (!done && max_threads);
705         if (!done) {
706           TORCH_INTERNAL_ASSERT(false, "Couldn't reduce launch bounds to accomodate sharedMemPerBlock limit");
707         }
708         break;
709       }
710       case at::MemoryFormat::Contiguous: {
711         bool atomic = true; // suboptimal, but without atomic it doesn't pass the tests
712 
713         Tensor gradOutput = gradOutput_.contiguous();
714 
715         int64_t sizeD  = input.size(-3);
716         int64_t isizeH = input.size(-2);
717         int64_t isizeW = input.size(-1);
718 
719         int64_t osizeH = gradOutput.size(-2);
720         int64_t osizeW = gradOutput.size(-1);
721 
722         int64_t grid_x = sizeD;
723         if (input.ndimension() == 4) grid_x *= input.size(-4);
724 
725           //bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0);
726         AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
727             input.scalar_type(), "adaptive_avg_pool2d_backward_cuda", [&] {
728               const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
729               scalar_t *gradInput_data = gradInput.mutable_data_ptr<scalar_t>();
730 
731               // cuda blocks & threads:
732               int blocksH = std::max((int)(16L / sizeD), 1);
733               dim3 blocks(grid_x, blocksH);
734               dim3 threads(32, 8);
735 
736               if(atomic)
737               {
738                 // run updateGradInput kernel, accumulate gradients atomically
739                 atomic_adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
740                   gradInput_data, gradOutput_data,
741                   isizeH, isizeW, osizeH, osizeW);
742                 C10_CUDA_KERNEL_LAUNCH_CHECK();
743               }
744               else
745               {
746                 // run updateGradInput kernel
747                 adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
748                   gradInput_data, gradOutput_data,
749                   isizeH, isizeW, osizeH, osizeW);
750                 C10_CUDA_KERNEL_LAUNCH_CHECK();
751               }
752             }
753           );
754         break;
755       }
756       default:
757         TORCH_CHECK(
758           false,
759           "Unsupported memory format. Supports only ChannelsLast, Contiguous");
760 
761     }
762   }
763 
764 } // namespace
765 
adaptive_avg_pool2d_out_cuda(const Tensor & input,IntArrayRef output_size,Tensor & output)766   Tensor& adaptive_avg_pool2d_out_cuda(
767     const Tensor& input,
768     IntArrayRef output_size,
769     Tensor& output)
770   {
771     adaptive_avg_pool2d_out_cuda_template(
772       output, input, output_size);
773     return output;
774   }
775 
adaptive_avg_pool2d_cuda(at::Tensor const & input,IntArrayRef output_size)776   Tensor adaptive_avg_pool2d_cuda(
777     at::Tensor const& input,
778     IntArrayRef output_size)
779   {
780     auto output = at::empty({0}, input.options());
781     adaptive_avg_pool2d_out_cuda_template(
782       output, input, output_size);
783     return output;
784   }
785 
adaptive_avg_pool2d_backward_out_cuda(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input)786   Tensor& adaptive_avg_pool2d_backward_out_cuda(
787     Tensor& gradInput,
788     const Tensor& gradOutput,
789     const Tensor& input)
790   {
791     // See Note [Writing Nondeterministic Operations]
792     // Nondeterministic because of atomicAdd usage
793     globalContext().alertNotDeterministic("adaptive_avg_pool2d_backward_out_cuda");
794     gradInput.resize_as_(input);
795     if (gradInput.numel() != 0) {
796       adaptive_avg_pool2d_backward_out_cuda_template(
797         gradInput, gradOutput, input);
798     }
799     return gradInput;
800   }
801 
adaptive_avg_pool2d_backward_cuda(const Tensor & gradOutput,const Tensor & input)802   Tensor adaptive_avg_pool2d_backward_cuda(
803     const Tensor& gradOutput,
804     const Tensor& input)
805   {
806     // See Note [Writing Nondeterministic Operations]
807     // Nondeterministic because of atomicAdd usage
808     globalContext().alertNotDeterministic("adaptive_avg_pool2d_backward_cuda");
809     auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
810     if (gradInput.numel() != 0) {
811       adaptive_avg_pool2d_backward_out_cuda_template(
812         gradInput, gradOutput, input);
813     }
814     return gradInput;
815   }
816 
817 } // namespace at::native
818 
819 #undef BLOCK_STRIDE
820 #undef CUDA_MAX_THREADS
821 #undef START_IND
822 #undef END_IND
823