xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/Atomic.cuh>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/NumericLimits.cuh>
8 #include <ATen/cuda/detail/IndexUtils.cuh>
9 #include <ATen/cuda/detail/TensorInfo.cuh>
10 #include <ATen/cuda/detail/KernelUtils.h>
11 #include <ATen/NumericUtils.h>
12 #include <ATen/TensorUtils.h>
13 #include <ATen/Utils.h>
14 #include <ATen/native/FractionalMaxPooling.h>
15 #include <c10/macros/Macros.h>
16 #include <c10/util/Exception.h>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/empty.h>
23 #include <ATen/ops/fractional_max_pool3d_backward_native.h>
24 #include <ATen/ops/fractional_max_pool3d_native.h>
25 #endif
26 
27 #include <algorithm>
28 #include <cfloat>
29 #include <cmath>
30 
31 namespace at::native {
32 
33 using namespace at::cuda::detail;
34 
35 namespace {
36 
37 template <typename scalar_t, typename accscalar_t>
get_intervals(accscalar_t sample,int64_t index,int64_t inputSize,int64_t outputSize,int64_t poolSize)38 __device__ inline int64_t get_intervals(
39   accscalar_t sample,
40   int64_t index,
41   int64_t inputSize,
42   int64_t outputSize,
43   int64_t poolSize) {
44     accscalar_t alpha = static_cast<accscalar_t>(inputSize - poolSize) /
45       static_cast<accscalar_t>(outputSize - 1);
46     if (index == outputSize - 1) {
47       return inputSize - poolSize;
48     } else {
49       return static_cast<int64_t>((index + sample) * alpha) - \
50         static_cast<int64_t>(sample * alpha);
51     }
52   }
53 
54 template <typename scalar_t>
fractional_max_pool3d_out_frame(PackedTensorAccessor64<const scalar_t,5> input,PackedTensorAccessor64<scalar_t,5> output,PackedTensorAccessor64<int64_t,5> indices,PackedTensorAccessor64<const scalar_t,3> samples,int64_t poolSizeT,int64_t poolSizeH,int64_t poolSizeW)55 __global__ void fractional_max_pool3d_out_frame(
56   PackedTensorAccessor64<const scalar_t, 5> input,
57   PackedTensorAccessor64<scalar_t, 5> output,
58   PackedTensorAccessor64<int64_t, 5> indices,
59   PackedTensorAccessor64<const scalar_t, 3> samples,
60   int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
61     using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
62     // Output (t, h, w) point that this thread is responsible for
63     int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
64     int64_t plane = blockIdx.y;
65     int64_t batch = blockIdx.z;
66     // Each thread generates a specific output point
67     if (ourOutputPoint < output.size(2) * output.size(3) *
68       output.size(4)){
69       int64_t outputT = ourOutputPoint / (output.size(3) *
70                     output.size(4));
71       int64_t outputH = (ourOutputPoint / output.size(4)) %
72                     output.size(3);
73       int64_t outputW = ourOutputPoint % output.size(4);
74 
75       int64_t poolT = get_intervals<scalar_t,accscalar_t>(
76         static_cast<accscalar_t>(samples[batch][plane][0]),
77         outputT, input.size(2), output.size(2), poolSizeT);
78       int64_t poolH = get_intervals<scalar_t, accscalar_t>(
79         static_cast<accscalar_t>(samples[batch][plane][1]),
80         outputH, input.size(3), output.size(3), poolSizeH);
81       int64_t poolW = get_intervals<scalar_t, accscalar_t>(
82         static_cast<accscalar_t>(samples[batch][plane][2]),
83         outputW, input.size(4), output.size(4), poolSizeW);
84 
85       scalar_t maxVal = at::numeric_limits<scalar_t>::lower_bound();
86       int64_t maxIndex = poolT * input.size(3) * input.size(4) + poolH * input.size(4) + poolW;
87 
88       for(int64_t t = poolT; t < poolT + poolSizeT; ++ t) {
89         for (int64_t h = poolH; h < poolH + poolSizeH; ++h) {
90           if(poolSizeW < 2 || poolSizeW > 7) {
91             for (int64_t w = poolW; w < poolW + poolSizeW; ++w) {
92               scalar_t val = input[batch][plane][t][h][w];
93               // for consistency with THNN, favor the first max
94               if (val > maxVal || at::_isnan(val)) {
95                 maxIndex = t * input.size(3) *
96                   input.size(4) + h * input.size(4) + w;
97                 maxVal = val;
98               }
99             }
100           } else {
101             for (int64_t i = 0; i < poolSizeW; ++i) {
102               int64_t w = i + poolW;
103               scalar_t val = input[batch][plane][t][h][w];
104               // for consistency with THNN, favor the first max
105               if (val > maxVal || at::_isnan(val)) {
106                 maxIndex = t * input.size(3) * input.size(4) +
107                   h * input.size(4) + w;
108                 maxVal = val;
109               }
110             }
111           }
112         }
113       }
114 
115       indices[batch][plane][outputT][outputH][outputW] = maxIndex;
116       output[batch][plane][outputT][outputH][outputW] = maxVal;
117     }
118   }
119 
120 template <typename scalar_t>
fractional_max_pool3d_backward_out_frame(PackedTensorAccessor64<scalar_t,5> gradInput,PackedTensorAccessor64<const scalar_t,5> gradOutput,PackedTensorAccessor64<const int64_t,5> indices)121 __global__ void fractional_max_pool3d_backward_out_frame(
122   PackedTensorAccessor64<scalar_t, 5> gradInput,
123   PackedTensorAccessor64<const scalar_t, 5> gradOutput,
124   PackedTensorAccessor64<const int64_t, 5> indices) {
125   // Output (h, w) point that this thread is responsible for
126   int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
127   int64_t plane = blockIdx.y;
128   int64_t batch = blockIdx.z;
129 
130   // Each thread generates a specific output point
131   if (ourOutputPoint < gradOutput.size(2) *
132     gradOutput.size(3) * gradOutput.size(4)) {
133     int64_t outputW = ourOutputPoint % gradOutput.size(4);
134     int64_t outputH = (ourOutputPoint / gradOutput.size(4)) %
135                       gradOutput.size(3);
136     int64_t outputT = ourOutputPoint / (gradOutput.size(3) *
137                       gradOutput.size(4));
138 
139     int64_t index = indices[batch][plane][outputT][outputH][outputW];
140     CUDA_KERNEL_ASSERT(index >= 0);
141     int64_t inputW = index % gradInput.size(4);
142     int64_t inputH = (index / gradInput.size(4)) %
143       gradInput.size(3);
144     int64_t inputT = index / (gradInput.size(3) *
145       gradInput.size(4));
146     CUDA_KERNEL_ASSERT(inputT < gradInput.size(2));
147 
148     gpuAtomicAddNoReturn(
149       &gradInput[batch][plane][inputT][inputH][inputW],
150       gradOutput[batch][plane][outputT][outputH][outputW]
151       );
152     }
153   }
154 
fractional_max_pool3d_backward_out_cuda_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,IntArrayRef output_size,const Tensor & indices)155 void fractional_max_pool3d_backward_out_cuda_template(
156   Tensor& gradInput,
157   const Tensor& gradOutput,
158   const Tensor& input,
159   IntArrayRef output_size,
160   const Tensor& indices) {
161     int64_t dimt = 1;
162     int64_t dimh = 2;
163     int64_t dimw = 3;
164 
165     int64_t outputT = output_size[0];
166     int64_t outputH = output_size[1];
167     int64_t outputW = output_size[2];
168 
169     int64_t ndims = input.ndimension();
170     if (ndims == 5) {
171       dimt++;
172       dimh++;
173       dimw++;
174     }
175 
176     /* sizes */
177     int64_t inputT = input.size(dimt);
178     int64_t inputH = input.size(dimh);
179     int64_t inputW = input.size(dimw);
180 
181     TORCH_CHECK(
182       outputT == gradOutput.size(dimt),
183       "fractional_max_pool3d_backward_out_cuda_template(): ",
184       "gradOutput time unexpected"
185     );
186     TORCH_CHECK(
187       outputH == gradOutput.size(dimh),
188       "fractional_max_pool3d_backward_out_cuda_template(): ",
189       "gradOutput height unexpected"
190     );
191     TORCH_CHECK(
192       outputW == gradOutput.size(dimw),
193       "fractional_max_pool3d_backward_out_cuda_template(): ",
194       "gradOutput width unexpected"
195     );
196 
197     /* resize */
198     gradInput.resize_as_(input);
199     gradInput.zero_();
200 
201     auto gradInput_ = gradInput;
202     auto gradOutput_ = gradOutput;
203     auto indices_ = indices;
204 
205     if(ndims == 4) {
206       gradInput_ = gradInput_.reshape({1, gradInput.size(0), inputT,
207                                        inputH, inputW});
208       gradOutput_ = gradOutput_.reshape({1, gradOutput.size(0), outputT,
209                                          outputH, outputW});
210       indices_ = indices_.reshape({1, indices.size(0), outputT, outputH,
211                                    outputW});
212     }
213 
214     if (gradInput.numel() == 0) {
215       return;
216     }
217 
218     /* backprop */
219     // block is limited to 4 warps
220     // grid handles overflow per each plane
221     int64_t outputPlaneSize = gradOutput_.size(2) *
222       gradOutput_.size(3) * gradOutput_.size(4);
223     dim3 grid(
224       (outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
225       gradInput_.size(1),
226       gradInput_.size(0));
227     dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
228 
229     AT_DISPATCH_FLOATING_TYPES_AND2(
230       at::ScalarType::Half,
231       at::ScalarType::BFloat16,
232       gradOutput.scalar_type(),
233       "fractional_max_pool3d_backward_out_frame",
234       [&] {
235         fractional_max_pool3d_backward_out_frame<scalar_t>
236         <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
237           gradInput_.packed_accessor64<scalar_t, 5>(),
238           gradOutput_.packed_accessor64<const scalar_t, 5>(),
239           indices_.packed_accessor64<const int64_t, 5>()
240         );
241         C10_CUDA_KERNEL_LAUNCH_CHECK();
242       }
243     );
244   }
245 
246 }// namespace
247 
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cuda)248 TORCH_IMPL_FUNC(fractional_max_pool3d_out_cuda) (
249   const Tensor& input,
250   int64_t poolSizeT,
251   int64_t poolSizeH,
252   int64_t poolSizeW,
253   int64_t outputT,
254   int64_t outputH,
255   int64_t outputW,
256   const Tensor& randomSamples,
257   int64_t numBatch,
258   int64_t numPlanes,
259   int64_t inputT,
260   int64_t inputH,
261   int64_t inputW,
262   const Tensor& output,
263   const Tensor& indices) {
264   fractional_max_pool_check_shape</*ndim*/ 3>(input, randomSamples);
265 
266   auto output_ = output;
267   auto indices_ = indices;
268   auto input_ = input;
269 
270   int64_t ndims = input_.ndimension();
271   if(ndims == 4) {
272     output_ = output_.reshape({1, numPlanes, outputT, outputH, outputW});
273     indices_ = indices_.reshape({1, numPlanes, outputT, outputH, outputW});
274     input_ = input_.reshape({1, numPlanes, inputT, inputH, inputW});
275   }
276   if (output_.numel() == 0) {
277     return;
278   }
279 
280   // block is limited to 4 warps
281   // grid handles overflow per each plane
282   int64_t outputPlaneSize = output_.size(2) *
283     output_.size(3) * output_.size(4);
284   dim3 grid(
285     (outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
286     input_.size(1),
287     input_.size(0));
288   dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
289 
290   AT_DISPATCH_FLOATING_TYPES_AND2(
291     at::ScalarType::Half,
292     at::ScalarType::BFloat16,
293     input.scalar_type(),
294     "fractional_max_pool3d_out_frame",
295     [&]{
296       fractional_max_pool3d_out_frame<scalar_t>
297       <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
298         input_.packed_accessor64<const scalar_t, 5>(),
299         output_.packed_accessor64<scalar_t, 5>(),
300         indices_.packed_accessor64<int64_t, 5>(),
301         randomSamples.packed_accessor64<const scalar_t, 3>(),
302         poolSizeT, poolSizeH, poolSizeW
303       );
304       C10_CUDA_KERNEL_LAUNCH_CHECK();
305     }
306   );
307 }
308 
fractional_max_pool3d_backward_out_cuda(const at::Tensor & gradOutput_,const at::Tensor & input,IntArrayRef,IntArrayRef output_size,const at::Tensor & indices,at::Tensor & gradInput)309 Tensor& fractional_max_pool3d_backward_out_cuda(const at::Tensor& gradOutput_,
310   const at::Tensor& input,
311   IntArrayRef /*pool_size*/,
312   IntArrayRef output_size,
313   const at::Tensor& indices,
314   at::Tensor& gradInput) {
315     // See Note [Writing Nondeterministic Operations]
316     // Nondeterministic because of atomicAdd usage
317     globalContext().alertNotDeterministic("fractional_max_pool3d_backward_out_cuda");
318     fractional_max_pool3d_backward_out_cuda_template(
319       gradInput,
320       gradOutput_,
321       input,
322       output_size,
323       indices
324     );
325     return gradInput;
326   }
327 
fractional_max_pool3d_backward_cuda(const at::Tensor & gradOutput,const at::Tensor & input,IntArrayRef pool_size,IntArrayRef output_size,const at::Tensor & indices)328 Tensor fractional_max_pool3d_backward_cuda(
329   const at::Tensor& gradOutput,
330   const at::Tensor& input,
331   IntArrayRef pool_size,
332   IntArrayRef output_size,
333   const at::Tensor& indices) {
334     // See Note [Writing Nondeterministic Operations]
335     // Nondeterministic because of atomicAdd usage
336     globalContext().alertNotDeterministic("fractional_max_pool3d_backward_cuda");
337     Tensor gradInput = at::empty({0}, input.options());
338     fractional_max_pool3d_backward_out_cuda_template(
339       gradInput,
340       gradOutput,
341       input,
342       output_size,
343       indices
344     );
345     return gradInput;
346  }
347 
348 }// namespace at::native
349