1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/Atomic.cuh>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/NumericLimits.cuh>
6 #include <ATen/Dispatch.h>
7 #include <ATen/NumericUtils.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <c10/util/Exception.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/adaptive_max_pool2d_backward_native.h>
17 #include <ATen/ops/adaptive_max_pool2d_native.h>
18 #include <ATen/ops/empty.h>
19 #endif
20
21 #include <algorithm>
22 #include <cfloat>
23 #include <cmath>
24
25
26 namespace at::native {
27
28 namespace {
29
start_index(int64_t a,int64_t b,int64_t c)30 __device__ inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
31 return (a / b) * c + ((a % b) * c) / b;
32 }
33
end_index(int64_t a,int64_t b,int64_t c)34 __device__ inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
35 return 1 + ((a + 1) * c - 1) / b;
36 }
37
38 // 4d tensor B x D x H x W
39
40 /*
41 * Description:
42 * this function adaptively maxpools an input 4D tensor along dimensions 2 and 3
43 * 4D input, 4D output, 4D argmax x and y
44 */
45 template <typename T>
adaptivemaxpool(const T * input,T * output,int64_t * indices,int isizeH,int isizeW,int osizeH,int osizeW,int64_t istrideD,int64_t istrideH,int64_t istrideW)46 __global__ void adaptivemaxpool(const T *input, T *output, int64_t *indices,
47 int isizeH, int isizeW,
48 int osizeH, int osizeW,
49 int64_t istrideD, int64_t istrideH, int64_t istrideW)
50 {
51 // iterators
52 int oh, ow;
53
54 // compute offsets based on thread/block ID
55 int o_plane = blockIdx.x;
56 int i_plane = o_plane;
57
58 int ostartW = threadIdx.x;
59 int oendW = osizeW;
60 const int ostepW = blockDim.x;
61
62 int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
63 int oendH = osizeH;
64 const int ostepH = blockDim.y*gridDim.y;
65 // select input/output plane
66 output = output + o_plane*osizeH*osizeW;
67 input = input + i_plane*istrideD;
68 indices = indices + o_plane*osizeH*osizeW;
69
70 // For all output pixels...
71 for(oh = ostartH; oh < oendH; oh += ostepH) {
72
73 int istartH = start_index(oh, osizeH, isizeH);
74 int iendH = end_index(oh, osizeH, isizeH);
75 int kH = iendH - istartH;
76
77 for(ow = ostartW; ow < oendW; ow += ostepW) {
78 int istartW = start_index(ow, osizeW, isizeW);
79 int iendW = end_index(ow, osizeW, isizeW);
80
81 int kW = iendW - istartW;
82
83 // Compute the mean of the input image...
84 const T *ptr_input = input + istartH*istrideH + istartW*istrideW;
85 T *ptr_output = output + oh*osizeW + ow;
86 int64_t *ptr_ind = indices + oh*osizeW + ow;
87 int argmax = istartH * isizeW + istartW;
88 T max = at::numeric_limits<T>::lower_bound(); // -Infinity
89 int ih, iw;
90 for(ih = 0; ih < kH; ih++) {
91 for(iw = 0; iw < kW; iw++) {
92 T val = ptr_input[iw*istrideW];
93 if ((val > max) || at::_isnan(val)) {
94 max = val;
95 argmax = (ih+istartH)*isizeW + iw+istartW;
96 }
97 }
98 ptr_input += istrideH; // next input line
99 }
100 // Update output and argmax
101 *ptr_output = max;
102 *ptr_ind = argmax;
103 }
104 }
105 }
106
107 /*
108 * Description:
109 * this function computes the gradInput from weight and gradOutput
110 */
111 template <typename T>
adaptivemaxgradinput(T * gradInput,const T * gradOutput,const int64_t * indices,int isizeH,int isizeW,int osizeH,int osizeW)112 __global__ void adaptivemaxgradinput(T *gradInput, const T *gradOutput, const int64_t *indices,
113 int isizeH, int isizeW,
114 int osizeH, int osizeW)
115 {
116 // iterators
117 int oh, ow;
118
119 // compute offsets based on thread/block ID
120 int o_plane = blockIdx.x;
121 int i_plane = o_plane;
122 //int k = blockIdx.x % sizeD;
123
124 int ostartW = threadIdx.x;
125 int oendW = osizeW;
126 int ostepW = blockDim.x;
127
128 int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
129 int oendH = osizeH;
130 int ostepH = blockDim.y*gridDim.y;
131
132 // select input/output plane
133 gradOutput = gradOutput + o_plane*osizeH*osizeW;
134 gradInput = gradInput + i_plane*isizeH*isizeW;
135 indices = indices + o_plane*osizeH*osizeW;
136
137 // compute gradInput
138 for(oh = ostartH; oh < oendH; oh += ostepH) {
139
140 for(ow = ostartW; ow < oendW; ow += ostepW) {
141
142 const T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
143 const int64_t *ptr_ind = indices + oh*osizeW + ow;
144 T z = *ptr_gradOutput;
145
146 int argmax = (*ptr_ind);
147
148 gradInput[argmax] += z;
149 }
150 }
151 }
152
153 /*
154 * Description:
155 * this function computes the gradInput from weight and gradOutput
156 * when kH != dH or kW != dW (uses atomic add)
157 */
158 template <typename T>
atomicadaptivemaxgradinput(T * gradInput,const T * gradOutput,const int64_t * indices,int isizeH,int isizeW,int osizeH,int osizeW)159 __global__ void atomicadaptivemaxgradinput(
160 T *gradInput, const T *gradOutput, const int64_t *indices,
161 int isizeH, int isizeW, int osizeH, int osizeW
162 )
163 {
164 // iterators
165 int oh, ow;
166
167 // compute offsets based on thread/block ID
168 int o_plane = blockIdx.x;
169 int i_plane = o_plane;
170
171 int ostartW = threadIdx.x;
172 int oendW = osizeW;
173 int ostepW = blockDim.x;
174
175 int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
176 int oendH = osizeH;
177 int ostepH = blockDim.y*gridDim.y;
178
179 // select input/output plane
180 gradOutput = gradOutput + o_plane*osizeH*osizeW;
181 gradInput = gradInput + i_plane*isizeH*isizeW;
182 indices = indices + o_plane*osizeH*osizeW;
183
184 // compute gradInput
185 for(oh = ostartH; oh < oendH; oh += ostepH) {
186
187 for(ow = ostartW; ow < oendW; ow += ostepW) {
188
189 const T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
190 const int64_t *ptr_ind = indices + oh*osizeW + ow;
191 T z = *ptr_gradOutput;
192
193 int argmax = (*ptr_ind);
194
195 // atomic add since different threads could update same variable
196 gpuAtomicAddNoReturn(&(gradInput[argmax]), z);
197 }
198 }
199 }
200 } // namespace
201
202 // 4d tensor B x D x H x W
203
TORCH_IMPL_FUNC(adaptive_max_pool2d_out_cuda)204 TORCH_IMPL_FUNC(adaptive_max_pool2d_out_cuda)
205 (const Tensor& input,
206 IntArrayRef output_size,
207 const Tensor& output,
208 const Tensor& indices) {
209 TensorArg output_arg{output, "output", 1};
210 TensorArg indices_arg{indices, "indices", 2};
211 TensorArg input_arg{input, "input", 3};
212
213 checkAllSameGPU(
214 __func__, {output_arg, indices_arg, input_arg});
215 if (input.numel() == 0) {
216 return;
217 }
218
219 int64_t osizeH = output_size[0];
220 int64_t osizeW = output_size[1];
221
222 const at::Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
223 const at::Tensor indices_c = indices.is_contiguous() ? indices : at::empty(indices.sizes(), indices.options());
224
225 if (input.ndimension() == 3) {
226 int64_t sizeD = input.size(0);
227 int64_t isizeH = input.size(1);
228 int64_t isizeW = input.size(2);
229
230 int64_t istrideD = input.stride(0);
231 int64_t istrideH = input.stride(1);
232 int64_t istrideW = input.stride(2);
233
234 AT_DISPATCH_FLOATING_TYPES_AND2(
235 kHalf, kBFloat16, input.scalar_type(), "adaptive_max_pool2d_cuda", [&] {
236 const scalar_t* input_data = input.const_data_ptr<scalar_t>();
237 scalar_t* output_data = output_c.mutable_data_ptr<scalar_t>();
238 int64_t* indices_data = indices_c.mutable_data_ptr<int64_t>();
239
240 // cuda blocks & threads:
241 int blocksH = (int)(16L / sizeD);
242 blocksH = blocksH < 1 ? 1 : blocksH;
243 dim3 blocks(sizeD, blocksH);
244 dim3 threads(32, 8);
245
246 // run maxpool kernel
247 adaptivemaxpool<<<
248 blocks,
249 threads,
250 0,
251 at::cuda::getCurrentCUDAStream()>>>(
252 input_data,
253 output_data,
254 indices_data,
255 isizeH,
256 isizeW,
257 osizeH,
258 osizeW,
259 istrideD,
260 istrideH,
261 istrideW);
262 C10_CUDA_KERNEL_LAUNCH_CHECK();
263 });
264 } else {
265 Tensor input_ = input.contiguous();
266 int64_t sizeB = input_.size(0);
267 int64_t sizeD = input_.size(1);
268 int64_t isizeH = input_.size(2);
269 int64_t isizeW = input_.size(3);
270
271 // In the kernel, the batch and channel dimensions are treated as if they
272 // are flattened and istrideD is used as the stride of this flattened dim
273 // Handle the edge case where input_.size(1) == 1, where despite passing the
274 // contiguity check the stride might not be H * W
275 int64_t istrideD = isizeH * isizeW;
276 int64_t istrideH = input_.stride(2);
277 int64_t istrideW = input_.stride(3);
278
279 AT_DISPATCH_FLOATING_TYPES_AND2(
280 kHalf,
281 kBFloat16,
282 input_.scalar_type(),
283 "adaptive_max_pool2d_cuda",
284 [&] {
285 const scalar_t* input_data = input_.const_data_ptr<scalar_t>();
286 scalar_t* output_data = output_c.mutable_data_ptr<scalar_t>();
287 int64_t* indices_data = indices_c.mutable_data_ptr<int64_t>();
288
289 // cuda blocks & threads:
290 int blocksH = (int)(16L / sizeD);
291 blocksH = blocksH < 1 ? 1 : blocksH;
292 dim3 blocks(sizeB * sizeD, blocksH);
293 dim3 threads(32, 8);
294
295 // run maxpool kernel
296 adaptivemaxpool<<<
297 blocks,
298 threads,
299 0,
300 at::cuda::getCurrentCUDAStream()>>>(
301 input_data,
302 output_data,
303 indices_data,
304 isizeH,
305 isizeW,
306 osizeH,
307 osizeW,
308 istrideD,
309 istrideH,
310 istrideW);
311 C10_CUDA_KERNEL_LAUNCH_CHECK();
312 });
313 }
314
315 if (!output.is_contiguous()) {
316 output.copy_(output_c);
317 }
318 if (!indices.is_contiguous()) {
319 indices.copy_(indices_c);
320 }
321 }
322
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)323 TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
324 (const Tensor& gradOutput,
325 const Tensor& input,
326 const Tensor& indices,
327 const Tensor& gradInput) {
328 globalContext().alertNotDeterministic(
329 "adaptive_max_pool2d_backward_cuda");
330
331 TensorArg grad_input_arg{gradInput, "gradInput", 1};
332 TensorArg grad_output_arg{gradOutput, "gradOutput", 2};
333 TensorArg input_arg{input, "input", 3};
334 TensorArg indices_arg{indices, "indices", 4};
335
336 checkAllSameGPU(
337 __func__,
338 {grad_input_arg, grad_output_arg, input_arg, indices_arg});
339
340 if (gradOutput.numel() == 0) {
341 return;
342 }
343
344 bool atomic =
345 true; // suboptimal, but without atomic it doesn't pass the tests
346
347 const at::Tensor gradOutput_ = gradOutput.contiguous();
348 const at::Tensor indices_ = indices.contiguous();
349 const at::Tensor gradInput_c = gradInput.is_contiguous() ? gradInput : at::empty(gradInput.sizes(), gradInput.options());
350
351 if (input.ndimension() == 3) {
352 int64_t sizeD = input.size(0);
353 int64_t isizeH = input.size(1);
354 int64_t isizeW = input.size(2);
355
356 int64_t osizeH = gradOutput_.size(1);
357 int64_t osizeW = gradOutput_.size(2);
358
359 // bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
360
361 gradInput_c.zero_();
362
363 AT_DISPATCH_FLOATING_TYPES_AND2(
364 kHalf,
365 kBFloat16,
366 input.scalar_type(),
367 "adaptive_max_pool2d_backward_cuda",
368 [&] {
369 scalar_t* gradInput_data = gradInput_c.mutable_data_ptr<scalar_t>();
370 const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
371 const int64_t* indices_data = indices_.const_data_ptr<int64_t>();
372
373 // cuda blocks & threads:
374 int blocksH = (int)(16L / sizeD);
375 blocksH = blocksH < 1 ? 1 : blocksH;
376 dim3 blocks(sizeD, blocksH);
377 dim3 threads(32, 8);
378
379 if (atomic) {
380 // run updateGradInput kernel, accumulate gradients atomically
381 atomicadaptivemaxgradinput<<<
382 blocks,
383 threads,
384 0,
385 at::cuda::getCurrentCUDAStream()>>>(
386 gradInput_data,
387 gradOutput_data,
388 indices_data,
389 isizeH,
390 isizeW,
391 osizeH,
392 osizeW);
393 C10_CUDA_KERNEL_LAUNCH_CHECK();
394 } else {
395 // run updateGradInput kernel
396 atomicadaptivemaxgradinput<<<
397 blocks,
398 threads,
399 0,
400 at::cuda::getCurrentCUDAStream()>>>(
401 gradInput_data,
402 gradOutput_data,
403 indices_data,
404 isizeH,
405 isizeW,
406 osizeH,
407 osizeW);
408 C10_CUDA_KERNEL_LAUNCH_CHECK();
409 }
410 });
411 } else {
412 int64_t sizeB = input.size(0);
413 int64_t sizeD = input.size(1);
414 int64_t isizeH = input.size(2);
415 int64_t isizeW = input.size(3);
416
417 int64_t osizeH = gradOutput_.size(2);
418 int64_t osizeW = gradOutput_.size(3);
419
420 gradInput_c.zero_();
421
422 // bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
423
424 AT_DISPATCH_FLOATING_TYPES_AND2(
425 kHalf,
426 kBFloat16,
427 input.scalar_type(),
428 "adaptive_max_pool2d_backward_cuda",
429 [&] {
430 scalar_t* gradInput_data = gradInput_c.mutable_data_ptr<scalar_t>();
431 const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
432 const int64_t* indices_data = indices_.const_data_ptr<int64_t>();
433
434 // cuda blocks & threads:
435 int blocksH = (int)(16L / sizeD);
436 blocksH = blocksH < 1 ? 1 : blocksH;
437 dim3 blocks(sizeB * sizeD, blocksH);
438 dim3 threads(32, 8);
439
440 if (atomic) {
441 // run updateGradInput kernel, accumulate gradients atomically
442 atomicadaptivemaxgradinput<<<
443 blocks,
444 threads,
445 0,
446 at::cuda::getCurrentCUDAStream()>>>(
447 gradInput_data,
448 gradOutput_data,
449 indices_data,
450 isizeH,
451 isizeW,
452 osizeH,
453 osizeW);
454 C10_CUDA_KERNEL_LAUNCH_CHECK();
455 } else {
456 // run updateGradInput kernel, accumulate gradients atomically
457 adaptivemaxgradinput<<<
458 blocks,
459 threads,
460 0,
461 at::cuda::getCurrentCUDAStream()>>>(
462 gradInput_data,
463 gradOutput_data,
464 indices_data,
465 isizeH,
466 isizeW,
467 osizeH,
468 osizeW);
469 C10_CUDA_KERNEL_LAUNCH_CHECK();
470 }
471 });
472 }
473
474 if (!gradInput.is_contiguous()) {
475 gradInput.copy_(gradInput_c);
476 }
477 }
478 } // namespace at::native
479