1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/detail/KernelUtils.h>
8 #include <c10/util/Exception.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/max_unpool2d_native.h>
15 #include <ATen/ops/max_unpool3d_native.h>
16
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_like.h>
19 #endif
20
21 namespace at::native {
22
23 using namespace at::cuda::detail;
24
25 template <typename T>
ceilDiv(T a,T b)26 __host__ __device__ __forceinline__ T ceilDiv(T a, T b) {
27 return (a + b - 1) / b;
28 }
29
30 template <typename T>
max_unpooling2d_forward_kernel(const int64_t numInputElements,const T * input,const int64_t * indices,const int64_t numChannels,const int64_t inputHeight,const int64_t inputWidth,const int64_t outputHeight,const int64_t outputWidth,T * output)31 __global__ void max_unpooling2d_forward_kernel(
32 const int64_t numInputElements,
33 const T* input,
34 const int64_t* indices,
35 const int64_t numChannels,
36 const int64_t inputHeight,
37 const int64_t inputWidth,
38 const int64_t outputHeight,
39 const int64_t outputWidth,
40 T* output) {
41 int64_t outputImageSize = outputHeight * outputWidth;
42 CUDA_KERNEL_LOOP(linearIndex, numInputElements) {
43 int c = (linearIndex / inputWidth / inputHeight) % numChannels;
44 int n = linearIndex / inputWidth / inputHeight / numChannels;
45 output += (n * numChannels + c) * outputHeight * outputWidth;
46 int maxind = indices[linearIndex];
47 CUDA_KERNEL_ASSERT(maxind >= 0 && maxind < outputImageSize);
48 output[maxind] = input[linearIndex];
49 }
50 }
51
52 template <typename T>
max_unpooling3d_forward_kernel(PackedTensorAccessor64<const T,4> input,PackedTensorAccessor64<const int64_t,4> indices,T * output,const int64_t oT,const int64_t oH,const int64_t oW,const int64_t offsetZ)53 __global__ void max_unpooling3d_forward_kernel(
54 PackedTensorAccessor64<const T, 4> input,
55 PackedTensorAccessor64<const int64_t, 4> indices,
56 T* output,
57 const int64_t oT,
58 const int64_t oH,
59 const int64_t oW,
60 const int64_t offsetZ) {
61 int64_t iColumn = blockIdx.x * blockDim.x + threadIdx.x;
62 int64_t iRow = blockIdx.y * blockDim.y + threadIdx.y;
63 int64_t iFrame = (blockIdx.z + offsetZ) % input.size(1); // input frame/time
64 int64_t slice = (blockIdx.z + offsetZ) / input.size(1); // input slice/feature
65 int64_t outputImageSize = oT * oH * oW;
66 if (iRow < input.size(2) && iColumn < input.size(3)) {
67 const T val = input[slice][iFrame][iRow][iColumn];
68 const int64_t index = indices[slice][iFrame][iRow][iColumn];
69 CUDA_KERNEL_ASSERT(index >= 0 && index < outputImageSize);
70 output[slice * oT * oH * oW + index] = val;
71 }
72 }
73
74 template <typename T>
max_unpooling2d_backward_kernel(const int64_t numInputElements,const T * input,const int64_t * indices,const int64_t numChannels,const int64_t inputHeight,const int64_t inputWidth,const int64_t outputHeight,const int64_t outputWidth,T * output)75 __global__ void max_unpooling2d_backward_kernel(
76 const int64_t numInputElements,
77 const T* input,
78 const int64_t* indices,
79 const int64_t numChannels,
80 const int64_t inputHeight,
81 const int64_t inputWidth,
82 const int64_t outputHeight,
83 const int64_t outputWidth,
84 T* output) {
85 CUDA_KERNEL_LOOP(linearIndex, numInputElements) {
86 int c = (linearIndex / inputWidth / inputHeight) % numChannels;
87 int n = linearIndex / inputWidth / inputHeight / numChannels;
88 input += (n * numChannels + c) * outputHeight * outputWidth;
89 int maxind = indices[linearIndex];
90 output[linearIndex] = input[maxind];
91 }
92 }
93
94 template <typename T>
max_unpooling3d_backward_kernel(const T * gradOutputData,int64_t oT,int64_t oH,int64_t oW,PackedTensorAccessor64<int64_t,4> indices,PackedTensorAccessor64<T,4> gradInput,int offsetZ)95 __global__ void max_unpooling3d_backward_kernel(
96 const T* gradOutputData,
97 int64_t oT,
98 int64_t oH,
99 int64_t oW,
100 PackedTensorAccessor64<int64_t, 4> indices,
101 PackedTensorAccessor64<T, 4> gradInput,
102 int offsetZ) {
103 int iColumn = blockIdx.x * blockDim.x + threadIdx.x;
104 int iRow = blockIdx.y * blockDim.y + threadIdx.y;
105 int iFrame = (blockIdx.z + offsetZ) % gradInput.size(1); // output frame/time
106 int slice =
107 (blockIdx.z + offsetZ) / gradInput.size(1); // output slice/feature
108
109 if (iRow < gradInput.size(2) && iColumn < gradInput.size(3)) {
110 int64_t index = indices[slice][iFrame][iRow][iColumn];
111 T grad_val = gradOutputData[slice * oT * oH * oW + index];
112 gradInput[slice][iFrame][iRow][iColumn] = grad_val;
113 }
114 }
115
max_unpooling2d_forward_out_cuda(const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,Tensor & output)116 Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
117 const Tensor& indices_,
118 IntArrayRef output_size,
119 Tensor& output) {
120 // See Note [Writing Nondeterministic Operations]
121 // Nondeterministic with duplicate indices
122 at::globalContext().alertNotDeterministic("max_unpooling2d_forward_out");
123
124 TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
125 TORCH_CHECK(
126 indices_.scalar_type() == at::ScalarType::Long,
127 "elements in indices should be type int64 but got: ", indices_.scalar_type());
128 auto oheight = output_size[0];
129 auto owidth = output_size[1];
130
131 TensorArg output_arg{output, "output", 1}, self_arg{self_, "self_", 2},
132 indices_arg{indices_, "indices_", 3};
133 checkAllSameGPU(
134 "max_unpooling2d_forward_out_cuda", {output_arg, self_arg, indices_arg});
135
136 for (int64_t i = 1; i < self_.ndimension(); ++i) {
137 TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cuda(): ",
138 "Expected input to have non-zero size for non-batch dimensions, but got ",
139 self_.sizes(), " with dimension ", i , " being empty.");
140 }
141
142 TORCH_CHECK(
143 (self_.ndimension() == 3 || self_.ndimension() == 4),
144 "Input to max_unpooling2d should be a 3d or 4d Tensor, but got tensor with dimension: ", self_.ndimension());
145 TORCH_CHECK(
146 self_.sizes() == indices_.sizes(),
147 "Expected shape of indices to be: ", self_.sizes(), " but got: ", indices_.sizes());
148 TORCH_CHECK(
149 output_size.size() == 2,
150 "There should be exactly two elements (height, width) in output_size, but got ", output_size.size(), " elements.");
151
152 int64_t dimw = 2;
153 int64_t dimh = 1;
154 int64_t numBatch = 1;
155
156 int64_t numChannels;
157 int64_t inputHeight;
158 int64_t inputWidth;
159
160 auto self = self_.contiguous();
161 auto indices = indices_.contiguous();
162
163 if (self.ndimension() == 4) {
164 numBatch = self.size(0);
165 dimw++;
166 dimh++;
167 }
168 numChannels = self.size(dimh - 1);
169 inputHeight = self.size(dimh);
170 inputWidth = self.size(dimw);
171
172 output.resize_({numBatch, numChannels, oheight, owidth});
173
174 output.zero_();
175
176 auto count = self.numel();
177 if (count != 0) {
178 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16,
179 self.scalar_type(), "max_unpooling2d_forward_kernel", ([&] {
180 max_unpooling2d_forward_kernel<<<
181 GET_BLOCKS(count),
182 CUDA_NUM_THREADS,
183 0,
184 at::cuda::getCurrentCUDAStream()>>>(
185 self.numel(),
186 self.const_data_ptr<scalar_t>(),
187 indices.const_data_ptr<int64_t>(),
188 numChannels,
189 inputHeight,
190 inputWidth,
191 oheight,
192 owidth,
193 output.mutable_data_ptr<scalar_t>());
194 C10_CUDA_KERNEL_LAUNCH_CHECK();
195 }));
196 }
197 if (self.ndimension() == 3) {
198 output.resize_({numChannels, oheight, owidth});
199 }
200 return output;
201 }
202
max_unpooling2d_forward_cuda(const Tensor & self,const Tensor & indices,IntArrayRef output_size)203 Tensor max_unpooling2d_forward_cuda(
204 const Tensor& self,
205 const Tensor& indices,
206 IntArrayRef output_size) {
207 auto output = at::empty({0}, self.options());
208 at::native::max_unpooling2d_forward_out_cuda(self, indices, output_size, output);
209 return output;
210 }
211
max_unpooling3d_shape_check(const Tensor & input,const Tensor & gradOutput,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,const char * fn_name)212 static void max_unpooling3d_shape_check(
213 const Tensor& input,
214 const Tensor& gradOutput,
215 const Tensor& indices,
216 IntArrayRef output_size,
217 IntArrayRef stride,
218 IntArrayRef padding,
219 const char *fn_name) {
220 int64_t oT = output_size[0];
221 int64_t oH = output_size[1];
222 int64_t oW = output_size[2];
223 TORCH_CHECK(
224 indices.scalar_type() == at::ScalarType::Long,
225 "elements in indices should be type int64 but got: ", indices.scalar_type());
226 TORCH_CHECK(
227 (input.ndimension() == 4 || input.ndimension() == 5),
228 "Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with dim ", input.ndimension());
229 TORCH_CHECK(
230 output_size.size() == 3,
231 "There should be exactly three elements (depth, height, width) in output_size, but got ", output_size.size(), " elements.");
232 TORCH_CHECK(
233 stride.size() == 3,
234 "There should be exactly three elements (depth, height, width) in stride, but got: ", stride.size(), " elements.");
235 TORCH_CHECK(
236 padding.size() == 3,
237 "There should be exactly three elements (depth, height, width) in padding, but got: ", padding.size(), " elements.");
238 TORCH_CHECK(
239 input.sizes() == indices.sizes(),
240 "Expected shape of indices to be: ", input.sizes(), " but got: ", indices.sizes());
241
242 for (int64_t i = 1; i < input.ndimension(); ++i) {
243 TORCH_CHECK(input.size(i) > 0, fn_name,
244 ": Expected input to have non-zero size for non-batch dimensions, but got ",
245 input.sizes(), " with dimension ", i , " being empty.");
246 }
247
248 TORCH_CHECK(
249 stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
250 "strides should be greater than zero, but got stride: ",
251 stride);
252
253 int dimw = 3;
254 int dimh = 2;
255 int dimt = 1;
256 int dimn = 0;
257
258 if (input.ndimension() == 5) {
259 dimw++;
260 dimh++;
261 dimt++;
262 dimn++;
263 }
264
265 int nslices = input.size(dimn);
266
267 if (gradOutput.defined()) {
268 if (oT != gradOutput.size(dimt) || oH != gradOutput.size(dimh) ||
269 oW != gradOutput.size(dimw)) {
270 AT_ERROR(
271 "Inconsistent gradOutput size. oT= ",
272 oT,
273 ", oH= ",
274 oH,
275 ", oW= ",
276 oW,
277 ". gradOutput: ",
278 gradOutput.size(dimt),
279 "x",
280 gradOutput.size(dimh),
281 "x",
282 gradOutput.size(dimw));
283 }
284 TORCH_CHECK(
285 gradOutput.ndimension() == input.ndimension() &&
286 gradOutput.size(dimn) == nslices,
287 "gradOutput and input Tensors should have same number of dimensions and also the same number of channels/slices");
288 }
289 }
290
max_unpooling3d_forward_out_cuda(const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,Tensor & output)291 Tensor& max_unpooling3d_forward_out_cuda(const Tensor& self_,
292 const Tensor& indices_,
293 IntArrayRef output_size,
294 IntArrayRef stride,
295 IntArrayRef padding,
296 Tensor& output) {
297 // See Note [Writing Nondeterministic Operations]
298 // Nondeterministic with duplicate indices
299 at::globalContext().alertNotDeterministic("max_unpooling3d_forward_out");
300
301 TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
302 max_unpooling3d_shape_check(
303 self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cuda()");
304
305 int64_t oT = output_size[0];
306 int64_t oH = output_size[1];
307 int64_t oW = output_size[2];
308
309 TensorArg output_arg{output, "output", 1}, self_arg{self_, "self_", 2},
310 indices_arg{indices_, "indices_", 3};
311 checkAllSameGPU(
312 "max_unpooling3d_forward_out_cuda", {output_arg, self_arg, indices_arg});
313
314 auto self = self_.contiguous();
315 auto indices = indices_.contiguous();
316
317 int64_t batchSize;
318 int64_t inputSlices;
319 int64_t inputTime;
320 int64_t inputHeight;
321 int64_t inputWidth;
322
323 if (self.ndimension() == 4) {
324 batchSize = 1;
325 inputSlices = self.size(0);
326 inputTime = self.size(1);
327 inputHeight = self.size(2);
328 inputWidth = self.size(3);
329 output.resize_({inputSlices, oT, oH, oW});
330 } else {
331 batchSize = self.size(0);
332 inputSlices = self.size(1);
333 inputTime = self.size(2);
334 inputHeight = self.size(3);
335 inputWidth = self.size(4);
336 output.resize_({batchSize, inputSlices, oT, oH, oW});
337 }
338
339 output.zero_();
340
341 // Collapse batch and feature dimensions if needed
342 if (self.ndimension() == 5) {
343 self = self.reshape({self.size(0) * self.size(1),
344 self.size(2),
345 self.size(3),
346 self.size(4)});
347 indices = indices.reshape({indices.size(0) * indices.size(1),
348 indices.size(2),
349 indices.size(3),
350 indices.size(4)});
351 }
352
353 if (self.numel() == 0) {
354 return output;
355 }
356
357 int totalZ = inputTime * inputSlices * batchSize;
358 int offsetZ = 0;
359 dim3 block(32, 8);
360
361 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16,
362 self.scalar_type(), "max_unpooling3d_forward_kernel", ([&] {
363 while (totalZ > 0) {
364 dim3 grid(
365 ceilDiv(inputWidth, static_cast<int64_t>(block.x)),
366 ceilDiv(inputHeight, static_cast<int64_t>(block.y)),
367 totalZ > 65535 ? 65535 : totalZ);
368 max_unpooling3d_forward_kernel<<<
369 grid,
370 block,
371 0,
372 at::cuda::getCurrentCUDAStream()>>>(
373 self.packed_accessor64<const scalar_t, 4>(),
374 indices.packed_accessor64<const int64_t, 4>(),
375 output.mutable_data_ptr<scalar_t>(),
376 oT,
377 oH,
378 oW,
379 offsetZ);
380 C10_CUDA_KERNEL_LAUNCH_CHECK();
381 totalZ -= 65535;
382 offsetZ += 65535;
383 }
384 }));
385 return output;
386 }
387
max_unpooling3d_forward_cuda(const Tensor & self,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding)388 Tensor max_unpooling3d_forward_cuda(
389 const Tensor& self,
390 const Tensor& indices,
391 IntArrayRef output_size,
392 IntArrayRef stride,
393 IntArrayRef padding) {
394 auto output = at::empty({0}, self.options());
395 at::native::max_unpooling3d_forward_out_cuda(
396 self, indices, output_size, stride, padding, output);
397 return output;
398 }
399
max_unpooling2d_backward_out_cuda(const Tensor & grad_output_,const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,Tensor & grad_input)400 at::Tensor& max_unpooling2d_backward_out_cuda(const Tensor& grad_output_,
401 const Tensor& self_,
402 const Tensor& indices_,
403 IntArrayRef output_size,
404 Tensor& grad_input) {
405 int64_t oheight = output_size[0];
406 int64_t owidth = output_size[1];
407 TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
408 TORCH_CHECK(
409 indices_.scalar_type() == at::ScalarType::Long,
410 "elements in indices should be type int64 but got type: ", indices_.scalar_type());
411 TensorArg grad_input_arg{grad_input, "grad_input", 1},
412 grad_output_arg{grad_output_, "grad_output_", 2},
413 self_arg{self_, "self_", 3}, indices_arg{indices_, "indices_", 4};
414 checkAllSameGPU(
415 "max_unpooling2d_backward_out_cuda",
416 {grad_input_arg, grad_output_arg, self_arg, indices_arg});
417
418 TORCH_CHECK(
419 (self_.ndimension() == 3 || self_.ndimension() == 4),
420 "Input to max_unpooling2d should be a 3d or 4d Tensor, instead got: ",
421 self_);
422
423 TORCH_CHECK(
424 self_.sizes() == indices_.sizes(),
425 "Expected shape of indices to be: ", self_.sizes(), " but got: ", indices_.sizes());
426
427 TORCH_CHECK(output_size.size() == 2, "output_size must have two elements, got size: ", output_size.size());
428
429 int64_t nInputCols, nInputRows, nInputPlane;
430
431 int dimw = 2;
432 int dimh = 1;
433
434 auto self = self_.contiguous();
435 auto indices = indices_.contiguous();
436 auto grad_output = grad_output_.contiguous();
437
438 if (self.ndimension() == 3) {
439 nInputPlane = self.size(0);
440 } else {
441 ++dimw;
442 ++dimh;
443 nInputPlane = self.size(1);
444 }
445
446 nInputCols = self.size(dimw);
447 nInputRows = self.size(dimh);
448
449 if (oheight != grad_output.size(dimh) || owidth != grad_output.size(dimw)) {
450 AT_ERROR(
451 "Inconsistent gradOutput size. output height: ",
452 oheight,
453 ", output width= ",
454 owidth,
455 ", gradOutput: ",
456 grad_output.size(dimh),
457 "x",
458 grad_output.size(dimw));
459 }
460
461 grad_input.resize_as_(self);
462 grad_input.zero_();
463
464 int64_t count = self.numel();
465 if (count == 0) {
466 return grad_input;
467 }
468
469 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16,
470 self.scalar_type(), "max_unpooling2d_backward_kernel", ([&] {
471 max_unpooling2d_backward_kernel<<<
472 GET_BLOCKS(count),
473 CUDA_NUM_THREADS,
474 0,
475 at::cuda::getCurrentCUDAStream()>>>(
476 count,
477 grad_output.const_data_ptr<scalar_t>(),
478 indices.const_data_ptr<int64_t>(),
479 nInputPlane,
480 nInputRows,
481 nInputCols,
482 oheight,
483 owidth,
484 grad_input.mutable_data_ptr<scalar_t>());
485 C10_CUDA_KERNEL_LAUNCH_CHECK();
486 }));
487 return grad_input;
488 }
max_unpooling2d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & indices,IntArrayRef output_size)489 at::Tensor max_unpooling2d_backward_cuda(
490 const Tensor& grad_output,
491 const Tensor& self,
492 const Tensor& indices,
493 IntArrayRef output_size) {
494 auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
495 at::native::max_unpooling2d_backward_out_cuda(
496 grad_output, self, indices, output_size, grad_input);
497 return grad_input;
498 }
499
max_unpooling3d_backward_out_cuda(const Tensor & grad_output_,const Tensor & self_,const Tensor & indices_,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding,Tensor & grad_input)500 at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_,
501 const Tensor& self_,
502 const Tensor& indices_,
503 IntArrayRef output_size,
504 IntArrayRef stride,
505 IntArrayRef padding,
506 Tensor& grad_input) {
507 TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
508 int64_t oT = output_size[0];
509 int64_t oH = output_size[1];
510 int64_t oW = output_size[2];
511
512 max_unpooling3d_shape_check(
513 self_, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cuda()");
514
515 int batchSize = 0;
516 int inputSlices = 0;
517 int inputTime = 0;
518 int64_t inputHeight = 0;
519 int64_t inputWidth = 0;
520
521 TensorArg self_arg{self_, "self_", 1}, indices_arg{indices_, "indices_", 2},
522 grad_output_arg{grad_output_, "grad_output_", 3},
523 grad_input_arg{grad_input, "grad_input", 4};
524 checkAllSameGPU(
525 "max_unpooling3d_backward_out_cuda",
526 {self_arg, indices_arg, grad_output_arg, grad_input_arg});
527
528 auto self = self_.contiguous();
529 auto indices = indices_.contiguous();
530 auto grad_output = grad_output_.contiguous();
531
532 if (self.ndimension() == 4) {
533 batchSize = 1;
534 inputSlices = self.size(0);
535 inputTime = self.size(1);
536 inputHeight = self.size(2);
537 inputWidth = self.size(3);
538 } else {
539 batchSize = self.size(0);
540 inputSlices = self.size(1);
541 inputTime = self.size(2);
542 inputHeight = self.size(3);
543 inputWidth = self.size(4);
544 }
545
546 grad_input.resize_as_(self);
547 grad_input.zero_();
548
549 // Collapse batch and feature dimensions if needed
550 auto grad_input_reshaped = grad_input;
551 if (grad_input.ndimension() == 5) {
552 grad_input_reshaped =
553 grad_input.reshape({grad_input.size(0) * grad_input.size(1),
554 grad_input.size(2),
555 grad_input.size(3),
556 grad_input.size(4)});
557
558 indices = indices.reshape({indices.size(0) * indices.size(1),
559 indices.size(2),
560 indices.size(3),
561 indices.size(4)});
562 }
563 if (grad_input.numel() == 0) {
564 return grad_input;
565 }
566
567 int totalZ = inputTime * inputSlices * batchSize;
568 int offsetZ = 0;
569
570 dim3 block(32, 8);
571
572 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16,
573 self.scalar_type(), "max_unpooling3d_backward_kernel", ([&] {
574 while (totalZ > 0) {
575 dim3 grid(
576 ceilDiv(inputWidth, static_cast<int64_t>(block.x)),
577 ceilDiv(inputHeight, static_cast<int64_t>(block.y)),
578 totalZ > 65535 ? 65535 : totalZ);
579 max_unpooling3d_backward_kernel<<<
580 grid,
581 block,
582 0,
583 at::cuda::getCurrentCUDAStream()>>>(
584 grad_output.const_data_ptr<scalar_t>(),
585 oT,
586 oH,
587 oW,
588 indices.packed_accessor64<int64_t, 4>(),
589 grad_input_reshaped.packed_accessor64<scalar_t, 4>(),
590 offsetZ);
591 C10_CUDA_KERNEL_LAUNCH_CHECK();
592 totalZ -= 65535;
593 offsetZ += 65535;
594 }
595 }));
596 return grad_input;
597 }
598
max_unpooling3d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & indices,IntArrayRef output_size,IntArrayRef stride,IntArrayRef padding)599 at::Tensor max_unpooling3d_backward_cuda(
600 const Tensor& grad_output,
601 const Tensor& self,
602 const Tensor& indices,
603 IntArrayRef output_size,
604 IntArrayRef stride,
605 IntArrayRef padding) {
606 auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
607 at::native::max_unpooling3d_backward_out_cuda(
608 grad_output, self, indices, output_size, stride, padding, grad_input);
609 return grad_input;
610 }
611
612 } // namespace at::native
613