xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/NamedTensorUtils.h>
7 #include <ATen/NumericUtils.h>
8 #include <ATen/native/Pool.h>
9 #include <ATen/cuda/Atomic.cuh>
10 #include <ATen/cuda/CUDAContext.h>
11 #include <ATen/cuda/NumericLimits.cuh>
12 #include <ATen/cuda/detail/TensorInfo.cuh>
13 #include <ATen/cuda/detail/IndexUtils.cuh>
14 #include <ATen/cuda/detail/KernelUtils.h>
15 #include <c10/macros/Macros.h>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #include <ATen/NativeFunctions.h>
20 #else
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/max_pool3d_with_indices_native.h>
23 #include <ATen/ops/max_pool3d_with_indices_backward_native.h>
24 #endif
25 
26 namespace at::native {
27 namespace {
28 
min(int a,int b)29 __device__ inline int min(int a, int b) {
30   return a <= b ? a : b;
31 }
32 
33 template <typename scalar_t>
max_pool3d_with_indices_single_out_frame(const scalar_t * inputData,scalar_t * outputData,int64_t * indicesData,int features,int itime,int iheight,int iwidth,int obatch,int otime,int oheight,int owidth,int kT,int kH,int kW,int dT,int dH,int dW,int pT,int pH,int pW,int dilationT,int dilationH,int dilationW,int offsetZ,bool channels_last)34 __global__ static void max_pool3d_with_indices_single_out_frame(
35   const scalar_t* inputData,
36   scalar_t* outputData,
37   int64_t* indicesData,
38   int features,
39   int itime, int iheight, int iwidth,
40   int obatch, int otime, int oheight, int owidth,
41   int kT, int kH, int kW,
42   int dT, int dH, int dW,
43   int pT, int pH, int pW,
44   int dilationT, int dilationH, int dilationW,
45   int offsetZ,
46   bool channels_last)
47 {
48   int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
49   int oRow = blockIdx.y * blockDim.y + threadIdx.y;
50   int oFrame = 0;
51   // used only for channels-first indexing
52   int64_t slice = 0;
53   // used only for channels-last indexing
54   int batch = 0;
55   int channel = 0;
56   if (!channels_last) {
57     // indexing order: batch, channel, time
58     oFrame = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % otime; // output frame/time
59     slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / otime; // output slice/feature
60   } else {
61     // indexing order: batch, time, channel
62     channel = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % features; // output feature (channel)
63     slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / features; // output slice (batch + time)
64     batch = slice / otime;
65     oFrame = slice % otime;
66   }
67 
68   // For int64_t data type, see https://github.com/pytorch/pytorch/issues/52822
69   if (oRow < oheight && oColumn < owidth && oFrame < otime && channel < features && batch < obatch)
70   {
71     int tStart = oFrame  * dT - pT;
72     int hStart = oRow    * dH - pH;
73     int wStart = oColumn * dW - pW;
74     int tEnd = min(tStart + (kT - 1) * dilationT + 1, itime);
75     int hEnd = min(hStart + (kH - 1) * dilationH + 1, iheight);
76     int wEnd = min(wStart + (kW - 1) * dilationW + 1, iwidth);
77 
78     while(tStart < 0)
79       tStart += dilationT;
80     while(hStart < 0)
81       hStart += dilationH;
82     while(wStart < 0)
83       wStart += dilationW;
84 
85     // maxIndex remains in "channels-first"/contiguous
86     int64_t maxIndex = tStart * iheight * iwidth + hStart * iwidth + wStart;
87 
88     if (!channels_last) {
89         inputData += (int64_t) slice * itime * iheight * iwidth;
90     } else {
91         inputData += ((int64_t) batch * itime * iheight * iwidth * features) + channel;
92     }
93 
94     scalar_t max = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
95 
96     for (int t = tStart; t < tEnd; t += dilationT)
97     {
98       for (int h = hStart; h < hEnd; h += dilationH)
99       {
100         for (int w = wStart; w < wEnd; w += dilationW)
101         {
102           scalar_t val;
103           int index = t * iheight * iwidth + h * iwidth + w;
104           if (!channels_last) {
105             val = inputData[index];
106           } else {
107             int64_t index_channels_last = index*features;
108             val = inputData[index_channels_last];
109           }
110 
111           if ((max < val) || at::_isnan(val))
112           {
113             max = val;
114             maxIndex = index;
115           }
116         }
117       }
118     }
119 
120     int64_t out_index;
121     if (!channels_last) {
122       out_index = (int64_t) slice*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn;
123     } else {
124       out_index = ((int64_t) batch*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn)*features + channel;
125     }
126     outputData[out_index] = max;
127     indicesData[out_index] = maxIndex;
128   }
129 }
130 
131 template <typename scalar_t>
max_pool3d_with_indices_out_frame(const scalar_t * input_data,const Tensor & output,const Tensor & indices,int features,int64_t totalZ,int itime,int iheight,int iwidth,int obatch,int otime,int oheight,int owidth,int kT,int kH,int kW,int dT,int dH,int dW,int pT,int pH,int pW,int dilationT,int dilationH,int dilationW,bool channels_last)132 void max_pool3d_with_indices_out_frame(
133   const scalar_t* input_data,
134   const Tensor& output,
135   const Tensor& indices,
136   int features,
137   int64_t totalZ,
138   int itime, int iheight, int iwidth,
139   int obatch, int otime, int oheight, int owidth,
140   int kT, int kH, int kW,
141   int dT, int dH, int dW,
142   int pT, int pH, int pW,
143   int dilationT, int dilationH, int dilationW,
144   bool channels_last)
145 {
146   int offsetZ = 0;
147   int threadX = 32;
148   int threadY = 8;
149   int threadZ = 1;
150   int stepZ = 65535;
151   if (channels_last) {
152     threadX = 2;
153     threadY = 4;
154     threadZ = 64;
155   }
156   dim3 block(threadX, threadY, threadZ);
157 
158   while (totalZ > 0) {
159     dim3 grid(ceil_div(owidth, static_cast<int>(block.x)),
160               ceil_div(oheight, static_cast<int>(block.y)),
161               totalZ > stepZ*threadZ ? stepZ : ceil_div(totalZ, static_cast<int64_t>(threadZ)));
162 
163     max_pool3d_with_indices_single_out_frame
164       <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
165          input_data,
166          output.mutable_data_ptr<scalar_t>(),
167          indices.mutable_data_ptr<int64_t>(),
168          features,
169          itime, iheight, iwidth,
170          obatch, otime, oheight, owidth,
171          kT, kH, kW,
172          dT, dH, dW,
173          pT, pH, pW,
174          dilationT, dilationH, dilationW,
175          offsetZ, channels_last);
176     C10_CUDA_KERNEL_LAUNCH_CHECK();
177 
178     totalZ -= threadZ*stepZ;
179     offsetZ += threadZ*stepZ;
180   }
181 }
182 
183 #undef UPDATE_OUTPUT_KERNEL_WIDTH
184 
185 template <typename scalar_t>
max_pool3d_with_indices_backward_single_out_frame(scalar_t * gradInputData,const scalar_t * gradOutputData,const int64_t * indicesData,int features,int itime,int iheight,int iwidth,int obatch,int otime,int oheight,int owidth,int offsetZ,bool channels_last)186 __global__ static void max_pool3d_with_indices_backward_single_out_frame(
187   scalar_t *gradInputData,
188   const scalar_t *gradOutputData,
189   const int64_t *indicesData,
190   int features,
191   int itime, int iheight, int iwidth,
192   int obatch, int otime, int oheight, int owidth,
193   int offsetZ,
194   bool channels_last)
195 {
196   int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
197   int oRow = blockIdx.y * blockDim.y + threadIdx.y;
198 
199   int oFrame = 0;
200   // used only for channels-first indexing
201   int64_t slice = 0;
202   // used only for channels-last indexing
203   int batch = 0;
204   int channel = 0;
205   if (!channels_last) {
206     // indexing order: batch, channel, time
207     oFrame = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % otime; // output frame/time
208     slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / otime; // output slice/feature
209   } else {
210     // indexing order: batch, time, channel
211     channel = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % features; // output feature (channel)
212     slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / features; // output slice (batch + time)
213     batch = slice / otime;
214     oFrame = slice % otime;
215   }
216 
217   if (oRow < oheight && oColumn < owidth && oFrame < otime && batch < obatch && channel < features)
218   {
219     int64_t out_index;
220     if (!channels_last) {
221       out_index = (int64_t) slice*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn;
222     } else {
223       out_index = ((int64_t) batch*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn)*features + channel;
224     }
225     int64_t maxIndex = indicesData[out_index];
226     if (maxIndex != -1) {
227       if (!channels_last) {
228         gpuAtomicAddNoReturn(&gradInputData[(int64_t) slice * itime  * iheight * iwidth + maxIndex],
229           gradOutputData[out_index]);
230       } else {
231         gpuAtomicAddNoReturn(&gradInputData[((int64_t) batch * itime * iheight * iwidth + maxIndex) * features + channel],
232           gradOutputData[out_index]);
233       }
234     }
235   }
236 }
237 
238 template <typename scalar_t>
max_pool3d_with_indices_backward_out_frame(scalar_t * gradInputData,const Tensor & gradOutput,const Tensor & indices,int features,int64_t totalZ,int itime,int iheight,int iwidth,int obatch,int otime,int oheight,int owidth,bool channels_last)239 void max_pool3d_with_indices_backward_out_frame(
240   scalar_t *gradInputData,
241   const Tensor& gradOutput,
242   const Tensor& indices,
243   int features,
244   int64_t totalZ,
245   int itime, int iheight, int iwidth,
246   int obatch, int otime, int oheight, int owidth,
247   bool channels_last)
248 {
249   int offsetZ = 0;
250   int threadX = 32;
251   int threadY = 8;
252   int threadZ = 1;
253   int stepZ = 65535;
254   if (channels_last) {
255     threadX = 2;
256     threadY = 4;
257     threadZ = 64;
258   }
259   dim3 block(threadX, threadY, threadZ);
260 
261   while (totalZ > 0) {
262     dim3 grid(ceil_div(owidth, static_cast<int>(block.x)),
263               ceil_div(oheight, static_cast<int>(block.y)),
264               totalZ > stepZ*threadZ ? stepZ : ceil_div(totalZ, static_cast<int64_t>(block.z)));
265 
266     max_pool3d_with_indices_backward_single_out_frame
267       <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
268         gradInputData,
269         gradOutput.const_data_ptr<scalar_t>(),
270         indices.const_data_ptr<int64_t>(),
271         features,
272         itime, iheight, iwidth,
273         obatch, otime, oheight, owidth,
274         offsetZ,
275         channels_last);
276     C10_CUDA_KERNEL_LAUNCH_CHECK();
277 
278     totalZ -= threadZ*stepZ;
279     offsetZ += threadZ*stepZ;
280   }
281 }
282 
max_pool3d_with_indices_out_cuda_template(Tensor & output,Tensor & indices,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)283 void max_pool3d_with_indices_out_cuda_template(
284            Tensor& output,
285            Tensor& indices,
286            const Tensor& input,
287            IntArrayRef kernel_size,
288            IntArrayRef stride,
289            IntArrayRef padding,
290            IntArrayRef dilation,
291            bool ceil_mode)
292 {
293   TensorArg output_arg{ output, "output", 1 };
294   TensorArg indices_arg{ indices, "indices", 2 };
295   TensorArg input_arg{ input, "input", 3 };
296 
297   checkAllSameGPU(__func__,
298                   {output_arg, indices_arg, input_arg});
299 
300   // #20866, #22032: Guarantee this for the official C++ API?
301   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
302     "max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
303   const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
304   const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
305   const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
306 
307   TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 3,
308     "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
309   const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
310   const int dH = stride.empty() ? kH :
311                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
312   const int dW = stride.empty() ? kW :
313                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
314 
315   TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
316     "max_pool3d: padding must either be a single int, or a tuple of three ints");
317   const int pT = safe_downcast<int, int64_t>(padding[0]);
318   const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
319   const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
320 
321   TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
322     "max_pool3d: dilation must be either a single int, or a tuple of three ints");
323   const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
324   const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
325   const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
326 
327   const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
328   const int64_t nslices = input.size(-4);
329   const int64_t itime = input.size(-3);
330   const int64_t iheight = input.size(-2);
331   const int64_t iwidth = input.size(-1);
332 
333   const int64_t otime = pooling_output_shape<int64_t>(itime, kT, pT, dT, dilationT, ceil_mode);
334   const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
335   const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
336 
337   pool3d_shape_check(
338     input,
339     nslices,
340     kT, kH, kW,
341     dT, dH, dW,
342     pT, pH, pW,
343     dilationT, dilationH, dilationW,
344     itime, iheight, iwidth,
345     otime, oheight, owidth,
346     "max_pool3d_with_indices_out_cuda_template()");
347 
348   bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;
349   Tensor _input = input;
350   if (input.ndimension() == 4) {
351     Tensor input_channels_last_check = input.unsqueeze(0);
352     // work around buggy behavior of suggest_memory_format here where
353     // suggested format of unsqueezed tensor is contiguous while it is
354     // really only contiguous in ChannelsLast3d
355     channels_last = (!input_channels_last_check.is_contiguous()) &&
356                      input_channels_last_check.is_contiguous(at::MemoryFormat::ChannelsLast3d);
357     if (!channels_last) {
358       output.resize_({ nslices, otime, oheight, owidth});
359       indices.resize_({nslices, otime, oheight, owidth});
360     } else {
361       _input = input_channels_last_check;
362       output.resize_({1, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d);
363       indices.resize_({1, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d);
364       output = output.squeeze(0);
365       indices = indices.squeeze(0);
366     }
367   } else {
368     if (!channels_last) {
369       output.resize_({nbatch, nslices, otime, oheight, owidth});
370       indices.resize_({nbatch, nslices, otime, oheight, owidth});
371     } else {
372       output.resize_({nbatch, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d);
373       indices.resize_({nbatch, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d);
374     }
375   }
376 
377   if (input.numel() == 0) {
378     return;
379   }
380 
381   Tensor work_input;
382   Tensor work_output = output;
383   if (!channels_last) {
384     work_input = input.contiguous();
385   } else {
386     work_input = _input.contiguous(at::MemoryFormat::ChannelsLast3d);
387   }
388   Tensor work_indices = indices;
389 
390   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
391     input.scalar_type(),
392     "max_pool3d_with_indices_out_frame",
393     [&]{
394       const scalar_t *input_data = work_input.const_data_ptr<scalar_t>();
395       const int64_t totalZ = otime * nslices * nbatch;
396 
397       max_pool3d_with_indices_out_frame(
398         input_data, work_output, work_indices,
399         nslices, // features
400         totalZ,
401         itime, iheight, iwidth,
402         nbatch, otime, oheight, owidth,
403         kT, kH, kW,
404         dT, dH, dW,
405         pT, pH, pW,
406         dilationT, dilationH, dilationW, channels_last);
407     }
408   );
409 }
410 
max_pool3d_with_indices_backward_out_cuda_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,const Tensor & indices,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)411 void max_pool3d_with_indices_backward_out_cuda_template(
412            Tensor& gradInput,
413            const Tensor& gradOutput,
414            const Tensor& input,
415            const Tensor& indices,
416            IntArrayRef kernel_size,
417            IntArrayRef stride,
418            IntArrayRef padding,
419            IntArrayRef dilation,
420            bool ceil_mode)
421 {
422   TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
423   TensorArg gradOutput_arg{ gradOutput, "gradOutput", 2 };
424   TensorArg input_arg{ input, "input", 3 };
425   TensorArg indices_arg{ indices, "indices", 4 };
426 
427   checkAllSameGPU(__func__,
428                   {gradInput_arg, gradOutput_arg, input_arg, indices_arg});
429 
430   // #20866, #22032: Guarantee this for the official C++ API?
431   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
432     "max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
433   const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
434   const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
435   const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
436 
437   TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 3,
438     "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
439   const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
440   const int dH = stride.empty() ? kH :
441                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
442   const int dW = stride.empty() ? kW :
443                  stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
444 
445   TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
446     "max_pool3d: padding must either be a single int, or a tuple of three ints");
447   const int pT = safe_downcast<int, int64_t>(padding[0]);
448   const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
449   const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
450 
451   TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
452     "max_pool3d: dilation must be either a single int, or a tuple of three ints");
453   const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
454   const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
455   const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
456 
457   TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
458     "max_pool2d_with_indices_backward_out_cuda_template(): ",
459     "Expected 4D or 5D input tensor, but got ", input.sizes());
460 
461   TORCH_CHECK((gradOutput.ndimension() == 4 || gradOutput.ndimension() == 5),
462     "max_pool2d_with_indices_backward_out_cuda_template(): ",
463     "Expected 4D or 5D gradOutput tensor, but got ", gradOutput.sizes());
464 
465   // Resize and initialize result tensor.
466   bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;
467   Tensor _input = input;
468   if (input.ndimension() == 4) {
469     Tensor input_channels_last_check = input.unsqueeze(0);
470     // work around buggy behavior of suggest_memory_format here where
471     // suggested format of unsqueezed tensor is contiguous while it is
472     // really only contiguous in ChannelsLast3d
473     channels_last = (!input_channels_last_check.is_contiguous()) &&
474                      input_channels_last_check.is_contiguous(at::MemoryFormat::ChannelsLast3d);
475     if (channels_last) {
476       _input = input_channels_last_check;
477     }
478   }
479   if (!channels_last) {
480     gradInput.resize_as_(input);
481   } else {
482     gradInput.resize_as_(_input, at::MemoryFormat::ChannelsLast3d);
483   }
484   gradInput.zero_();
485 
486   const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
487   const int64_t nslices = input.size(-4);
488 
489   const int64_t otime = gradOutput.size(-3);
490   const int64_t oheight = gradOutput.size(-2);
491   const int64_t owidth = gradOutput.size(-1);
492 
493   const int64_t itime = gradInput.size(-3);
494   const int64_t iheight = gradInput.size(-2);
495   const int64_t iwidth = gradInput.size(-1);
496 
497   max_pool3d_backward_shape_check(
498     input,
499     gradOutput,
500     indices,
501     nslices,
502     kT, kH, kW,
503     dT, dH, dW,
504     pT, pH, pW,
505     dilationT, dilationH, dilationW,
506     itime, iheight, iwidth,
507     otime, oheight, owidth,
508     "max_pool3d_with_indices_backward_out_cuda_template()");
509 
510   if (gradOutput.numel() == 0) {
511     return;
512   }
513 
514   Tensor work_grad_input = gradInput;
515   Tensor work_grad_output;
516   Tensor work_indices;
517   if (!channels_last) {
518     work_grad_output = gradOutput.contiguous();
519     work_indices = indices.contiguous();
520   } else {
521     if (input.ndimension() == 4) {
522       work_grad_output = gradOutput.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
523       work_indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
524     } else {
525       work_grad_output = gradOutput.contiguous(at::MemoryFormat::ChannelsLast3d);
526       work_indices = indices.contiguous(at::MemoryFormat::ChannelsLast3d);
527     }
528   }
529 
530   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
531     "max_pool3d_with_indices_backward_out_frame",
532     [&] {
533       const int64_t totalZ = otime * nslices * nbatch;
534       scalar_t *grad_input_data = work_grad_input.mutable_data_ptr<scalar_t>();
535 
536       max_pool3d_with_indices_backward_out_frame(
537         grad_input_data, work_grad_output, work_indices,
538         nslices,
539         totalZ,
540         itime, iheight, iwidth,
541         nbatch, otime, oheight, owidth,
542         channels_last);
543     }
544   );
545 }
546 
547 } // namespace
548 
max_pool3d_with_indices_out_cuda(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,Tensor & output,Tensor & indices)549 std::tuple<Tensor&, Tensor&> max_pool3d_with_indices_out_cuda(const Tensor& input,
550   IntArrayRef kernel_size,
551   IntArrayRef stride,
552   IntArrayRef padding,
553   IntArrayRef dilation,
554   bool ceil_mode,
555   Tensor& output,
556   Tensor& indices)
557 {
558   max_pool3d_with_indices_out_cuda_template(
559     output,
560     indices,
561     input,
562     kernel_size,
563     stride,
564     padding,
565     dilation,
566     ceil_mode);
567   return std::tuple<Tensor&, Tensor&>(output, indices);
568 }
569 
max_pool3d_with_indices_cuda(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode)570 std::tuple<Tensor, Tensor> max_pool3d_with_indices_cuda(
571   const Tensor& input,
572   IntArrayRef kernel_size,
573   IntArrayRef stride,
574   IntArrayRef padding,
575   IntArrayRef dilation,
576   bool ceil_mode)
577 {
578   NoNamesGuard guard;
579 
580   Tensor output = at::empty({0}, input.options());
581   Tensor indices = at::empty({0}, input.options().dtype(kLong));
582   max_pool3d_with_indices_out_cuda_template(
583     output,
584     indices,
585     input,
586     kernel_size,
587     stride,
588     padding,
589     dilation,
590     ceil_mode);
591 
592   guard.reset();
593   namedinference::propagate_names(output, input);
594   namedinference::propagate_names(indices, input);
595 
596   return std::tuple<Tensor, Tensor>(output, indices);
597 }
598 
max_pool3d_with_indices_backward_out_cuda(const Tensor & gradOutput,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,const Tensor & indices,Tensor & gradInput)599 Tensor& max_pool3d_with_indices_backward_out_cuda(const Tensor& gradOutput,
600   const Tensor& input,
601   IntArrayRef kernel_size,
602   IntArrayRef stride,
603   IntArrayRef padding,
604   IntArrayRef dilation,
605   bool ceil_mode,
606   const Tensor& indices,
607   Tensor& gradInput)
608 {
609   // See Note [Writing Nondeterministic Operations]
610   // Nondeterministic because of atomicAdd usage
611   globalContext().alertNotDeterministic("max_pool3d_with_indices_backward_out_cuda");
612   max_pool3d_with_indices_backward_out_cuda_template(
613     gradInput,
614     gradOutput,
615     input,
616     indices,
617     kernel_size,
618     stride,
619     padding,
620     dilation,
621     ceil_mode);
622   return gradInput;
623 }
624 
max_pool3d_with_indices_backward_cuda(const Tensor & gradOutput,const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,const Tensor & indices)625 Tensor max_pool3d_with_indices_backward_cuda(
626   const Tensor& gradOutput,
627   const Tensor& input,
628   IntArrayRef kernel_size,
629   IntArrayRef stride,
630   IntArrayRef padding,
631   IntArrayRef dilation,
632   bool ceil_mode,
633   const Tensor& indices)
634 {
635   // See Note [Writing Nondeterministic Operations]
636   // Nondeterministic because of atomicAdd usage
637   globalContext().alertNotDeterministic("max_pool3d_with_indices_backward_cuda");
638   auto gradInput = at::empty(input.sizes(), input.options());
639   max_pool3d_with_indices_backward_out_cuda_template(
640     gradInput,
641     gradOutput,
642     input,
643     indices,
644     kernel_size,
645     stride,
646     padding,
647     dilation,
648     ceil_mode);
649   return gradInput;
650 }
651 
652 } // at::native
653