xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/Atomic.cuh>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/NumericLimits.cuh>
8 #include <ATen/cuda/detail/IndexUtils.cuh>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <ATen/NumericUtils.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/Utils.h>
13 #include <ATen/native/FractionalMaxPooling.h>
14 #include <c10/macros/Macros.h>
15 #include <c10/util/Exception.h>
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/fractional_max_pool2d_backward_native.h>
20 #include <ATen/ops/fractional_max_pool2d_native.h>
21 #endif
22 
23 #include <algorithm>
24 #include <cfloat>
25 #include <cmath>
26 
27 namespace at::native {
28 
29 using namespace at::cuda::detail;
30 
31 namespace {
32 
33 template <typename scalar_t, typename accscalar_t>
get_interval(accscalar_t sample,int index,int inputSize,int outputSize,int poolSize)34 __device__ inline int get_interval(accscalar_t sample,
35   int index, int inputSize, int outputSize, int poolSize) {
36   accscalar_t alpha = static_cast<accscalar_t>(inputSize - poolSize) /
37     static_cast<accscalar_t>(outputSize - 1);
38   if (index == outputSize - 1) {
39     return inputSize - poolSize;
40   } else {
41     return static_cast<int>((index + sample) * alpha) -
42       static_cast<int>(sample * alpha);
43   }
44 }
45 
46 template <typename scalar_t>
fractional_max_pool2d_out_cuda_frame(PackedTensorAccessor<scalar_t,4> output,PackedTensorAccessor<int64_t,4> indices,PackedTensorAccessor<const scalar_t,4> input,PackedTensorAccessor<const scalar_t,3> samples,int poolSizeH,int poolSizeW)47 __global__ void fractional_max_pool2d_out_cuda_frame(
48   PackedTensorAccessor<scalar_t, 4> output,
49   PackedTensorAccessor<int64_t, 4> indices,
50   PackedTensorAccessor<const scalar_t, 4> input,
51   PackedTensorAccessor<const scalar_t, 3> samples,
52   int poolSizeH, int poolSizeW) {
53 
54   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
55 
56   int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
57   int plane = blockIdx.y;
58   int batch = blockIdx.z;
59 
60   // Each thread generates a specific output point
61   if (ourOutputPoint < output.size(2) * output.size(3)) {
62     int outputW = ourOutputPoint % output.size(3);
63     int outputH = ourOutputPoint / output.size(3);
64 
65     int poolW = get_interval<scalar_t, accscalar_t>(
66       static_cast<accscalar_t>(samples[batch][plane][0]),
67         outputW, input.size(3), output.size(3), poolSizeW);
68     int poolH = get_interval<scalar_t, accscalar_t>(
69       static_cast<accscalar_t>(samples[batch][plane][1]),
70         outputH, input.size(2), output.size(2), poolSizeH);
71 
72     scalar_t maxVal = at::numeric_limits<scalar_t>::lower_bound();
73     int maxIndex = poolH * input.size(3) + poolW;
74 
75     for (int h = poolH; h < poolH + poolSizeH; ++h) {
76       if (poolSizeW < 2 || poolSizeW > 7) {
77         for (int w = poolW; w < poolW + poolSizeW; ++w) {
78           scalar_t val = input[batch][plane][h][w];
79           // for consistency with THNN, favor the first max
80           if (val > maxVal || at::_isnan(val)) {
81             maxIndex = h * input.size(3) + w;
82             maxVal = val;
83           }
84         }
85       } else {
86         for (int i = 0; i < poolSizeW; ++i) {
87           int w = i + poolW;
88           scalar_t val = input[batch][plane][h][w];
89           // for consistency with THNN, favor the first max
90           if (val > maxVal || at::_isnan(val)) {
91             maxIndex = h * input.size(3) + w;
92             maxVal = val;
93           }
94         }
95       }
96     }
97 
98     indices[batch][plane][outputH][outputW] = maxIndex;
99     output[batch][plane][outputH][outputW] = maxVal;
100   }
101 }
102 
103 template <typename scalar_t>
fractional_max_pool2d_backward_out_cuda_frame(PackedTensorAccessor<scalar_t,4> gradInput,PackedTensorAccessor<const scalar_t,4> gradOutput,PackedTensorAccessor<const int64_t,4> indices)104 __global__ void fractional_max_pool2d_backward_out_cuda_frame(
105   PackedTensorAccessor<scalar_t, 4> gradInput,
106   PackedTensorAccessor<const scalar_t, 4> gradOutput,
107   PackedTensorAccessor<const int64_t, 4> indices) {
108   // Output (h, w) point that this thread is responsible for
109   int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
110   int plane = blockIdx.y;
111   int batch = blockIdx.z;
112 
113   // Each thread generates a specific output point
114   if (ourOutputPoint < gradOutput.size(2) *
115     gradOutput.size(3)) {
116     int outputW = ourOutputPoint % gradOutput.size(3);
117     int outputH = ourOutputPoint / gradOutput.size(3);
118 
119     int index = indices[batch][plane][outputH][outputW];
120     CUDA_KERNEL_ASSERT(index >= 0);
121     int inputW = index % gradInput.size(3);
122     int inputH = index / gradInput.size(3);
123     CUDA_KERNEL_ASSERT(inputH < gradInput.size(2));
124 
125     gpuAtomicAddNoReturn(
126       &gradInput[batch][plane][inputH][inputW],
127       gradOutput[batch][plane][outputH][outputW]
128     );
129   }
130 }
131 
132 } // anonymous namespace
133 
TORCH_IMPL_FUNC(fractional_max_pool2d_out_cuda)134 TORCH_IMPL_FUNC(fractional_max_pool2d_out_cuda) (
135   const Tensor& input,
136   IntArrayRef pool_size,
137   IntArrayRef output_size,
138   const Tensor& randomSamples,
139   const Tensor& output,
140   const Tensor& indices
141 ) {
142   fractional_max_pool_check_shape</*ndim*/ 2>(input, randomSamples);
143 
144   int planeDim = 0;
145 
146   int ndims = input.ndimension();
147 
148   if (ndims == 4) {
149     planeDim++;
150   }
151 
152   /* sizes */
153   int numPlanes = input.size(planeDim);
154 
155   int outputH = output_size[0];
156   int outputW = output_size[1];
157   int poolSizeH = pool_size[0];
158   int poolSizeW = pool_size[1];
159 
160   auto output_ = output;
161   auto input_ = input;
162   auto indices_ = indices;
163 
164   if(ndims == 3) {
165     output_ = output_.reshape({1, numPlanes, outputH, outputW});
166     indices_ = indices_.reshape({1, numPlanes, outputH, outputW});
167     input_ = input_.reshape({1, input.size(0), input.size(1), input.size(2)});
168   }
169 
170   if (output_.numel() == 0) {
171     return;
172   }
173 
174   // block is limited to 4 warps
175   // grid handles overflow per each plane
176   int outputPlaneSize = output_.size(2) *
177     output_.size(3);
178   dim3 grid((outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
179             input_.size(1),
180             input_.size(0));
181   dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
182 
183   AT_DISPATCH_FLOATING_TYPES_AND2(
184     at::ScalarType::Half,
185     at::ScalarType::BFloat16,
186     input.scalar_type(),
187     "fractional_max_pool2d_out_cuda_frame",
188     [&] {
189       auto devInput = input_.packed_accessor64<const scalar_t, 4>();
190       auto devOutput = output_.packed_accessor64<scalar_t, 4>();
191       auto devIndices = indices_.packed_accessor64<int64_t, 4>();
192       auto devSamples = randomSamples.packed_accessor64<const scalar_t, 3>();
193       fractional_max_pool2d_out_cuda_frame<scalar_t>
194         <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
195           devOutput, devIndices, devInput, devSamples,
196           poolSizeH, poolSizeW);
197       C10_CUDA_KERNEL_LAUNCH_CHECK();
198      }
199    );
200 }
201 
TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cuda)202 TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cuda)(
203   const Tensor& gradOutput,
204   const Tensor& input,
205   IntArrayRef pool_size /* unused */,
206   IntArrayRef output_size,
207   const Tensor& indices,
208   const Tensor& gradInput)
209 {
210 
211   // See Note [Writing Nondeterministic Operations]
212   // Nondeterministic because of atomicAdd usage
213   globalContext().alertNotDeterministic("fractional_max_pool2d_backward_cuda");
214 
215   int dimh = 1;
216   int dimw = 2;
217 
218   int ndims = input.ndimension();
219   if (ndims == 4) {
220     dimh++;
221     dimw++;
222   }
223 
224   /* sizes */
225   int inputH = input.size(dimh);
226   int inputW = input.size(dimw);
227 
228   int outputH = output_size[0];
229   int outputW = output_size[1];
230 
231   if (gradInput.numel() == 0) {
232     return;
233   }
234 
235   gradInput.zero_();
236 
237   auto gradInput_ = gradInput;
238   auto gradOutput_ = gradOutput;
239   auto indices_ = indices;
240 
241   if(ndims == 3) {
242     gradInput_ = gradInput_.reshape({1, input.size(0), inputH, inputW});
243     gradOutput_ = gradOutput_.reshape({1, gradOutput.size(0), outputH, outputW});
244     indices_ = indices_.reshape({1, indices_.size(0), outputH, outputW});
245   }
246 
247   /* backprop */
248   // block is limited to 4 warps
249   // grid handles overflow per each plane
250   int outputPlaneSize = gradOutput_.size(2) *
251     gradOutput_.size(3);
252   dim3 grid((outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
253             gradInput_.size(1),
254             gradInput_.size(0));
255   dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
256 
257   auto devIndices = indices_.packed_accessor64<const int64_t, 4>();
258   AT_DISPATCH_FLOATING_TYPES_AND2(
259     at::ScalarType::Half,
260     at::ScalarType::BFloat16,
261     gradOutput.scalar_type(),
262     "fractional_max_pool2d_backward_out_cuda_frame",
263     [&] {
264       auto devGradInput = gradInput_.packed_accessor64<scalar_t, 4>();
265       auto devGradOutput = gradOutput_.packed_accessor64<const scalar_t, 4>();
266       fractional_max_pool2d_backward_out_cuda_frame<scalar_t>
267         <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
268         devGradInput, devGradOutput, devIndices);
269       C10_CUDA_KERNEL_LAUNCH_CHECK();
270     }
271   );
272 }
273 
274 }// at::native
275