xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/Atomic.cuh>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/NumericLimits.cuh>
6 #include <ATen/Dispatch.h>
7 #include <ATen/NumericUtils.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <c10/util/Exception.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/adaptive_max_pool2d_backward_native.h>
17 #include <ATen/ops/adaptive_max_pool2d_native.h>
18 #include <ATen/ops/empty.h>
19 #endif
20 
21 #include <algorithm>
22 #include <cfloat>
23 #include <cmath>
24 
25 
26 namespace at::native {
27 
28 namespace {
29 
start_index(int64_t a,int64_t b,int64_t c)30 __device__ inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
31   return (a / b) * c + ((a % b) * c) / b;
32 }
33 
end_index(int64_t a,int64_t b,int64_t c)34 __device__ inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
35   return 1 + ((a + 1) * c - 1) / b;
36 }
37 
38 // 4d tensor B x D x H x W
39 
40 /*
41  * Description:
42  *    this function adaptively maxpools an input 4D tensor along dimensions 2 and 3
43  *    4D input, 4D output, 4D argmax x and y
44  */
45  template <typename T>
adaptivemaxpool(const T * input,T * output,int64_t * indices,int isizeH,int isizeW,int osizeH,int osizeW,int64_t istrideD,int64_t istrideH,int64_t istrideW)46 __global__ void adaptivemaxpool(const T *input, T *output, int64_t *indices,
47                         int isizeH, int isizeW,
48                         int osizeH, int osizeW,
49                         int64_t istrideD, int64_t istrideH, int64_t istrideW)
50 {
51   // iterators
52   int oh, ow;
53 
54   // compute offsets based on thread/block ID
55   int o_plane = blockIdx.x;
56   int i_plane = o_plane;
57 
58   int ostartW = threadIdx.x;
59   int oendW = osizeW;
60   const int ostepW = blockDim.x;
61 
62   int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
63   int oendH = osizeH;
64   const int ostepH = blockDim.y*gridDim.y;
65   // select input/output plane
66   output = output + o_plane*osizeH*osizeW;
67   input = input + i_plane*istrideD;
68   indices = indices + o_plane*osizeH*osizeW;
69 
70   // For all output pixels...
71   for(oh = ostartH; oh < oendH; oh += ostepH) {
72 
73     int istartH = start_index(oh, osizeH, isizeH);
74     int iendH   = end_index(oh, osizeH, isizeH);
75     int kH = iendH - istartH;
76 
77     for(ow = ostartW; ow < oendW; ow += ostepW) {
78       int istartW = start_index(ow, osizeW, isizeW);
79       int iendW   = end_index(ow, osizeW, isizeW);
80 
81       int kW = iendW - istartW;
82 
83       // Compute the mean of the input image...
84       const T *ptr_input = input + istartH*istrideH + istartW*istrideW;
85       T *ptr_output = output + oh*osizeW + ow;
86       int64_t *ptr_ind = indices + oh*osizeW + ow;
87       int argmax = istartH * isizeW + istartW;
88       T max = at::numeric_limits<T>::lower_bound(); // -Infinity
89       int ih, iw;
90       for(ih = 0; ih < kH; ih++) {
91         for(iw = 0; iw < kW; iw++) {
92           T val = ptr_input[iw*istrideW];
93           if ((val > max) || at::_isnan(val)) {
94             max = val;
95             argmax = (ih+istartH)*isizeW + iw+istartW;
96           }
97         }
98         ptr_input += istrideH; // next input line
99       }
100       // Update output and argmax
101       *ptr_output = max;
102       *ptr_ind = argmax;
103     }
104   }
105 }
106 
107 /*
108  * Description:
109  *    this function computes the gradInput from weight and gradOutput
110  */
111  template <typename T>
adaptivemaxgradinput(T * gradInput,const T * gradOutput,const int64_t * indices,int isizeH,int isizeW,int osizeH,int osizeW)112 __global__ void adaptivemaxgradinput(T *gradInput, const T *gradOutput, const int64_t *indices,
113                              int isizeH, int isizeW,
114                              int osizeH, int osizeW)
115 {
116   // iterators
117   int oh, ow;
118 
119   // compute offsets based on thread/block ID
120   int o_plane = blockIdx.x;
121   int i_plane = o_plane;
122   //int k = blockIdx.x % sizeD;
123 
124   int ostartW = threadIdx.x;
125   int oendW = osizeW;
126   int ostepW = blockDim.x;
127 
128   int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
129   int oendH = osizeH;
130   int ostepH = blockDim.y*gridDim.y;
131 
132   // select input/output plane
133   gradOutput = gradOutput + o_plane*osizeH*osizeW;
134   gradInput = gradInput + i_plane*isizeH*isizeW;
135   indices = indices + o_plane*osizeH*osizeW;
136 
137   // compute gradInput
138   for(oh = ostartH; oh < oendH; oh += ostepH) {
139 
140     for(ow = ostartW; ow < oendW; ow += ostepW) {
141 
142       const T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
143       const int64_t *ptr_ind = indices + oh*osizeW + ow;
144       T z = *ptr_gradOutput;
145 
146       int argmax = (*ptr_ind);
147 
148       gradInput[argmax] += z;
149     }
150   }
151 }
152 
153 /*
154  * Description:
155  *    this function computes the gradInput from weight and gradOutput
156  *    when kH != dH or kW != dW (uses atomic add)
157  */
158  template <typename T>
atomicadaptivemaxgradinput(T * gradInput,const T * gradOutput,const int64_t * indices,int isizeH,int isizeW,int osizeH,int osizeW)159 __global__ void atomicadaptivemaxgradinput(
160   T *gradInput, const T *gradOutput, const int64_t *indices,
161   int isizeH, int isizeW, int osizeH, int osizeW
162 )
163 {
164   // iterators
165   int oh, ow;
166 
167   // compute offsets based on thread/block ID
168   int o_plane = blockIdx.x;
169   int i_plane = o_plane;
170 
171   int ostartW = threadIdx.x;
172   int oendW = osizeW;
173   int ostepW = blockDim.x;
174 
175   int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
176   int oendH = osizeH;
177   int ostepH = blockDim.y*gridDim.y;
178 
179   // select input/output plane
180   gradOutput = gradOutput + o_plane*osizeH*osizeW;
181   gradInput = gradInput + i_plane*isizeH*isizeW;
182   indices = indices + o_plane*osizeH*osizeW;
183 
184   // compute gradInput
185   for(oh = ostartH; oh < oendH; oh += ostepH) {
186 
187     for(ow = ostartW; ow < oendW; ow += ostepW) {
188 
189       const T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
190       const int64_t *ptr_ind = indices + oh*osizeW + ow;
191       T z = *ptr_gradOutput;
192 
193       int argmax = (*ptr_ind);
194 
195       // atomic add since different threads could update same variable
196       gpuAtomicAddNoReturn(&(gradInput[argmax]), z);
197     }
198   }
199 }
200 } // namespace
201 
202 // 4d tensor B x D x H x W
203 
TORCH_IMPL_FUNC(adaptive_max_pool2d_out_cuda)204 TORCH_IMPL_FUNC(adaptive_max_pool2d_out_cuda)
205 (const Tensor& input,
206 IntArrayRef output_size,
207 const Tensor& output,
208 const Tensor& indices) {
209   TensorArg output_arg{output, "output", 1};
210   TensorArg indices_arg{indices, "indices", 2};
211   TensorArg input_arg{input, "input", 3};
212 
213   checkAllSameGPU(
214       __func__, {output_arg, indices_arg, input_arg});
215   if (input.numel() == 0) {
216     return;
217   }
218 
219   int64_t osizeH = output_size[0];
220   int64_t osizeW = output_size[1];
221 
222   const at::Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
223   const at::Tensor indices_c = indices.is_contiguous() ? indices : at::empty(indices.sizes(), indices.options());
224 
225   if (input.ndimension() == 3) {
226     int64_t sizeD = input.size(0);
227     int64_t isizeH = input.size(1);
228     int64_t isizeW = input.size(2);
229 
230     int64_t istrideD = input.stride(0);
231     int64_t istrideH = input.stride(1);
232     int64_t istrideW = input.stride(2);
233 
234     AT_DISPATCH_FLOATING_TYPES_AND2(
235         kHalf, kBFloat16, input.scalar_type(), "adaptive_max_pool2d_cuda", [&] {
236           const scalar_t* input_data = input.const_data_ptr<scalar_t>();
237           scalar_t* output_data = output_c.mutable_data_ptr<scalar_t>();
238           int64_t* indices_data = indices_c.mutable_data_ptr<int64_t>();
239 
240           // cuda blocks & threads:
241           int blocksH = (int)(16L / sizeD);
242           blocksH = blocksH < 1 ? 1 : blocksH;
243           dim3 blocks(sizeD, blocksH);
244           dim3 threads(32, 8);
245 
246           // run maxpool kernel
247           adaptivemaxpool<<<
248               blocks,
249               threads,
250               0,
251               at::cuda::getCurrentCUDAStream()>>>(
252               input_data,
253               output_data,
254               indices_data,
255               isizeH,
256               isizeW,
257               osizeH,
258               osizeW,
259               istrideD,
260               istrideH,
261               istrideW);
262           C10_CUDA_KERNEL_LAUNCH_CHECK();
263         });
264   } else {
265     Tensor input_ = input.contiguous();
266     int64_t sizeB = input_.size(0);
267     int64_t sizeD = input_.size(1);
268     int64_t isizeH = input_.size(2);
269     int64_t isizeW = input_.size(3);
270 
271     // In the kernel, the batch and channel dimensions are treated as if they
272     // are flattened and istrideD is used as the stride of this flattened dim
273     // Handle the edge case where input_.size(1) == 1, where despite passing the
274     // contiguity check the stride might not be H * W
275     int64_t istrideD = isizeH * isizeW;
276     int64_t istrideH = input_.stride(2);
277     int64_t istrideW = input_.stride(3);
278 
279     AT_DISPATCH_FLOATING_TYPES_AND2(
280         kHalf,
281         kBFloat16,
282         input_.scalar_type(),
283         "adaptive_max_pool2d_cuda",
284         [&] {
285           const scalar_t* input_data = input_.const_data_ptr<scalar_t>();
286           scalar_t* output_data = output_c.mutable_data_ptr<scalar_t>();
287           int64_t* indices_data = indices_c.mutable_data_ptr<int64_t>();
288 
289           // cuda blocks & threads:
290           int blocksH = (int)(16L / sizeD);
291           blocksH = blocksH < 1 ? 1 : blocksH;
292           dim3 blocks(sizeB * sizeD, blocksH);
293           dim3 threads(32, 8);
294 
295           // run maxpool kernel
296           adaptivemaxpool<<<
297               blocks,
298               threads,
299               0,
300               at::cuda::getCurrentCUDAStream()>>>(
301               input_data,
302               output_data,
303               indices_data,
304               isizeH,
305               isizeW,
306               osizeH,
307               osizeW,
308               istrideD,
309               istrideH,
310               istrideW);
311           C10_CUDA_KERNEL_LAUNCH_CHECK();
312         });
313   }
314 
315   if (!output.is_contiguous()) {
316     output.copy_(output_c);
317   }
318   if (!indices.is_contiguous()) {
319     indices.copy_(indices_c);
320   }
321 }
322 
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)323 TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
324 (const Tensor& gradOutput,
325  const Tensor& input,
326  const Tensor& indices,
327  const Tensor& gradInput) {
328   globalContext().alertNotDeterministic(
329       "adaptive_max_pool2d_backward_cuda");
330 
331   TensorArg grad_input_arg{gradInput, "gradInput", 1};
332   TensorArg grad_output_arg{gradOutput, "gradOutput", 2};
333   TensorArg input_arg{input, "input", 3};
334   TensorArg indices_arg{indices, "indices", 4};
335 
336   checkAllSameGPU(
337       __func__,
338       {grad_input_arg, grad_output_arg, input_arg, indices_arg});
339 
340   if (gradOutput.numel() == 0) {
341     return;
342   }
343 
344   bool atomic =
345       true; // suboptimal, but without atomic it doesn't pass the tests
346 
347   const at::Tensor gradOutput_ = gradOutput.contiguous();
348   const at::Tensor indices_ = indices.contiguous();
349   const at::Tensor gradInput_c = gradInput.is_contiguous() ? gradInput : at::empty(gradInput.sizes(), gradInput.options());
350 
351   if (input.ndimension() == 3) {
352     int64_t sizeD = input.size(0);
353     int64_t isizeH = input.size(1);
354     int64_t isizeW = input.size(2);
355 
356     int64_t osizeH = gradOutput_.size(1);
357     int64_t osizeW = gradOutput_.size(2);
358 
359     // bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
360 
361     gradInput_c.zero_();
362 
363     AT_DISPATCH_FLOATING_TYPES_AND2(
364         kHalf,
365         kBFloat16,
366         input.scalar_type(),
367         "adaptive_max_pool2d_backward_cuda",
368         [&] {
369           scalar_t* gradInput_data = gradInput_c.mutable_data_ptr<scalar_t>();
370           const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
371           const int64_t* indices_data = indices_.const_data_ptr<int64_t>();
372 
373           // cuda blocks & threads:
374           int blocksH = (int)(16L / sizeD);
375           blocksH = blocksH < 1 ? 1 : blocksH;
376           dim3 blocks(sizeD, blocksH);
377           dim3 threads(32, 8);
378 
379           if (atomic) {
380             // run updateGradInput kernel, accumulate gradients atomically
381             atomicadaptivemaxgradinput<<<
382                 blocks,
383                 threads,
384                 0,
385                 at::cuda::getCurrentCUDAStream()>>>(
386                 gradInput_data,
387                 gradOutput_data,
388                 indices_data,
389                 isizeH,
390                 isizeW,
391                 osizeH,
392                 osizeW);
393             C10_CUDA_KERNEL_LAUNCH_CHECK();
394           } else {
395             // run updateGradInput kernel
396             atomicadaptivemaxgradinput<<<
397                 blocks,
398                 threads,
399                 0,
400                 at::cuda::getCurrentCUDAStream()>>>(
401                 gradInput_data,
402                 gradOutput_data,
403                 indices_data,
404                 isizeH,
405                 isizeW,
406                 osizeH,
407                 osizeW);
408             C10_CUDA_KERNEL_LAUNCH_CHECK();
409           }
410         });
411   } else {
412     int64_t sizeB = input.size(0);
413     int64_t sizeD = input.size(1);
414     int64_t isizeH = input.size(2);
415     int64_t isizeW = input.size(3);
416 
417     int64_t osizeH = gradOutput_.size(2);
418     int64_t osizeW = gradOutput_.size(3);
419 
420     gradInput_c.zero_();
421 
422     // bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
423 
424     AT_DISPATCH_FLOATING_TYPES_AND2(
425         kHalf,
426         kBFloat16,
427         input.scalar_type(),
428         "adaptive_max_pool2d_backward_cuda",
429         [&] {
430           scalar_t* gradInput_data = gradInput_c.mutable_data_ptr<scalar_t>();
431           const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
432           const int64_t* indices_data = indices_.const_data_ptr<int64_t>();
433 
434           // cuda blocks & threads:
435           int blocksH = (int)(16L / sizeD);
436           blocksH = blocksH < 1 ? 1 : blocksH;
437           dim3 blocks(sizeB * sizeD, blocksH);
438           dim3 threads(32, 8);
439 
440           if (atomic) {
441             // run updateGradInput kernel, accumulate gradients atomically
442             atomicadaptivemaxgradinput<<<
443                 blocks,
444                 threads,
445                 0,
446                 at::cuda::getCurrentCUDAStream()>>>(
447                 gradInput_data,
448                 gradOutput_data,
449                 indices_data,
450                 isizeH,
451                 isizeW,
452                 osizeH,
453                 osizeW);
454             C10_CUDA_KERNEL_LAUNCH_CHECK();
455           } else {
456             // run updateGradInput kernel, accumulate gradients atomically
457             adaptivemaxgradinput<<<
458                 blocks,
459                 threads,
460                 0,
461                 at::cuda::getCurrentCUDAStream()>>>(
462                 gradInput_data,
463                 gradOutput_data,
464                 indices_data,
465                 isizeH,
466                 isizeW,
467                 osizeH,
468                 osizeW);
469             C10_CUDA_KERNEL_LAUNCH_CHECK();
470           }
471         });
472   }
473 
474   if (!gradInput.is_contiguous()) {
475     gradInput.copy_(gradInput_c);
476   }
477  }
478 } // namespace at::native
479