xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/MaxUnpooling.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5 
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/detail/KernelUtils.h>
8 #include <c10/util/Exception.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/max_unpool2d_native.h>
15 #include <ATen/ops/max_unpool3d_native.h>
16 
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_like.h>
19 #endif
20 
21 namespace at::native {
22 
23 using namespace at::cuda::detail;
24 
25 template <typename T>
ceilDiv(T a,T b)26 __host__ __device__ __forceinline__ T ceilDiv(T a, T b) {
27   return (a + b - 1) / b;
28 }
29 
30 template <typename T>
max_unpooling2d_forward_kernel(const int64_t numInputElements,const T * input,const int64_t * indices,const int64_t numChannels,const int64_t inputHeight,const int64_t inputWidth,const int64_t outputHeight,const int64_t outputWidth,T * output)31 __global__ void max_unpooling2d_forward_kernel(
32     const int64_t numInputElements,
33     const T* input,
34     const int64_t* indices,
35     const int64_t numChannels,
36     const int64_t inputHeight,
37     const int64_t inputWidth,
38     const int64_t outputHeight,
39     const int64_t outputWidth,
40     T* output) {
41   int64_t outputImageSize = outputHeight * outputWidth;
42   CUDA_KERNEL_LOOP(linearIndex, numInputElements) {
43     int c = (linearIndex / inputWidth / inputHeight) % numChannels;
44     int n = linearIndex / inputWidth / inputHeight / numChannels;
45     output += (n * numChannels + c) * outputHeight * outputWidth;
46     int maxind = indices[linearIndex];
47     CUDA_KERNEL_ASSERT(maxind >= 0 && maxind < outputImageSize);
48     output[maxind] = input[linearIndex];
49   }
50 }
51 
52 template <typename T>
max_unpooling3d_forward_kernel(PackedTensorAccessor64<const T,4> input,PackedTensorAccessor64<const int64_t,4> indices,T * output,const int64_t oT,const int64_t oH,const int64_t oW,const int64_t offsetZ)53 __global__ void max_unpooling3d_forward_kernel(
54     PackedTensorAccessor64<const T, 4> input,
55     PackedTensorAccessor64<const int64_t, 4> indices,
56     T* output,
57     const int64_t oT,
58     const int64_t oH,
59     const int64_t oW,
60     const int64_t offsetZ) {
61   int64_t iColumn = blockIdx.x * blockDim.x + threadIdx.x;
62   int64_t iRow = blockIdx.y * blockDim.y + threadIdx.y;
63   int64_t iFrame = (blockIdx.z + offsetZ) % input.size(1); // input frame/time
64   int64_t slice = (blockIdx.z + offsetZ) / input.size(1); // input slice/feature
65   int64_t outputImageSize = oT * oH * oW;
66   if (iRow < input.size(2) && iColumn < input.size(3)) {
67     const T val = input[slice][iFrame][iRow][iColumn];
68     const int64_t index = indices[slice][iFrame][iRow][iColumn];
69     CUDA_KERNEL_ASSERT(index >= 0 && index < outputImageSize);
70     output[slice * oT * oH * oW + index] = val;
71   }
72 }
73 
74 template <typename T>
max_unpooling2d_backward_kernel(const int64_t numInputElements,const T * input,const int64_t * indices,const int64_t numChannels,const int64_t inputHeight,const int64_t inputWidth,const int64_t outputHeight,const int64_t outputWidth,T * output)75 __global__ void max_unpooling2d_backward_kernel(
76     const int64_t numInputElements,
77     const T* input,
78     const int64_t* indices,
79     const int64_t numChannels,
80     const int64_t inputHeight,
81     const int64_t inputWidth,
82     const int64_t outputHeight,
83     const int64_t outputWidth,
84     T* output) {
85   CUDA_KERNEL_LOOP(linearIndex, numInputElements) {
86     int c = (linearIndex / inputWidth / inputHeight) % numChannels;
87     int n = linearIndex / inputWidth / inputHeight / numChannels;
88     input += (n * numChannels + c) * outputHeight * outputWidth;
89     int maxind = indices[linearIndex];
90     output[linearIndex] = input[maxind];
91   }
92 }
93 
94 template <typename T>
max_unpooling3d_backward_kernel(const T * gradOutputData,int64_t oT,int64_t oH,int64_t oW,PackedTensorAccessor64<int64_t,4> indices,PackedTensorAccessor64<T,4> gradInput,int offsetZ)95 __global__ void max_unpooling3d_backward_kernel(
96     const T* gradOutputData,
97     int64_t oT,
98     int64_t oH,
99     int64_t oW,
100     PackedTensorAccessor64<int64_t, 4> indices,
101     PackedTensorAccessor64<T, 4> gradInput,
102     int offsetZ) {
103   int iColumn = blockIdx.x * blockDim.x + threadIdx.x;
104   int iRow = blockIdx.y * blockDim.y + threadIdx.y;
105   int iFrame = (blockIdx.z + offsetZ) % gradInput.size(1); // output frame/time
106   int slice =
107       (blockIdx.z + offsetZ) / gradInput.size(1); // output slice/feature
108 
109   if (iRow < gradInput.size(2) && iColumn < gradInput.size(3)) {
110     int64_t index = indices[slice][iFrame][iRow][iColumn];
111     T grad_val = gradOutputData[slice * oT * oH * oW + index];
112     gradInput[slice][iFrame][iRow][iColumn] = grad_val;
113   }
114 }
115 
max_unpooling2d_forward_out_cuda(const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,Tensor & output)116 Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
117     const Tensor& indices_,
118     IntArrayRef output_size,
119     Tensor& output) {
120   // See Note [Writing Nondeterministic Operations]
121   // Nondeterministic with duplicate indices
122   at::globalContext().alertNotDeterministic("max_unpooling2d_forward_out");
123 
124   TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
125   TORCH_CHECK(
126       indices_.scalar_type() == at::ScalarType::Long,
127       "elements in indices should be type int64 but got: ", indices_.scalar_type());
128   auto oheight = output_size[0];
129   auto owidth = output_size[1];
130 
131   TensorArg output_arg{output, "output", 1}, self_arg{self_, "self_", 2},
132       indices_arg{indices_, "indices_", 3};
133   checkAllSameGPU(
134       "max_unpooling2d_forward_out_cuda", {output_arg, self_arg, indices_arg});
135 
136   for (int64_t i = 1; i < self_.ndimension(); ++i) {
137     TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cuda(): ",
138                 "Expected input to have non-zero size for non-batch dimensions, but got ",
139                 self_.sizes(), " with dimension ", i , " being empty.");
140   }
141 
142   TORCH_CHECK(
143       (self_.ndimension() == 3 || self_.ndimension() == 4),
144       "Input to max_unpooling2d should be a 3d or 4d Tensor, but got tensor with dimension: ", self_.ndimension());
145   TORCH_CHECK(
146       self_.sizes() == indices_.sizes(),
147       "Expected shape of indices to be: ", self_.sizes(), " but got: ", indices_.sizes());
148   TORCH_CHECK(
149       output_size.size() == 2,
150       "There should be exactly two elements (height, width) in output_size, but got ", output_size.size(), " elements.");
151 
152   int64_t dimw = 2;
153   int64_t dimh = 1;
154   int64_t numBatch = 1;
155 
156   int64_t numChannels;
157   int64_t inputHeight;
158   int64_t inputWidth;
159 
160   auto self = self_.contiguous();
161   auto indices = indices_.contiguous();
162 
163   if (self.ndimension() == 4) {
164     numBatch = self.size(0);
165     dimw++;
166     dimh++;
167   }
168   numChannels = self.size(dimh - 1);
169   inputHeight = self.size(dimh);
170   inputWidth = self.size(dimw);
171 
172   output.resize_({numBatch, numChannels, oheight, owidth});
173 
174   output.zero_();
175 
176   auto count = self.numel();
177   if (count != 0) {
178     AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16,
179         self.scalar_type(), "max_unpooling2d_forward_kernel", ([&] {
180           max_unpooling2d_forward_kernel<<<
181               GET_BLOCKS(count),
182               CUDA_NUM_THREADS,
183               0,
184               at::cuda::getCurrentCUDAStream()>>>(
185               self.numel(),
186               self.const_data_ptr<scalar_t>(),
187               indices.const_data_ptr<int64_t>(),
188               numChannels,
189               inputHeight,
190               inputWidth,
191               oheight,
192               owidth,
193               output.mutable_data_ptr<scalar_t>());
194           C10_CUDA_KERNEL_LAUNCH_CHECK();
195         }));
196   }
197   if (self.ndimension() == 3) {
198     output.resize_({numChannels, oheight, owidth});
199   }
200   return output;
201 }
202 
max_unpooling2d_forward_cuda(const Tensor & self,const Tensor & indices,IntArrayRef output_size)203 Tensor max_unpooling2d_forward_cuda(
204     const Tensor& self,
205     const Tensor& indices,
206     IntArrayRef output_size) {
207   auto output = at::empty({0}, self.options());
208   at::native::max_unpooling2d_forward_out_cuda(self, indices, output_size, output);
209   return output;
210 }
211 
max_unpooling3d_shape_check(const Tensor & input,const Tensor & gradOutput,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,const char * fn_name)212 static void max_unpooling3d_shape_check(
213     const Tensor& input,
214     const Tensor& gradOutput,
215     const Tensor& indices,
216     IntArrayRef output_size,
217     IntArrayRef stride,
218     IntArrayRef padding,
219     const char *fn_name) {
220   int64_t oT = output_size[0];
221   int64_t oH = output_size[1];
222   int64_t oW = output_size[2];
223   TORCH_CHECK(
224       indices.scalar_type() == at::ScalarType::Long,
225       "elements in indices should be type int64 but got: ", indices.scalar_type());
226   TORCH_CHECK(
227       (input.ndimension() == 4 || input.ndimension() == 5),
228       "Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with dim ", input.ndimension());
229   TORCH_CHECK(
230       output_size.size() == 3,
231       "There should be exactly three elements (depth, height, width) in output_size, but got ", output_size.size(), " elements.");
232   TORCH_CHECK(
233       stride.size() == 3,
234       "There should be exactly three elements (depth, height, width) in stride, but got: ", stride.size(), " elements.");
235   TORCH_CHECK(
236       padding.size() == 3,
237       "There should be exactly three elements (depth, height, width) in padding, but got: ", padding.size(), " elements.");
238   TORCH_CHECK(
239       input.sizes() == indices.sizes(),
240       "Expected shape of indices to be: ", input.sizes(), " but got: ", indices.sizes());
241 
242   for (int64_t i = 1; i < input.ndimension(); ++i) {
243     TORCH_CHECK(input.size(i) > 0, fn_name,
244                 ": Expected input to have non-zero size for non-batch dimensions, but got ",
245                 input.sizes(), " with dimension ", i , " being empty.");
246   }
247 
248   TORCH_CHECK(
249       stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
250       "strides should be greater than zero, but got stride: ",
251       stride);
252 
253   int dimw = 3;
254   int dimh = 2;
255   int dimt = 1;
256   int dimn = 0;
257 
258   if (input.ndimension() == 5) {
259     dimw++;
260     dimh++;
261     dimt++;
262     dimn++;
263   }
264 
265   int nslices = input.size(dimn);
266 
267   if (gradOutput.defined()) {
268     if (oT != gradOutput.size(dimt) || oH != gradOutput.size(dimh) ||
269         oW != gradOutput.size(dimw)) {
270       AT_ERROR(
271           "Inconsistent gradOutput size. oT= ",
272           oT,
273           ", oH= ",
274           oH,
275           ", oW= ",
276           oW,
277           ". gradOutput: ",
278           gradOutput.size(dimt),
279           "x",
280           gradOutput.size(dimh),
281           "x",
282           gradOutput.size(dimw));
283     }
284     TORCH_CHECK(
285         gradOutput.ndimension() == input.ndimension() &&
286             gradOutput.size(dimn) == nslices,
287         "gradOutput and input Tensors should have same number of dimensions and also the same number of channels/slices");
288   }
289 }
290 
max_unpooling3d_forward_out_cuda(const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,Tensor & output)291 Tensor& max_unpooling3d_forward_out_cuda(const Tensor& self_,
292     const Tensor& indices_,
293     IntArrayRef output_size,
294     IntArrayRef stride,
295     IntArrayRef padding,
296     Tensor& output) {
297   // See Note [Writing Nondeterministic Operations]
298   // Nondeterministic with duplicate indices
299   at::globalContext().alertNotDeterministic("max_unpooling3d_forward_out");
300 
301   TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
302   max_unpooling3d_shape_check(
303     self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cuda()");
304 
305   int64_t oT = output_size[0];
306   int64_t oH = output_size[1];
307   int64_t oW = output_size[2];
308 
309   TensorArg output_arg{output, "output", 1}, self_arg{self_, "self_", 2},
310       indices_arg{indices_, "indices_", 3};
311   checkAllSameGPU(
312       "max_unpooling3d_forward_out_cuda", {output_arg, self_arg, indices_arg});
313 
314   auto self = self_.contiguous();
315   auto indices = indices_.contiguous();
316 
317   int64_t batchSize;
318   int64_t inputSlices;
319   int64_t inputTime;
320   int64_t inputHeight;
321   int64_t inputWidth;
322 
323   if (self.ndimension() == 4) {
324     batchSize = 1;
325     inputSlices = self.size(0);
326     inputTime = self.size(1);
327     inputHeight = self.size(2);
328     inputWidth = self.size(3);
329     output.resize_({inputSlices, oT, oH, oW});
330   } else {
331     batchSize = self.size(0);
332     inputSlices = self.size(1);
333     inputTime = self.size(2);
334     inputHeight = self.size(3);
335     inputWidth = self.size(4);
336     output.resize_({batchSize, inputSlices, oT, oH, oW});
337   }
338 
339   output.zero_();
340 
341   // Collapse batch and feature dimensions if needed
342   if (self.ndimension() == 5) {
343     self = self.reshape({self.size(0) * self.size(1),
344                          self.size(2),
345                          self.size(3),
346                          self.size(4)});
347     indices = indices.reshape({indices.size(0) * indices.size(1),
348                                indices.size(2),
349                                indices.size(3),
350                                indices.size(4)});
351   }
352 
353   if (self.numel() == 0) {
354     return output;
355   }
356 
357   int totalZ = inputTime * inputSlices * batchSize;
358   int offsetZ = 0;
359   dim3 block(32, 8);
360 
361   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16,
362       self.scalar_type(), "max_unpooling3d_forward_kernel", ([&] {
363         while (totalZ > 0) {
364           dim3 grid(
365               ceilDiv(inputWidth, static_cast<int64_t>(block.x)),
366               ceilDiv(inputHeight, static_cast<int64_t>(block.y)),
367               totalZ > 65535 ? 65535 : totalZ);
368           max_unpooling3d_forward_kernel<<<
369               grid,
370               block,
371               0,
372               at::cuda::getCurrentCUDAStream()>>>(
373               self.packed_accessor64<const scalar_t, 4>(),
374               indices.packed_accessor64<const int64_t, 4>(),
375               output.mutable_data_ptr<scalar_t>(),
376               oT,
377               oH,
378               oW,
379               offsetZ);
380           C10_CUDA_KERNEL_LAUNCH_CHECK();
381           totalZ -= 65535;
382           offsetZ += 65535;
383         }
384       }));
385   return output;
386 }
387 
max_unpooling3d_forward_cuda(const Tensor & self,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding)388 Tensor max_unpooling3d_forward_cuda(
389     const Tensor& self,
390     const Tensor& indices,
391     IntArrayRef output_size,
392     IntArrayRef stride,
393     IntArrayRef padding) {
394   auto output = at::empty({0}, self.options());
395   at::native::max_unpooling3d_forward_out_cuda(
396       self, indices, output_size, stride, padding, output);
397   return output;
398 }
399 
max_unpooling2d_backward_out_cuda(const Tensor & grad_output_,const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,Tensor & grad_input)400 at::Tensor& max_unpooling2d_backward_out_cuda(const Tensor& grad_output_,
401     const Tensor& self_,
402     const Tensor& indices_,
403     IntArrayRef output_size,
404     Tensor& grad_input) {
405   int64_t oheight = output_size[0];
406   int64_t owidth = output_size[1];
407   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
408   TORCH_CHECK(
409       indices_.scalar_type() == at::ScalarType::Long,
410       "elements in indices should be type int64 but got type: ", indices_.scalar_type());
411   TensorArg grad_input_arg{grad_input, "grad_input", 1},
412       grad_output_arg{grad_output_, "grad_output_", 2},
413       self_arg{self_, "self_", 3}, indices_arg{indices_, "indices_", 4};
414   checkAllSameGPU(
415       "max_unpooling2d_backward_out_cuda",
416       {grad_input_arg, grad_output_arg, self_arg, indices_arg});
417 
418   TORCH_CHECK(
419       (self_.ndimension() == 3 || self_.ndimension() == 4),
420       "Input to max_unpooling2d should be a 3d or 4d Tensor, instead got: ",
421       self_);
422 
423   TORCH_CHECK(
424       self_.sizes() == indices_.sizes(),
425       "Expected shape of indices to be: ", self_.sizes(), " but got: ", indices_.sizes());
426 
427   TORCH_CHECK(output_size.size() == 2, "output_size must have two elements, got size: ", output_size.size());
428 
429   int64_t nInputCols, nInputRows, nInputPlane;
430 
431   int dimw = 2;
432   int dimh = 1;
433 
434   auto self = self_.contiguous();
435   auto indices = indices_.contiguous();
436   auto grad_output = grad_output_.contiguous();
437 
438   if (self.ndimension() == 3) {
439     nInputPlane = self.size(0);
440   } else {
441     ++dimw;
442     ++dimh;
443     nInputPlane = self.size(1);
444   }
445 
446   nInputCols = self.size(dimw);
447   nInputRows = self.size(dimh);
448 
449   if (oheight != grad_output.size(dimh) || owidth != grad_output.size(dimw)) {
450     AT_ERROR(
451         "Inconsistent gradOutput size. output height: ",
452         oheight,
453         ", output width= ",
454         owidth,
455         ", gradOutput: ",
456         grad_output.size(dimh),
457         "x",
458         grad_output.size(dimw));
459   }
460 
461   grad_input.resize_as_(self);
462   grad_input.zero_();
463 
464   int64_t count = self.numel();
465   if (count == 0) {
466     return grad_input;
467   }
468 
469   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16,
470       self.scalar_type(), "max_unpooling2d_backward_kernel", ([&] {
471         max_unpooling2d_backward_kernel<<<
472             GET_BLOCKS(count),
473             CUDA_NUM_THREADS,
474             0,
475             at::cuda::getCurrentCUDAStream()>>>(
476             count,
477             grad_output.const_data_ptr<scalar_t>(),
478             indices.const_data_ptr<int64_t>(),
479             nInputPlane,
480             nInputRows,
481             nInputCols,
482             oheight,
483             owidth,
484             grad_input.mutable_data_ptr<scalar_t>());
485         C10_CUDA_KERNEL_LAUNCH_CHECK();
486       }));
487   return grad_input;
488 }
max_unpooling2d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & indices,IntArrayRef output_size)489 at::Tensor max_unpooling2d_backward_cuda(
490     const Tensor& grad_output,
491     const Tensor& self,
492     const Tensor& indices,
493     IntArrayRef output_size) {
494   auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
495   at::native::max_unpooling2d_backward_out_cuda(
496       grad_output, self, indices, output_size, grad_input);
497   return grad_input;
498 }
499 
max_unpooling3d_backward_out_cuda(const Tensor & grad_output_,const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,Tensor & grad_input)500 at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_,
501     const Tensor& self_,
502     const Tensor& indices_,
503     IntArrayRef output_size,
504     IntArrayRef stride,
505     IntArrayRef padding,
506     Tensor& grad_input) {
507   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
508   int64_t oT = output_size[0];
509   int64_t oH = output_size[1];
510   int64_t oW = output_size[2];
511 
512   max_unpooling3d_shape_check(
513     self_, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cuda()");
514 
515   int batchSize = 0;
516   int inputSlices = 0;
517   int inputTime = 0;
518   int64_t inputHeight = 0;
519   int64_t inputWidth = 0;
520 
521   TensorArg self_arg{self_, "self_", 1}, indices_arg{indices_, "indices_", 2},
522       grad_output_arg{grad_output_, "grad_output_", 3},
523       grad_input_arg{grad_input, "grad_input", 4};
524   checkAllSameGPU(
525       "max_unpooling3d_backward_out_cuda",
526       {self_arg, indices_arg, grad_output_arg, grad_input_arg});
527 
528   auto self = self_.contiguous();
529   auto indices = indices_.contiguous();
530   auto grad_output = grad_output_.contiguous();
531 
532   if (self.ndimension() == 4) {
533     batchSize = 1;
534     inputSlices = self.size(0);
535     inputTime = self.size(1);
536     inputHeight = self.size(2);
537     inputWidth = self.size(3);
538   } else {
539     batchSize = self.size(0);
540     inputSlices = self.size(1);
541     inputTime = self.size(2);
542     inputHeight = self.size(3);
543     inputWidth = self.size(4);
544   }
545 
546   grad_input.resize_as_(self);
547   grad_input.zero_();
548 
549   // Collapse batch and feature dimensions if needed
550   auto grad_input_reshaped = grad_input;
551   if (grad_input.ndimension() == 5) {
552     grad_input_reshaped =
553         grad_input.reshape({grad_input.size(0) * grad_input.size(1),
554                             grad_input.size(2),
555                             grad_input.size(3),
556                             grad_input.size(4)});
557 
558     indices = indices.reshape({indices.size(0) * indices.size(1),
559                                indices.size(2),
560                                indices.size(3),
561                                indices.size(4)});
562   }
563   if (grad_input.numel() == 0) {
564     return grad_input;
565   }
566 
567   int totalZ = inputTime * inputSlices * batchSize;
568   int offsetZ = 0;
569 
570   dim3 block(32, 8);
571 
572   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16,
573       self.scalar_type(), "max_unpooling3d_backward_kernel", ([&] {
574         while (totalZ > 0) {
575           dim3 grid(
576               ceilDiv(inputWidth, static_cast<int64_t>(block.x)),
577               ceilDiv(inputHeight, static_cast<int64_t>(block.y)),
578               totalZ > 65535 ? 65535 : totalZ);
579           max_unpooling3d_backward_kernel<<<
580               grid,
581               block,
582               0,
583               at::cuda::getCurrentCUDAStream()>>>(
584               grad_output.const_data_ptr<scalar_t>(),
585               oT,
586               oH,
587               oW,
588               indices.packed_accessor64<int64_t, 4>(),
589               grad_input_reshaped.packed_accessor64<scalar_t, 4>(),
590               offsetZ);
591           C10_CUDA_KERNEL_LAUNCH_CHECK();
592           totalZ -= 65535;
593           offsetZ += 65535;
594         }
595       }));
596   return grad_input;
597 }
598 
max_unpooling3d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding)599 at::Tensor max_unpooling3d_backward_cuda(
600     const Tensor& grad_output,
601     const Tensor& self,
602     const Tensor& indices,
603     IntArrayRef output_size,
604     IntArrayRef stride,
605     IntArrayRef padding) {
606   auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
607   at::native::max_unpooling3d_backward_out_cuda(
608       grad_output, self, indices, output_size, stride, padding, grad_input);
609   return grad_input;
610 }
611 
612 } // namespace at::native
613