1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/NamedTensorUtils.h>
7 #include <ATen/NumericUtils.h>
8 #include <ATen/native/Pool.h>
9 #include <ATen/cuda/Atomic.cuh>
10 #include <ATen/cuda/CUDAContext.h>
11 #include <ATen/cuda/NumericLimits.cuh>
12 #include <ATen/cuda/detail/TensorInfo.cuh>
13 #include <ATen/cuda/detail/IndexUtils.cuh>
14 #include <ATen/cuda/detail/KernelUtils.h>
15 #include <c10/macros/Macros.h>
16
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #include <ATen/NativeFunctions.h>
20 #else
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/max_pool3d_with_indices_native.h>
23 #include <ATen/ops/max_pool3d_with_indices_backward_native.h>
24 #endif
25
26 namespace at::native {
27 namespace {
28
min(int a,int b)29 __device__ inline int min(int a, int b) {
30 return a <= b ? a : b;
31 }
32
33 template <typename scalar_t>
max_pool3d_with_indices_single_out_frame(const scalar_t * inputData,scalar_t * outputData,int64_t * indicesData,int features,int itime,int iheight,int iwidth,int obatch,int otime,int oheight,int owidth,int kT,int kH,int kW,int dT,int dH,int dW,int pT,int pH,int pW,int dilationT,int dilationH,int dilationW,int offsetZ,bool channels_last)34 __global__ static void max_pool3d_with_indices_single_out_frame(
35 const scalar_t* inputData,
36 scalar_t* outputData,
37 int64_t* indicesData,
38 int features,
39 int itime, int iheight, int iwidth,
40 int obatch, int otime, int oheight, int owidth,
41 int kT, int kH, int kW,
42 int dT, int dH, int dW,
43 int pT, int pH, int pW,
44 int dilationT, int dilationH, int dilationW,
45 int offsetZ,
46 bool channels_last)
47 {
48 int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
49 int oRow = blockIdx.y * blockDim.y + threadIdx.y;
50 int oFrame = 0;
51 // used only for channels-first indexing
52 int64_t slice = 0;
53 // used only for channels-last indexing
54 int batch = 0;
55 int channel = 0;
56 if (!channels_last) {
57 // indexing order: batch, channel, time
58 oFrame = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % otime; // output frame/time
59 slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / otime; // output slice/feature
60 } else {
61 // indexing order: batch, time, channel
62 channel = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % features; // output feature (channel)
63 slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / features; // output slice (batch + time)
64 batch = slice / otime;
65 oFrame = slice % otime;
66 }
67
68 // For int64_t data type, see https://github.com/pytorch/pytorch/issues/52822
69 if (oRow < oheight && oColumn < owidth && oFrame < otime && channel < features && batch < obatch)
70 {
71 int tStart = oFrame * dT - pT;
72 int hStart = oRow * dH - pH;
73 int wStart = oColumn * dW - pW;
74 int tEnd = min(tStart + (kT - 1) * dilationT + 1, itime);
75 int hEnd = min(hStart + (kH - 1) * dilationH + 1, iheight);
76 int wEnd = min(wStart + (kW - 1) * dilationW + 1, iwidth);
77
78 while(tStart < 0)
79 tStart += dilationT;
80 while(hStart < 0)
81 hStart += dilationH;
82 while(wStart < 0)
83 wStart += dilationW;
84
85 // maxIndex remains in "channels-first"/contiguous
86 int64_t maxIndex = tStart * iheight * iwidth + hStart * iwidth + wStart;
87
88 if (!channels_last) {
89 inputData += (int64_t) slice * itime * iheight * iwidth;
90 } else {
91 inputData += ((int64_t) batch * itime * iheight * iwidth * features) + channel;
92 }
93
94 scalar_t max = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
95
96 for (int t = tStart; t < tEnd; t += dilationT)
97 {
98 for (int h = hStart; h < hEnd; h += dilationH)
99 {
100 for (int w = wStart; w < wEnd; w += dilationW)
101 {
102 scalar_t val;
103 int index = t * iheight * iwidth + h * iwidth + w;
104 if (!channels_last) {
105 val = inputData[index];
106 } else {
107 int64_t index_channels_last = index*features;
108 val = inputData[index_channels_last];
109 }
110
111 if ((max < val) || at::_isnan(val))
112 {
113 max = val;
114 maxIndex = index;
115 }
116 }
117 }
118 }
119
120 int64_t out_index;
121 if (!channels_last) {
122 out_index = (int64_t) slice*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn;
123 } else {
124 out_index = ((int64_t) batch*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn)*features + channel;
125 }
126 outputData[out_index] = max;
127 indicesData[out_index] = maxIndex;
128 }
129 }
130
131 template <typename scalar_t>
max_pool3d_with_indices_out_frame(const scalar_t * input_data,const Tensor & output,const Tensor & indices,int features,int64_t totalZ,int itime,int iheight,int iwidth,int obatch,int otime,int oheight,int owidth,int kT,int kH,int kW,int dT,int dH,int dW,int pT,int pH,int pW,int dilationT,int dilationH,int dilationW,bool channels_last)132 void max_pool3d_with_indices_out_frame(
133 const scalar_t* input_data,
134 const Tensor& output,
135 const Tensor& indices,
136 int features,
137 int64_t totalZ,
138 int itime, int iheight, int iwidth,
139 int obatch, int otime, int oheight, int owidth,
140 int kT, int kH, int kW,
141 int dT, int dH, int dW,
142 int pT, int pH, int pW,
143 int dilationT, int dilationH, int dilationW,
144 bool channels_last)
145 {
146 int offsetZ = 0;
147 int threadX = 32;
148 int threadY = 8;
149 int threadZ = 1;
150 int stepZ = 65535;
151 if (channels_last) {
152 threadX = 2;
153 threadY = 4;
154 threadZ = 64;
155 }
156 dim3 block(threadX, threadY, threadZ);
157
158 while (totalZ > 0) {
159 dim3 grid(ceil_div(owidth, static_cast<int>(block.x)),
160 ceil_div(oheight, static_cast<int>(block.y)),
161 totalZ > stepZ*threadZ ? stepZ : ceil_div(totalZ, static_cast<int64_t>(threadZ)));
162
163 max_pool3d_with_indices_single_out_frame
164 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
165 input_data,
166 output.mutable_data_ptr<scalar_t>(),
167 indices.mutable_data_ptr<int64_t>(),
168 features,
169 itime, iheight, iwidth,
170 obatch, otime, oheight, owidth,
171 kT, kH, kW,
172 dT, dH, dW,
173 pT, pH, pW,
174 dilationT, dilationH, dilationW,
175 offsetZ, channels_last);
176 C10_CUDA_KERNEL_LAUNCH_CHECK();
177
178 totalZ -= threadZ*stepZ;
179 offsetZ += threadZ*stepZ;
180 }
181 }
182
183 #undef UPDATE_OUTPUT_KERNEL_WIDTH
184
185 template <typename scalar_t>
max_pool3d_with_indices_backward_single_out_frame(scalar_t * gradInputData,const scalar_t * gradOutputData,const int64_t * indicesData,int features,int itime,int iheight,int iwidth,int obatch,int otime,int oheight,int owidth,int offsetZ,bool channels_last)186 __global__ static void max_pool3d_with_indices_backward_single_out_frame(
187 scalar_t *gradInputData,
188 const scalar_t *gradOutputData,
189 const int64_t *indicesData,
190 int features,
191 int itime, int iheight, int iwidth,
192 int obatch, int otime, int oheight, int owidth,
193 int offsetZ,
194 bool channels_last)
195 {
196 int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
197 int oRow = blockIdx.y * blockDim.y + threadIdx.y;
198
199 int oFrame = 0;
200 // used only for channels-first indexing
201 int64_t slice = 0;
202 // used only for channels-last indexing
203 int batch = 0;
204 int channel = 0;
205 if (!channels_last) {
206 // indexing order: batch, channel, time
207 oFrame = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % otime; // output frame/time
208 slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / otime; // output slice/feature
209 } else {
210 // indexing order: batch, time, channel
211 channel = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % features; // output feature (channel)
212 slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / features; // output slice (batch + time)
213 batch = slice / otime;
214 oFrame = slice % otime;
215 }
216
217 if (oRow < oheight && oColumn < owidth && oFrame < otime && batch < obatch && channel < features)
218 {
219 int64_t out_index;
220 if (!channels_last) {
221 out_index = (int64_t) slice*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn;
222 } else {
223 out_index = ((int64_t) batch*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn)*features + channel;
224 }
225 int64_t maxIndex = indicesData[out_index];
226 if (maxIndex != -1) {
227 if (!channels_last) {
228 gpuAtomicAddNoReturn(&gradInputData[(int64_t) slice * itime * iheight * iwidth + maxIndex],
229 gradOutputData[out_index]);
230 } else {
231 gpuAtomicAddNoReturn(&gradInputData[((int64_t) batch * itime * iheight * iwidth + maxIndex) * features + channel],
232 gradOutputData[out_index]);
233 }
234 }
235 }
236 }
237
238 template <typename scalar_t>
max_pool3d_with_indices_backward_out_frame(scalar_t * gradInputData,const Tensor & gradOutput,const Tensor & indices,int features,int64_t totalZ,int itime,int iheight,int iwidth,int obatch,int otime,int oheight,int owidth,bool channels_last)239 void max_pool3d_with_indices_backward_out_frame(
240 scalar_t *gradInputData,
241 const Tensor& gradOutput,
242 const Tensor& indices,
243 int features,
244 int64_t totalZ,
245 int itime, int iheight, int iwidth,
246 int obatch, int otime, int oheight, int owidth,
247 bool channels_last)
248 {
249 int offsetZ = 0;
250 int threadX = 32;
251 int threadY = 8;
252 int threadZ = 1;
253 int stepZ = 65535;
254 if (channels_last) {
255 threadX = 2;
256 threadY = 4;
257 threadZ = 64;
258 }
259 dim3 block(threadX, threadY, threadZ);
260
261 while (totalZ > 0) {
262 dim3 grid(ceil_div(owidth, static_cast<int>(block.x)),
263 ceil_div(oheight, static_cast<int>(block.y)),
264 totalZ > stepZ*threadZ ? stepZ : ceil_div(totalZ, static_cast<int64_t>(block.z)));
265
266 max_pool3d_with_indices_backward_single_out_frame
267 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
268 gradInputData,
269 gradOutput.const_data_ptr<scalar_t>(),
270 indices.const_data_ptr<int64_t>(),
271 features,
272 itime, iheight, iwidth,
273 obatch, otime, oheight, owidth,
274 offsetZ,
275 channels_last);
276 C10_CUDA_KERNEL_LAUNCH_CHECK();
277
278 totalZ -= threadZ*stepZ;
279 offsetZ += threadZ*stepZ;
280 }
281 }
282
max_pool3d_with_indices_out_cuda_template(Tensor & output,Tensor & indices,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)283 void max_pool3d_with_indices_out_cuda_template(
284 Tensor& output,
285 Tensor& indices,
286 const Tensor& input,
287 IntArrayRef kernel_size,
288 IntArrayRef stride,
289 IntArrayRef padding,
290 IntArrayRef dilation,
291 bool ceil_mode)
292 {
293 TensorArg output_arg{ output, "output", 1 };
294 TensorArg indices_arg{ indices, "indices", 2 };
295 TensorArg input_arg{ input, "input", 3 };
296
297 checkAllSameGPU(__func__,
298 {output_arg, indices_arg, input_arg});
299
300 // #20866, #22032: Guarantee this for the official C++ API?
301 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
302 "max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
303 const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
304 const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
305 const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
306
307 TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 3,
308 "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
309 const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
310 const int dH = stride.empty() ? kH :
311 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
312 const int dW = stride.empty() ? kW :
313 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
314
315 TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
316 "max_pool3d: padding must either be a single int, or a tuple of three ints");
317 const int pT = safe_downcast<int, int64_t>(padding[0]);
318 const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
319 const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
320
321 TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
322 "max_pool3d: dilation must be either a single int, or a tuple of three ints");
323 const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
324 const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
325 const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
326
327 const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
328 const int64_t nslices = input.size(-4);
329 const int64_t itime = input.size(-3);
330 const int64_t iheight = input.size(-2);
331 const int64_t iwidth = input.size(-1);
332
333 const int64_t otime = pooling_output_shape<int64_t>(itime, kT, pT, dT, dilationT, ceil_mode);
334 const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
335 const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
336
337 pool3d_shape_check(
338 input,
339 nslices,
340 kT, kH, kW,
341 dT, dH, dW,
342 pT, pH, pW,
343 dilationT, dilationH, dilationW,
344 itime, iheight, iwidth,
345 otime, oheight, owidth,
346 "max_pool3d_with_indices_out_cuda_template()");
347
348 bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;
349 Tensor _input = input;
350 if (input.ndimension() == 4) {
351 Tensor input_channels_last_check = input.unsqueeze(0);
352 // work around buggy behavior of suggest_memory_format here where
353 // suggested format of unsqueezed tensor is contiguous while it is
354 // really only contiguous in ChannelsLast3d
355 channels_last = (!input_channels_last_check.is_contiguous()) &&
356 input_channels_last_check.is_contiguous(at::MemoryFormat::ChannelsLast3d);
357 if (!channels_last) {
358 output.resize_({ nslices, otime, oheight, owidth});
359 indices.resize_({nslices, otime, oheight, owidth});
360 } else {
361 _input = input_channels_last_check;
362 output.resize_({1, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d);
363 indices.resize_({1, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d);
364 output = output.squeeze(0);
365 indices = indices.squeeze(0);
366 }
367 } else {
368 if (!channels_last) {
369 output.resize_({nbatch, nslices, otime, oheight, owidth});
370 indices.resize_({nbatch, nslices, otime, oheight, owidth});
371 } else {
372 output.resize_({nbatch, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d);
373 indices.resize_({nbatch, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d);
374 }
375 }
376
377 if (input.numel() == 0) {
378 return;
379 }
380
381 Tensor work_input;
382 Tensor work_output = output;
383 if (!channels_last) {
384 work_input = input.contiguous();
385 } else {
386 work_input = _input.contiguous(at::MemoryFormat::ChannelsLast3d);
387 }
388 Tensor work_indices = indices;
389
390 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
391 input.scalar_type(),
392 "max_pool3d_with_indices_out_frame",
393 [&]{
394 const scalar_t *input_data = work_input.const_data_ptr<scalar_t>();
395 const int64_t totalZ = otime * nslices * nbatch;
396
397 max_pool3d_with_indices_out_frame(
398 input_data, work_output, work_indices,
399 nslices, // features
400 totalZ,
401 itime, iheight, iwidth,
402 nbatch, otime, oheight, owidth,
403 kT, kH, kW,
404 dT, dH, dW,
405 pT, pH, pW,
406 dilationT, dilationH, dilationW, channels_last);
407 }
408 );
409 }
410
max_pool3d_with_indices_backward_out_cuda_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,const Tensor & indices,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)411 void max_pool3d_with_indices_backward_out_cuda_template(
412 Tensor& gradInput,
413 const Tensor& gradOutput,
414 const Tensor& input,
415 const Tensor& indices,
416 IntArrayRef kernel_size,
417 IntArrayRef stride,
418 IntArrayRef padding,
419 IntArrayRef dilation,
420 bool ceil_mode)
421 {
422 TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
423 TensorArg gradOutput_arg{ gradOutput, "gradOutput", 2 };
424 TensorArg input_arg{ input, "input", 3 };
425 TensorArg indices_arg{ indices, "indices", 4 };
426
427 checkAllSameGPU(__func__,
428 {gradInput_arg, gradOutput_arg, input_arg, indices_arg});
429
430 // #20866, #22032: Guarantee this for the official C++ API?
431 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
432 "max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
433 const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
434 const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
435 const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
436
437 TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 3,
438 "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
439 const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
440 const int dH = stride.empty() ? kH :
441 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
442 const int dW = stride.empty() ? kW :
443 stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
444
445 TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
446 "max_pool3d: padding must either be a single int, or a tuple of three ints");
447 const int pT = safe_downcast<int, int64_t>(padding[0]);
448 const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
449 const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
450
451 TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
452 "max_pool3d: dilation must be either a single int, or a tuple of three ints");
453 const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
454 const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
455 const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
456
457 TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
458 "max_pool2d_with_indices_backward_out_cuda_template(): ",
459 "Expected 4D or 5D input tensor, but got ", input.sizes());
460
461 TORCH_CHECK((gradOutput.ndimension() == 4 || gradOutput.ndimension() == 5),
462 "max_pool2d_with_indices_backward_out_cuda_template(): ",
463 "Expected 4D or 5D gradOutput tensor, but got ", gradOutput.sizes());
464
465 // Resize and initialize result tensor.
466 bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;
467 Tensor _input = input;
468 if (input.ndimension() == 4) {
469 Tensor input_channels_last_check = input.unsqueeze(0);
470 // work around buggy behavior of suggest_memory_format here where
471 // suggested format of unsqueezed tensor is contiguous while it is
472 // really only contiguous in ChannelsLast3d
473 channels_last = (!input_channels_last_check.is_contiguous()) &&
474 input_channels_last_check.is_contiguous(at::MemoryFormat::ChannelsLast3d);
475 if (channels_last) {
476 _input = input_channels_last_check;
477 }
478 }
479 if (!channels_last) {
480 gradInput.resize_as_(input);
481 } else {
482 gradInput.resize_as_(_input, at::MemoryFormat::ChannelsLast3d);
483 }
484 gradInput.zero_();
485
486 const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
487 const int64_t nslices = input.size(-4);
488
489 const int64_t otime = gradOutput.size(-3);
490 const int64_t oheight = gradOutput.size(-2);
491 const int64_t owidth = gradOutput.size(-1);
492
493 const int64_t itime = gradInput.size(-3);
494 const int64_t iheight = gradInput.size(-2);
495 const int64_t iwidth = gradInput.size(-1);
496
497 max_pool3d_backward_shape_check(
498 input,
499 gradOutput,
500 indices,
501 nslices,
502 kT, kH, kW,
503 dT, dH, dW,
504 pT, pH, pW,
505 dilationT, dilationH, dilationW,
506 itime, iheight, iwidth,
507 otime, oheight, owidth,
508 "max_pool3d_with_indices_backward_out_cuda_template()");
509
510 if (gradOutput.numel() == 0) {
511 return;
512 }
513
514 Tensor work_grad_input = gradInput;
515 Tensor work_grad_output;
516 Tensor work_indices;
517 if (!channels_last) {
518 work_grad_output = gradOutput.contiguous();
519 work_indices = indices.contiguous();
520 } else {
521 if (input.ndimension() == 4) {
522 work_grad_output = gradOutput.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
523 work_indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
524 } else {
525 work_grad_output = gradOutput.contiguous(at::MemoryFormat::ChannelsLast3d);
526 work_indices = indices.contiguous(at::MemoryFormat::ChannelsLast3d);
527 }
528 }
529
530 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
531 "max_pool3d_with_indices_backward_out_frame",
532 [&] {
533 const int64_t totalZ = otime * nslices * nbatch;
534 scalar_t *grad_input_data = work_grad_input.mutable_data_ptr<scalar_t>();
535
536 max_pool3d_with_indices_backward_out_frame(
537 grad_input_data, work_grad_output, work_indices,
538 nslices,
539 totalZ,
540 itime, iheight, iwidth,
541 nbatch, otime, oheight, owidth,
542 channels_last);
543 }
544 );
545 }
546
547 } // namespace
548
max_pool3d_with_indices_out_cuda(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,Tensor & output,Tensor & indices)549 std::tuple<Tensor&, Tensor&> max_pool3d_with_indices_out_cuda(const Tensor& input,
550 IntArrayRef kernel_size,
551 IntArrayRef stride,
552 IntArrayRef padding,
553 IntArrayRef dilation,
554 bool ceil_mode,
555 Tensor& output,
556 Tensor& indices)
557 {
558 max_pool3d_with_indices_out_cuda_template(
559 output,
560 indices,
561 input,
562 kernel_size,
563 stride,
564 padding,
565 dilation,
566 ceil_mode);
567 return std::tuple<Tensor&, Tensor&>(output, indices);
568 }
569
max_pool3d_with_indices_cuda(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)570 std::tuple<Tensor, Tensor> max_pool3d_with_indices_cuda(
571 const Tensor& input,
572 IntArrayRef kernel_size,
573 IntArrayRef stride,
574 IntArrayRef padding,
575 IntArrayRef dilation,
576 bool ceil_mode)
577 {
578 NoNamesGuard guard;
579
580 Tensor output = at::empty({0}, input.options());
581 Tensor indices = at::empty({0}, input.options().dtype(kLong));
582 max_pool3d_with_indices_out_cuda_template(
583 output,
584 indices,
585 input,
586 kernel_size,
587 stride,
588 padding,
589 dilation,
590 ceil_mode);
591
592 guard.reset();
593 namedinference::propagate_names(output, input);
594 namedinference::propagate_names(indices, input);
595
596 return std::tuple<Tensor, Tensor>(output, indices);
597 }
598
max_pool3d_with_indices_backward_out_cuda(const Tensor & gradOutput,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,const Tensor & indices,Tensor & gradInput)599 Tensor& max_pool3d_with_indices_backward_out_cuda(const Tensor& gradOutput,
600 const Tensor& input,
601 IntArrayRef kernel_size,
602 IntArrayRef stride,
603 IntArrayRef padding,
604 IntArrayRef dilation,
605 bool ceil_mode,
606 const Tensor& indices,
607 Tensor& gradInput)
608 {
609 // See Note [Writing Nondeterministic Operations]
610 // Nondeterministic because of atomicAdd usage
611 globalContext().alertNotDeterministic("max_pool3d_with_indices_backward_out_cuda");
612 max_pool3d_with_indices_backward_out_cuda_template(
613 gradInput,
614 gradOutput,
615 input,
616 indices,
617 kernel_size,
618 stride,
619 padding,
620 dilation,
621 ceil_mode);
622 return gradInput;
623 }
624
max_pool3d_with_indices_backward_cuda(const Tensor & gradOutput,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,const Tensor & indices)625 Tensor max_pool3d_with_indices_backward_cuda(
626 const Tensor& gradOutput,
627 const Tensor& input,
628 IntArrayRef kernel_size,
629 IntArrayRef stride,
630 IntArrayRef padding,
631 IntArrayRef dilation,
632 bool ceil_mode,
633 const Tensor& indices)
634 {
635 // See Note [Writing Nondeterministic Operations]
636 // Nondeterministic because of atomicAdd usage
637 globalContext().alertNotDeterministic("max_pool3d_with_indices_backward_cuda");
638 auto gradInput = at::empty(input.sizes(), input.options());
639 max_pool3d_with_indices_backward_out_cuda_template(
640 gradInput,
641 gradOutput,
642 input,
643 indices,
644 kernel_size,
645 stride,
646 padding,
647 dilation,
648 ceil_mode);
649 return gradInput;
650 }
651
652 } // at::native
653