1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/Utils.h>
7
8 #include <ATen/cuda/CUDABlas.h>
9 #include <ATen/cuda/CUDAContext.h>
10
11 #include <ATen/native/ConvUtils.h>
12 #include <ATen/native/cuda/vol2col.cuh>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/empty_like.h>
20 #include <ATen/ops/sum.h>
21 #include <ATen/ops/ones.h>
22 #include <ATen/ops/slow_conv_transpose3d_native.h>
23 #endif
24
25 namespace at::native {
26 namespace {
27
slow_conv_transpose3d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int kernel_depth,int kernel_width,int kernel_height,int stride_depth,int stride_width,int stride_height,int padding_depth,int padding_width,int padding_height,int dilation_depth,int dilation_width,int dilation_height,int output_padding_depth,int output_padding_width,int output_padding_height,int weight_nullable)28 static inline void slow_conv_transpose3d_shape_check(
29 const Tensor& input,
30 const Tensor& grad_output,
31 const Tensor& weight,
32 const Tensor& bias,
33 int kernel_depth,
34 int kernel_width,
35 int kernel_height,
36 int stride_depth,
37 int stride_width,
38 int stride_height,
39 int padding_depth,
40 int padding_width,
41 int padding_height,
42 int dilation_depth,
43 int dilation_width,
44 int dilation_height,
45 int output_padding_depth,
46 int output_padding_width,
47 int output_padding_height,
48 int weight_nullable) {
49 TORCH_CHECK(
50 input.numel() != 0 && (input.dim() == 4 || input.dim() == 5),
51 "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ",
52 input.sizes());
53 TORCH_CHECK(
54 stride_depth > 0 && stride_width > 0 && stride_height > 0,
55 "stride should be greater than zero, but got stride_depth: ",
56 stride_depth,
57 " stride_height: ",
58 stride_height,
59 " stride_width: ",
60 stride_width);
61 TORCH_CHECK(
62 dilation_depth > 0 && dilation_width > 0 && dilation_height > 0,
63 "dilation should be greater than zero, but got dilation_depth: ",
64 dilation_depth,
65 ", dilation_height: ",
66 dilation_height,
67 ", dilation_width: ",
68 dilation_width);
69 TORCH_CHECK(
70 (output_padding_depth < stride_depth ||
71 output_padding_depth < dilation_depth) &&
72 (output_padding_width < stride_width ||
73 output_padding_width < dilation_width) &&
74 (output_padding_height < stride_height ||
75 output_padding_height < dilation_height),
76 "output padding must be smaller than either stride or dilation,",
77 " but got output_padding_depth: ",
78 output_padding_depth,
79 " output_padding_height: ",
80 output_padding_height,
81 " output_padding_width: ",
82 output_padding_width,
83 " stride_depth: ",
84 stride_depth,
85 " stride_height: ",
86 stride_height,
87 " stride_width: ",
88 stride_width,
89 " dilation_depth: ",
90 dilation_depth,
91 " dilation_height: ",
92 dilation_height,
93 " dilation_width: ",
94 dilation_width);
95
96 // number of input & output planes and kernel size is indirectly defined by
97 // the weight tensor
98 if (weight.defined()) {
99 TORCH_CHECK(
100 weight.numel() != 0 && weight.dim() == 5,
101 "non-empty 5D (n_output_plane x n_input_plane ",
102 "x kernel_depth x kernel_height x kernel_width) tensor ",
103 "expected for weight, but got: ",
104 weight.sizes());
105 if (bias.defined()) {
106 check_dim_size(bias, 1, 0, weight.size(1));
107 }
108 } else if (!weight_nullable) {
109 AT_ERROR("weight tensor is expected to be non-nullable");
110 }
111
112 int ndim = input.dim();
113 int dimf = 0;
114 int dimd = 1;
115 int dimh = 2;
116 int dimw = 3;
117
118 if (ndim == 5) {
119 dimf++;
120 dimd++;
121 dimh++;
122 dimw++;
123 }
124
125 if (weight.defined()) {
126 const int64_t n_input_plane = weight.size(0);
127 check_dim_size(input, ndim, dimf, n_input_plane);
128 }
129
130 int64_t input_width = input.size(dimw);
131 int64_t input_height = input.size(dimh);
132 int64_t input_depth = input.size(dimd);
133
134 int64_t output_depth = (input_depth - 1) * stride_depth - 2 * padding_depth +
135 (dilation_depth * (kernel_depth - 1) + 1) + output_padding_depth;
136 int64_t output_height = (input_height - 1) * stride_height -
137 2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
138 output_padding_height;
139 int64_t output_width = (input_width - 1) * stride_width - 2 * padding_width +
140 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
141
142 if (output_depth < 1 || output_width < 1 || output_height < 1) {
143 AT_ERROR(
144 "Given input size per channel: (",
145 input_depth,
146 " x ",
147 input_height,
148 " x ",
149 input_width,
150 "). Calculated output size per channel: (",
151 output_depth,
152 " x ",
153 output_height,
154 " x ",
155 output_width,
156 "). Output size is too small");
157 }
158
159 if (grad_output.defined()) {
160 if (weight.defined()) {
161 const int64_t n_output_plane = weight.size(1);
162 check_dim_size(grad_output, ndim, dimf, n_output_plane);
163 } else if (bias.defined()) {
164 const int64_t n_output_plane = bias.size(0);
165 check_dim_size(grad_output, ndim, dimf, n_output_plane);
166 }
167 check_dim_size(grad_output, ndim, dimd, output_depth);
168 check_dim_size(grad_output, ndim, dimh, output_height);
169 check_dim_size(grad_output, ndim, dimw, output_width);
170 }
171 }
172
slow_conv_transpose3d_out_cuda_template(Tensor & output,const Tensor & input_,const Tensor & weight_,IntArrayRef kernel_size,const Tensor & bias_,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)173 void slow_conv_transpose3d_out_cuda_template(
174 Tensor& output,
175 const Tensor& input_,
176 const Tensor& weight_,
177 IntArrayRef kernel_size,
178 const Tensor& bias_,
179 IntArrayRef stride,
180 IntArrayRef padding,
181 IntArrayRef output_padding,
182 IntArrayRef dilation) {
183 TORCH_CHECK(
184 kernel_size.size() == 3,
185 "It is expected kernel_size equals to 3, but got size ",
186 kernel_size.size());
187
188 TORCH_CHECK(
189 dilation.size() == 3,
190 "It is expected dilation equals to 3, but got size ",
191 dilation.size());
192
193 TORCH_CHECK(
194 padding.size() == 3,
195 "It is expected padding equals to 3, but got size ",
196 padding.size());
197
198 TORCH_CHECK(
199 stride.size() == 3,
200 "It is expected stride equals to 3, but got size ",
201 stride.size());
202
203 TORCH_CHECK(
204 output_padding.size() == 3,
205 "It is expected stride equals to 3, but got size ",
206 output_padding.size());
207
208 int64_t kernel_depth = kernel_size[0];
209 int64_t kernel_height = kernel_size[1];
210 int64_t kernel_width = kernel_size[2];
211 int64_t dilation_depth = dilation[0];
212 int64_t dilation_height = dilation[1];
213 int64_t dilation_width = dilation[2];
214 int64_t padding_depth = padding[0];
215 int64_t padding_height = padding[1];
216 int64_t padding_width = padding[2];
217 int64_t stride_depth = stride[0];
218 int64_t stride_height = stride[1];
219 int64_t stride_width = stride[2];
220 int64_t output_padding_depth = output_padding[0];
221 int64_t output_padding_height = output_padding[1];
222 int64_t output_padding_width = output_padding[2];
223
224 int n_input_plane = weight_.size(0);
225 int n_output_plane = weight_.size(1);
226
227 TensorArg input_arg{input_, "input", 1}, output_arg{output, "output", 2},
228 weight_arg{weight_, "weight", 3}, bias_arg{bias_, "bias", 4};
229
230 checkAllSameGPU(
231 "slow_conv_transpose3d_out_cuda",
232 {input_arg, output_arg, weight_arg, bias_arg});
233
234 slow_conv_transpose3d_shape_check(
235 input_,
236 Tensor(),
237 weight_,
238 bias_,
239 kernel_depth,
240 kernel_width,
241 kernel_height,
242 stride_depth,
243 stride_width,
244 stride_height,
245 padding_depth,
246 padding_width,
247 padding_height,
248 dilation_depth,
249 dilation_width,
250 dilation_height,
251 output_padding_depth,
252 output_padding_width,
253 output_padding_height,
254 0);
255
256 Tensor input = input_.contiguous();
257 Tensor weight = weight_.contiguous();
258 Tensor bias = bias_.defined() ? bias_.contiguous() : bias_;
259
260 int is_batch = false;
261 if (input.dim() == 4) {
262 // Force batch
263 is_batch = true;
264 input.resize_(
265 {1, input.size(0), input.size(1), input.size(2), input.size(3)});
266 }
267
268 int64_t input_width = input.size(4);
269 int64_t input_height = input.size(3);
270 int64_t input_depth = input.size(2);
271
272 int64_t output_depth = (input_depth - 1) * stride_depth - 2 * padding_depth +
273 (dilation_depth * (kernel_depth - 1) + 1) + output_padding_depth;
274 int64_t output_height = (input_height - 1) * stride_height -
275 2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
276 output_padding_height;
277 int64_t output_width = (input_width - 1) * stride_width - 2 * padding_width +
278 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
279
280 // Batch size + input planes
281 int64_t batch_size = input.size(0);
282
283 // Resize output
284 output.resize_(
285 {batch_size, n_output_plane, output_depth, output_height, output_width});
286
287 // Create temporary columns
288 Tensor columns = at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
289 input_depth * input_height * input_width}, input.options());
290
291 // Define a buffer of ones, for bias accumulation
292 Tensor ones = bias.defined() ? at::ones({output_depth, output_height, output_width}, input_.options()) : Tensor();
293
294 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
295 input.scalar_type(), "slow_conv_transpose3d_out_cuda", [&] {
296 using accscalar_t = at::acc_type<scalar_t, true>;
297
298 // Helpers
299 Tensor input_n;
300 Tensor output_n;
301
302 // For each elt in batch, do:
303 for (int elt = 0; elt < batch_size; elt++) {
304 // Matrix multiply per output:
305 input_n = input.select(0, elt);
306 output_n = output.select(0, elt);
307
308 // M,N,K are dims of matrix A and B
309 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
310 int64_t m =
311 weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
312 int64_t n = columns.size(1);
313 int64_t k = weight.size(0);
314
315 // Do GEMM (note: this is a bit confusing because gemm assumes
316 // column-major matrices)
317 at::cuda::blas::gemm<scalar_t>(
318 'n',
319 't',
320 n,
321 m,
322 k,
323 static_cast<scalar_t>(1),
324 input_n.const_data_ptr<scalar_t>(),
325 n,
326 weight.const_data_ptr<scalar_t>(),
327 m,
328 static_cast<scalar_t>(0),
329 columns.mutable_data_ptr<scalar_t>(),
330 n);
331
332 // Unpack columns back into input:
333 at::native::col2vol<scalar_t, accscalar_t>(
334 at::cuda::getCurrentCUDAStream(),
335 columns.const_data_ptr<scalar_t>(),
336 n_output_plane,
337 output_depth,
338 output_height,
339 output_width,
340 input_depth,
341 input_height,
342 input_width,
343 kernel_depth,
344 kernel_height,
345 kernel_width,
346 padding_depth,
347 padding_height,
348 padding_width,
349 stride_depth,
350 stride_height,
351 stride_width,
352 dilation_depth,
353 dilation_height,
354 dilation_width,
355 output_n.mutable_data_ptr<scalar_t>());
356
357 // Do Bias after:
358 // M,N,K are dims of matrix A and B
359 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
360 int64_t m_ = n_output_plane;
361 int64_t n_ = output_depth * output_height * output_width;
362 int64_t k_ = 1;
363
364 // Do GEMM (note: this is a bit confusing because gemm assumes
365 // column-major matrices)
366 if (bias.defined()) {
367 at::cuda::blas::gemm<scalar_t>(
368 't',
369 'n',
370 n_,
371 m_,
372 k_,
373 static_cast<scalar_t>(1),
374 ones.const_data_ptr<scalar_t>(),
375 k_,
376 bias.const_data_ptr<scalar_t>(),
377 k_,
378 static_cast<scalar_t>(1),
379 output_n.mutable_data_ptr<scalar_t>(),
380 n_);
381 }
382 }
383
384 // Resize output
385 if (is_batch) {
386 output.resize_(
387 {n_output_plane, output_depth, output_height, output_width});
388 input.resize_(
389 {n_input_plane, input_depth, input_height, input_width});
390 }
391 });
392 }
393
slow_conv_transpose3d_backward_out_cuda_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_input,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)394 void slow_conv_transpose3d_backward_out_cuda_template(
395 const Tensor& input_,
396 const Tensor& grad_output_,
397 Tensor& grad_input,
398 const Tensor& weight_,
399 IntArrayRef kernel_size,
400 IntArrayRef stride,
401 IntArrayRef padding,
402 IntArrayRef output_padding,
403 IntArrayRef dilation) {
404 TORCH_CHECK(
405 kernel_size.size() == 3,
406 "It is expected kernel_size equals to 3, but got size ",
407 kernel_size.size());
408
409 TORCH_CHECK(
410 dilation.size() == 3,
411 "It is expected dilation equals to 3, but got size ",
412 dilation.size());
413
414 TORCH_CHECK(
415 padding.size() == 3,
416 "It is expected padding equals to 3, but got size ",
417 padding.size());
418
419 TORCH_CHECK(
420 stride.size() == 3,
421 "It is expected stride equals to 3, but got size ",
422 stride.size());
423
424 TORCH_CHECK(
425 output_padding.size() == 3,
426 "It is expected stride equals to 3, but got size ",
427 output_padding.size());
428
429 int n_input_plane = weight_.size(0);
430 int n_output_plane = weight_.size(1);
431
432 int64_t kernel_depth = kernel_size[0];
433 int64_t kernel_height = kernel_size[1];
434 int64_t kernel_width = kernel_size[2];
435 int64_t dilation_depth = dilation[0];
436 int64_t dilation_height = dilation[1];
437 int64_t dilation_width = dilation[2];
438 int64_t padding_depth = padding[0];
439 int64_t padding_height = padding[1];
440 int64_t padding_width = padding[2];
441 int64_t stride_depth = stride[0];
442 int64_t stride_height = stride[1];
443 int64_t stride_width = stride[2];
444 int64_t output_padding_depth = output_padding[0];
445 int64_t output_padding_height = output_padding[1];
446 int64_t output_padding_width = output_padding[2];
447
448 TensorArg input_arg{input_, "input", 1},
449 grad_output_arg{grad_output_, "grad_output", 2},
450 weight_arg{weight_, "weight", 3},
451 grad_input_arg{grad_input, "grad_input", 4};
452
453 checkAllSameGPU(
454 "slow_conv_transpose3d_backward_out_cuda",
455 {input_arg,
456 grad_output_arg,
457 weight_arg,
458 grad_input_arg});
459
460 slow_conv_transpose3d_shape_check(
461 input_,
462 grad_output_,
463 weight_,
464 Tensor(),
465 kernel_depth,
466 kernel_width,
467 kernel_height,
468 stride_depth,
469 stride_width,
470 stride_height,
471 padding_depth,
472 padding_width,
473 padding_height,
474 dilation_depth,
475 dilation_width,
476 dilation_height,
477 output_padding_depth,
478 output_padding_width,
479 output_padding_height,
480 0);
481
482 Tensor input = input_.contiguous();
483 Tensor grad_output = grad_output_.contiguous();
484 Tensor weight = weight_.contiguous();
485
486 bool is_batch = false;
487 if (input.dim() == 4) {
488 // Force batch
489 is_batch = true;
490 input.resize_(
491 {1, input.size(0), input.size(1), input.size(2), input.size(3)});
492 grad_output.resize_({1,
493 grad_output.size(0),
494 grad_output.size(1),
495 grad_output.size(2),
496 grad_output.size(3)});
497 }
498
499 int64_t input_width = input.size(4);
500 int64_t input_height = input.size(3);
501 int64_t input_depth = input.size(2);
502 int64_t output_depth = (input_depth - 1) * stride_depth - 2 * padding_depth +
503 (dilation_depth * (kernel_depth - 1) + 1) + output_padding_depth;
504 int64_t output_height = (input_height - 1) * stride_height -
505 2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
506 output_padding_height;
507 int64_t output_width = (input_width - 1) * stride_width - 2 * padding_width +
508 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
509
510 // Batch size + input planes
511 int64_t batch_size = input.size(0);
512
513 // Resize output
514 grad_input.resize_(
515 {batch_size, n_input_plane, input_depth, input_height, input_width});
516
517 // Create temporary columns
518 bool need_columns = (kernel_depth != 1 || kernel_height != 1 || kernel_width != 1 ||
519 stride_depth != 1 || stride_height != 1 || stride_width != 1 ||
520 dilation_depth != 1 || dilation_height != 1 ||
521 dilation_width != 1 || padding_depth != 0 ||
522 padding_height != 0 || padding_width != 0);
523 Tensor grad_columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
524 input_depth * input_height * input_width}, input.options()) : Tensor();
525
526 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
527 input.scalar_type(), "slow_conv_transpose3d_backward_out_cuda", [&] {
528 // Helpers
529 Tensor grad_input_n;
530 Tensor grad_output_n;
531
532 // For each elt in batch, do:
533 for (int elt = 0; elt < batch_size; elt++) {
534 // Matrix multiply per sample:
535 grad_input_n = grad_input.select(0, elt);
536 grad_output_n = grad_output.select(0, elt);
537
538 if (need_columns) {
539 // Extract columns:
540 at::native::vol2col<scalar_t>(
541 at::cuda::getCurrentCUDAStream(),
542 grad_output_n.const_data_ptr<scalar_t>(),
543 n_output_plane,
544 output_depth,
545 output_height,
546 output_width,
547 input_depth,
548 input_height,
549 input_width,
550 kernel_depth,
551 kernel_height,
552 kernel_width,
553 padding_depth,
554 padding_height,
555 padding_width,
556 stride_depth,
557 stride_height,
558 stride_width,
559 dilation_depth,
560 dilation_height,
561 dilation_width,
562 grad_columns.mutable_data_ptr<scalar_t>());
563 }
564
565 // M,N,K are dims of matrix A and B
566 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
567 int64_t m = weight.size(0);
568 int64_t n = input_depth * input_height * input_width;
569 int64_t k =
570 weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
571
572 // Do GEMM (note: this is a bit confusing because gemm assumes
573 // column-major matrices)
574 auto gemm_in_ptr = need_columns ? grad_columns.const_data_ptr<scalar_t>()
575 : grad_output_n.const_data_ptr<scalar_t>();
576 at::cuda::blas::gemm<scalar_t>(
577 'n',
578 'n',
579 n,
580 m,
581 k,
582 static_cast<scalar_t>(1),
583 gemm_in_ptr,
584 n,
585 weight.const_data_ptr<scalar_t>(),
586 k,
587 static_cast<scalar_t>(0),
588 grad_input_n.mutable_data_ptr<scalar_t>(),
589 n);
590 }
591
592 // Resize output
593 if (is_batch) {
594 grad_output.resize_(
595 {n_output_plane, output_depth, output_height, output_width});
596 input.resize_(
597 {n_input_plane, input_depth, input_height, input_width});
598 grad_input.resize_(
599 {n_input_plane, input_depth, input_height, input_width});
600 }
601 });
602 }
603
slow_conv_transpose3d_acc_grad_parameters_cuda(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,int scale_)604 void slow_conv_transpose3d_acc_grad_parameters_cuda(
605 const Tensor& input_,
606 const Tensor& grad_output_,
607 Tensor& grad_weight,
608 Tensor& grad_bias,
609 IntArrayRef kernel_size,
610 IntArrayRef stride,
611 IntArrayRef padding,
612 IntArrayRef output_padding,
613 IntArrayRef dilation,
614 int scale_) {
615 TORCH_CHECK(
616 kernel_size.size() == 3,
617 "It is expected kernel_size equals to 3, but got size ",
618 kernel_size.size());
619
620 TORCH_CHECK(
621 dilation.size() == 3,
622 "It is expected dilation equals to 3, but got size ",
623 dilation.size());
624
625 TORCH_CHECK(
626 padding.size() == 3,
627 "It is expected padding equals to 3, but got size ",
628 padding.size());
629
630 TORCH_CHECK(
631 stride.size() == 3,
632 "It is expected stride equals to 3, but got size ",
633 stride.size());
634
635 TORCH_CHECK(
636 output_padding.size() == 3,
637 "It is expected stride equals to 3, but got size ",
638 output_padding.size());
639
640 int64_t kernel_depth = kernel_size[0];
641 int64_t kernel_height = kernel_size[1];
642 int64_t kernel_width = kernel_size[2];
643 int64_t dilation_depth = dilation[0];
644 int64_t dilation_height = dilation[1];
645 int64_t dilation_width = dilation[2];
646 int64_t padding_depth = padding[0];
647 int64_t padding_height = padding[1];
648 int64_t padding_width = padding[2];
649 int64_t stride_depth = stride[0];
650 int64_t stride_height = stride[1];
651 int64_t stride_width = stride[2];
652 int64_t output_padding_depth = output_padding[0];
653 int64_t output_padding_height = output_padding[1];
654 int64_t output_padding_width = output_padding[2];
655
656 TensorArg input_arg{input_, "input", 1},
657 grad_output_arg{grad_output_, "grad_output", 2},
658 grad_weight_arg{grad_weight, "grad_weight", 3},
659 grad_bias_arg{grad_bias, "grad_bias", 4};
660
661 checkAllSameGPU(
662 "slow_conv_transpose3d_acc_grad_parameters_cuda",
663 {input_arg,
664 grad_output_arg,
665 grad_weight_arg,
666 grad_bias_arg});
667
668 slow_conv_transpose3d_shape_check(
669 input_,
670 grad_output_,
671 grad_weight,
672 grad_bias,
673 kernel_depth,
674 kernel_width,
675 kernel_height,
676 stride_depth,
677 stride_width,
678 stride_height,
679 padding_depth,
680 padding_width,
681 padding_height,
682 dilation_depth,
683 dilation_width,
684 dilation_height,
685 output_padding_depth,
686 output_padding_width,
687 output_padding_height,
688 1);
689
690 int n_output_plane;
691 if (grad_weight.defined()) {
692 n_output_plane = grad_weight.size(1);
693 } else if (grad_bias.defined()) {
694 n_output_plane = grad_bias.size(0);
695 } else {
696 return;
697 }
698
699 if (grad_weight.defined()) {
700 TORCH_CHECK(
701 grad_weight.is_contiguous(), "grad_weight needs to be contiguous");
702 }
703 if (grad_bias.defined()) {
704 TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous");
705 }
706
707 Tensor input = input_.contiguous();
708 Tensor grad_output = grad_output_.contiguous();
709
710 bool is_batch = false;
711 if (input.dim() == 4) {
712 // Force batch
713 is_batch = true;
714 input.resize_(
715 {1, input.size(0), input.size(1), input.size(2), input.size(3)});
716 grad_output.resize_({1,
717 grad_output.size(0),
718 grad_output.size(1),
719 grad_output.size(2),
720 grad_output.size(3)});
721 }
722
723 int64_t input_width = input.size(4);
724 int64_t input_height = input.size(3);
725 int64_t input_depth = input.size(2);
726
727 int64_t output_depth = (input_depth - 1) * stride_depth - 2 * padding_depth +
728 (dilation_depth * (kernel_depth - 1) + 1) + output_padding_depth;
729 int64_t output_height = (input_height - 1) * stride_height -
730 2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
731 output_padding_height;
732 int64_t output_width = (input_width - 1) * stride_width - 2 * padding_width +
733 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
734
735 // Batch size + input planes
736 int64_t batch_size = input.size(0);
737
738 // Create temporary columns
739 bool need_columns = (kernel_depth != 1 || kernel_height != 1 || kernel_width != 1 ||
740 stride_depth != 1 || stride_height != 1 || stride_width != 1 ||
741 dilation_depth != 1 || dilation_height != 1 ||
742 dilation_width != 1 || padding_depth != 0 ||
743 padding_height != 0 || padding_width != 0);
744 Tensor columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
745 input_depth * input_height * input_width}, input.options()) : Tensor();
746
747 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
748 input.scalar_type(),
749 "slow_conv_transpose3d_acc_grad_parameters_cuda",
750 [&] {
751 // Helpers
752 Tensor input_n;
753 Tensor grad_output_n;
754
755 scalar_t scale = static_cast<scalar_t>(scale_);
756
757 // For each elt in batch, do:
758 for (int elt = 0; elt < batch_size; elt++) {
759 // Matrix multiply per output:
760 grad_output_n = grad_output.select(0, elt);
761
762 // Do Weight:
763 if (grad_weight.defined()) {
764 // Matrix multiply per output:
765 input_n = input.select(0, elt);
766
767 if (need_columns) {
768 // Extract columns:
769 at::native::vol2col<scalar_t>(
770 at::cuda::getCurrentCUDAStream(),
771 grad_output_n.const_data_ptr<scalar_t>(),
772 n_output_plane,
773 output_depth,
774 output_height,
775 output_width,
776 input_depth,
777 input_height,
778 input_width,
779 kernel_depth,
780 kernel_height,
781 kernel_width,
782 padding_depth,
783 padding_height,
784 padding_width,
785 stride_depth,
786 stride_height,
787 stride_width,
788 dilation_depth,
789 dilation_height,
790 dilation_width,
791 columns.mutable_data_ptr<scalar_t>());
792 }
793
794 // M,N,K are dims of matrix A and B
795 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
796 int64_t n = n_output_plane * kernel_width * kernel_height * kernel_depth;
797 int64_t m = input_n.size(0); // n_input_plane
798 int64_t k = input_depth * input_height * input_width;
799
800 // Do GEMM (note: this is a bit confusing because gemm assumes
801 // column-major matrices)
802 auto gemm_in_ptr = need_columns ? columns.const_data_ptr<scalar_t>() : grad_output_n.const_data_ptr<scalar_t>();
803 at::cuda::blas::gemm<scalar_t>(
804 't',
805 'n',
806 n,
807 m,
808 k,
809 scale,
810 gemm_in_ptr,
811 k,
812 input_n.const_data_ptr<scalar_t>(),
813 k,
814 static_cast<scalar_t>(1),
815 grad_weight.mutable_data_ptr<scalar_t>(),
816 n);
817 }
818 }
819
820 if (grad_bias.defined()) {
821 at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3, 4});
822 }
823
824 // Resize
825 if (is_batch) {
826 grad_output.resize_(
827 {n_output_plane, output_depth, output_height, output_width});
828 input.resize_(
829 {input.size(1), input_depth, input_height, input_width});
830 }
831 });
832 }
833
834 } // namespace
835
slow_conv_transpose3d_out_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,Tensor & output)836 Tensor& slow_conv_transpose3d_out_cuda(const Tensor& input,
837 const Tensor& weight,
838 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
839 IntArrayRef stride,
840 IntArrayRef padding,
841 IntArrayRef output_padding,
842 IntArrayRef dilation,
843 Tensor& output) {
844 // See [Note: hacky wrapper removal for optional tensor]
845 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
846 const Tensor& bias = *bias_maybe_owned;
847
848 slow_conv_transpose3d_out_cuda_template(
849 output,
850 input,
851 weight,
852 kernel_size,
853 bias,
854 stride,
855 padding,
856 output_padding,
857 dilation);
858
859 return output;
860 }
861
slow_conv_transpose3d_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)862 Tensor slow_conv_transpose3d_cuda(
863 const Tensor& input,
864 const Tensor& weight,
865 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
866 IntArrayRef stride,
867 IntArrayRef padding,
868 IntArrayRef output_padding,
869 IntArrayRef dilation) {
870 // See [Note: hacky wrapper removal for optional tensor]
871 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
872 const Tensor& bias = *bias_maybe_owned;
873
874 Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
875
876 slow_conv_transpose3d_out_cuda_template(
877 output,
878 input,
879 weight,
880 kernel_size,
881 bias,
882 stride,
883 padding,
884 output_padding,
885 dilation);
886
887 return output;
888 }
889
slow_conv_transpose3d_backward_out_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)890 std::tuple<Tensor&, Tensor&, Tensor&> slow_conv_transpose3d_backward_out_cuda(const Tensor& grad_output,
891 const Tensor& input,
892 const Tensor& weight,
893 IntArrayRef kernel_size,
894 IntArrayRef stride,
895 IntArrayRef padding,
896 IntArrayRef output_padding,
897 IntArrayRef dilation,
898 Tensor& grad_input,
899 Tensor& grad_weight,
900 Tensor& grad_bias) {
901 if (grad_input.defined()) {
902 slow_conv_transpose3d_backward_out_cuda_template(
903 input,
904 grad_output,
905 grad_input,
906 weight,
907 kernel_size,
908 stride,
909 padding,
910 output_padding,
911 dilation);
912 }
913
914 if (grad_weight.defined()) {
915 grad_weight.resize_(weight.sizes());
916 grad_weight.zero_();
917 }
918
919 if (grad_bias.defined()) {
920 grad_bias.resize_({weight.size(1)});
921 grad_bias.zero_();
922 }
923
924 if (grad_weight.defined() || grad_bias.defined()) {
925 slow_conv_transpose3d_acc_grad_parameters_cuda(
926 input,
927 grad_output,
928 grad_weight,
929 grad_bias,
930 kernel_size,
931 stride,
932 padding,
933 output_padding,
934 dilation,
935 1);
936 }
937
938 return std::tuple<Tensor&, Tensor&, Tensor&>(
939 grad_input, grad_weight, grad_bias);
940 }
941
slow_conv_transpose3d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,std::array<bool,3> output_mask)942 std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose3d_backward_cuda(
943 const Tensor& grad_output,
944 const Tensor& input,
945 const Tensor& weight,
946 IntArrayRef kernel_size,
947 IntArrayRef stride,
948 IntArrayRef padding,
949 IntArrayRef output_padding,
950 IntArrayRef dilation,
951 std::array<bool, 3> output_mask) {
952 Tensor grad_input;
953 Tensor grad_weight;
954 Tensor grad_bias;
955
956 if (output_mask[0]) {
957 grad_input = at::empty({0}, grad_output.options());
958 } else {
959 grad_input = Tensor();
960 }
961
962 if (output_mask[1]) {
963 grad_weight = at::empty({0}, grad_output.options());
964 } else {
965 grad_weight = Tensor();
966 }
967
968 if (output_mask[2]) {
969 grad_bias = at::empty({0}, grad_output.options());
970 } else {
971 grad_bias = Tensor();
972 }
973
974 if (grad_input.defined()) {
975 slow_conv_transpose3d_backward_out_cuda_template(
976 input,
977 grad_output,
978 grad_input,
979 weight,
980 kernel_size,
981 stride,
982 padding,
983 output_padding,
984 dilation);
985 }
986
987 if (grad_weight.defined()) {
988 grad_weight.resize_(weight.sizes());
989 grad_weight.zero_();
990 }
991
992 if (grad_bias.defined()) {
993 grad_bias.resize_({weight.size(1)});
994 grad_bias.zero_();
995 }
996
997 if (grad_weight.defined() || grad_bias.defined()) {
998 slow_conv_transpose3d_acc_grad_parameters_cuda(
999 input,
1000 grad_output,
1001 grad_weight,
1002 grad_bias,
1003 kernel_size,
1004 stride,
1005 padding,
1006 output_padding,
1007 dilation,
1008 1);
1009 }
1010
1011 return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
1012 }
1013
1014 REGISTER_CUDA_DISPATCH(slow_conv_transpose3d_backward_stub, &slow_conv_transpose3d_backward_cuda);
1015
1016 } // namespace at::native
1017