1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/Atomic.cuh>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/NumericLimits.cuh>
8 #include <ATen/cuda/detail/IndexUtils.cuh>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <ATen/NumericUtils.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/Utils.h>
13 #include <ATen/native/FractionalMaxPooling.h>
14 #include <c10/macros/Macros.h>
15 #include <c10/util/Exception.h>
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/fractional_max_pool2d_backward_native.h>
20 #include <ATen/ops/fractional_max_pool2d_native.h>
21 #endif
22
23 #include <algorithm>
24 #include <cfloat>
25 #include <cmath>
26
27 namespace at::native {
28
29 using namespace at::cuda::detail;
30
31 namespace {
32
33 template <typename scalar_t, typename accscalar_t>
get_interval(accscalar_t sample,int index,int inputSize,int outputSize,int poolSize)34 __device__ inline int get_interval(accscalar_t sample,
35 int index, int inputSize, int outputSize, int poolSize) {
36 accscalar_t alpha = static_cast<accscalar_t>(inputSize - poolSize) /
37 static_cast<accscalar_t>(outputSize - 1);
38 if (index == outputSize - 1) {
39 return inputSize - poolSize;
40 } else {
41 return static_cast<int>((index + sample) * alpha) -
42 static_cast<int>(sample * alpha);
43 }
44 }
45
46 template <typename scalar_t>
fractional_max_pool2d_out_cuda_frame(PackedTensorAccessor<scalar_t,4> output,PackedTensorAccessor<int64_t,4> indices,PackedTensorAccessor<const scalar_t,4> input,PackedTensorAccessor<const scalar_t,3> samples,int poolSizeH,int poolSizeW)47 __global__ void fractional_max_pool2d_out_cuda_frame(
48 PackedTensorAccessor<scalar_t, 4> output,
49 PackedTensorAccessor<int64_t, 4> indices,
50 PackedTensorAccessor<const scalar_t, 4> input,
51 PackedTensorAccessor<const scalar_t, 3> samples,
52 int poolSizeH, int poolSizeW) {
53
54 using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
55
56 int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
57 int plane = blockIdx.y;
58 int batch = blockIdx.z;
59
60 // Each thread generates a specific output point
61 if (ourOutputPoint < output.size(2) * output.size(3)) {
62 int outputW = ourOutputPoint % output.size(3);
63 int outputH = ourOutputPoint / output.size(3);
64
65 int poolW = get_interval<scalar_t, accscalar_t>(
66 static_cast<accscalar_t>(samples[batch][plane][0]),
67 outputW, input.size(3), output.size(3), poolSizeW);
68 int poolH = get_interval<scalar_t, accscalar_t>(
69 static_cast<accscalar_t>(samples[batch][plane][1]),
70 outputH, input.size(2), output.size(2), poolSizeH);
71
72 scalar_t maxVal = at::numeric_limits<scalar_t>::lower_bound();
73 int maxIndex = poolH * input.size(3) + poolW;
74
75 for (int h = poolH; h < poolH + poolSizeH; ++h) {
76 if (poolSizeW < 2 || poolSizeW > 7) {
77 for (int w = poolW; w < poolW + poolSizeW; ++w) {
78 scalar_t val = input[batch][plane][h][w];
79 // for consistency with THNN, favor the first max
80 if (val > maxVal || at::_isnan(val)) {
81 maxIndex = h * input.size(3) + w;
82 maxVal = val;
83 }
84 }
85 } else {
86 for (int i = 0; i < poolSizeW; ++i) {
87 int w = i + poolW;
88 scalar_t val = input[batch][plane][h][w];
89 // for consistency with THNN, favor the first max
90 if (val > maxVal || at::_isnan(val)) {
91 maxIndex = h * input.size(3) + w;
92 maxVal = val;
93 }
94 }
95 }
96 }
97
98 indices[batch][plane][outputH][outputW] = maxIndex;
99 output[batch][plane][outputH][outputW] = maxVal;
100 }
101 }
102
103 template <typename scalar_t>
fractional_max_pool2d_backward_out_cuda_frame(PackedTensorAccessor<scalar_t,4> gradInput,PackedTensorAccessor<const scalar_t,4> gradOutput,PackedTensorAccessor<const int64_t,4> indices)104 __global__ void fractional_max_pool2d_backward_out_cuda_frame(
105 PackedTensorAccessor<scalar_t, 4> gradInput,
106 PackedTensorAccessor<const scalar_t, 4> gradOutput,
107 PackedTensorAccessor<const int64_t, 4> indices) {
108 // Output (h, w) point that this thread is responsible for
109 int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
110 int plane = blockIdx.y;
111 int batch = blockIdx.z;
112
113 // Each thread generates a specific output point
114 if (ourOutputPoint < gradOutput.size(2) *
115 gradOutput.size(3)) {
116 int outputW = ourOutputPoint % gradOutput.size(3);
117 int outputH = ourOutputPoint / gradOutput.size(3);
118
119 int index = indices[batch][plane][outputH][outputW];
120 CUDA_KERNEL_ASSERT(index >= 0);
121 int inputW = index % gradInput.size(3);
122 int inputH = index / gradInput.size(3);
123 CUDA_KERNEL_ASSERT(inputH < gradInput.size(2));
124
125 gpuAtomicAddNoReturn(
126 &gradInput[batch][plane][inputH][inputW],
127 gradOutput[batch][plane][outputH][outputW]
128 );
129 }
130 }
131
132 } // anonymous namespace
133
TORCH_IMPL_FUNC(fractional_max_pool2d_out_cuda)134 TORCH_IMPL_FUNC(fractional_max_pool2d_out_cuda) (
135 const Tensor& input,
136 IntArrayRef pool_size,
137 IntArrayRef output_size,
138 const Tensor& randomSamples,
139 const Tensor& output,
140 const Tensor& indices
141 ) {
142 fractional_max_pool_check_shape</*ndim*/ 2>(input, randomSamples);
143
144 int planeDim = 0;
145
146 int ndims = input.ndimension();
147
148 if (ndims == 4) {
149 planeDim++;
150 }
151
152 /* sizes */
153 int numPlanes = input.size(planeDim);
154
155 int outputH = output_size[0];
156 int outputW = output_size[1];
157 int poolSizeH = pool_size[0];
158 int poolSizeW = pool_size[1];
159
160 auto output_ = output;
161 auto input_ = input;
162 auto indices_ = indices;
163
164 if(ndims == 3) {
165 output_ = output_.reshape({1, numPlanes, outputH, outputW});
166 indices_ = indices_.reshape({1, numPlanes, outputH, outputW});
167 input_ = input_.reshape({1, input.size(0), input.size(1), input.size(2)});
168 }
169
170 if (output_.numel() == 0) {
171 return;
172 }
173
174 // block is limited to 4 warps
175 // grid handles overflow per each plane
176 int outputPlaneSize = output_.size(2) *
177 output_.size(3);
178 dim3 grid((outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
179 input_.size(1),
180 input_.size(0));
181 dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
182
183 AT_DISPATCH_FLOATING_TYPES_AND2(
184 at::ScalarType::Half,
185 at::ScalarType::BFloat16,
186 input.scalar_type(),
187 "fractional_max_pool2d_out_cuda_frame",
188 [&] {
189 auto devInput = input_.packed_accessor64<const scalar_t, 4>();
190 auto devOutput = output_.packed_accessor64<scalar_t, 4>();
191 auto devIndices = indices_.packed_accessor64<int64_t, 4>();
192 auto devSamples = randomSamples.packed_accessor64<const scalar_t, 3>();
193 fractional_max_pool2d_out_cuda_frame<scalar_t>
194 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
195 devOutput, devIndices, devInput, devSamples,
196 poolSizeH, poolSizeW);
197 C10_CUDA_KERNEL_LAUNCH_CHECK();
198 }
199 );
200 }
201
TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cuda)202 TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cuda)(
203 const Tensor& gradOutput,
204 const Tensor& input,
205 IntArrayRef pool_size /* unused */,
206 IntArrayRef output_size,
207 const Tensor& indices,
208 const Tensor& gradInput)
209 {
210
211 // See Note [Writing Nondeterministic Operations]
212 // Nondeterministic because of atomicAdd usage
213 globalContext().alertNotDeterministic("fractional_max_pool2d_backward_cuda");
214
215 int dimh = 1;
216 int dimw = 2;
217
218 int ndims = input.ndimension();
219 if (ndims == 4) {
220 dimh++;
221 dimw++;
222 }
223
224 /* sizes */
225 int inputH = input.size(dimh);
226 int inputW = input.size(dimw);
227
228 int outputH = output_size[0];
229 int outputW = output_size[1];
230
231 if (gradInput.numel() == 0) {
232 return;
233 }
234
235 gradInput.zero_();
236
237 auto gradInput_ = gradInput;
238 auto gradOutput_ = gradOutput;
239 auto indices_ = indices;
240
241 if(ndims == 3) {
242 gradInput_ = gradInput_.reshape({1, input.size(0), inputH, inputW});
243 gradOutput_ = gradOutput_.reshape({1, gradOutput.size(0), outputH, outputW});
244 indices_ = indices_.reshape({1, indices_.size(0), outputH, outputW});
245 }
246
247 /* backprop */
248 // block is limited to 4 warps
249 // grid handles overflow per each plane
250 int outputPlaneSize = gradOutput_.size(2) *
251 gradOutput_.size(3);
252 dim3 grid((outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
253 gradInput_.size(1),
254 gradInput_.size(0));
255 dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
256
257 auto devIndices = indices_.packed_accessor64<const int64_t, 4>();
258 AT_DISPATCH_FLOATING_TYPES_AND2(
259 at::ScalarType::Half,
260 at::ScalarType::BFloat16,
261 gradOutput.scalar_type(),
262 "fractional_max_pool2d_backward_out_cuda_frame",
263 [&] {
264 auto devGradInput = gradInput_.packed_accessor64<scalar_t, 4>();
265 auto devGradOutput = gradOutput_.packed_accessor64<const scalar_t, 4>();
266 fractional_max_pool2d_backward_out_cuda_frame<scalar_t>
267 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
268 devGradInput, devGradOutput, devIndices);
269 C10_CUDA_KERNEL_LAUNCH_CHECK();
270 }
271 );
272 }
273
274 }// at::native
275