xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/Utils.h>
7 
8 #include <ATen/cuda/CUDABlas.h>
9 #include <ATen/cuda/CUDAContext.h>
10 
11 #include <ATen/native/ConvUtils.h>
12 #include <ATen/native/cuda/vol2col.cuh>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/empty_like.h>
20 #include <ATen/ops/sum.h>
21 #include <ATen/ops/ones.h>
22 #include <ATen/ops/slow_conv_transpose3d_native.h>
23 #endif
24 
25 namespace at::native {
26 namespace {
27 
slow_conv_transpose3d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int kernel_depth,int kernel_width,int kernel_height,int stride_depth,int stride_width,int stride_height,int padding_depth,int padding_width,int padding_height,int dilation_depth,int dilation_width,int dilation_height,int output_padding_depth,int output_padding_width,int output_padding_height,int weight_nullable)28 static inline void slow_conv_transpose3d_shape_check(
29     const Tensor& input,
30     const Tensor& grad_output,
31     const Tensor& weight,
32     const Tensor& bias,
33     int kernel_depth,
34     int kernel_width,
35     int kernel_height,
36     int stride_depth,
37     int stride_width,
38     int stride_height,
39     int padding_depth,
40     int padding_width,
41     int padding_height,
42     int dilation_depth,
43     int dilation_width,
44     int dilation_height,
45     int output_padding_depth,
46     int output_padding_width,
47     int output_padding_height,
48     int weight_nullable) {
49   TORCH_CHECK(
50       input.numel() != 0 && (input.dim() == 4 || input.dim() == 5),
51       "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ",
52       input.sizes());
53   TORCH_CHECK(
54       stride_depth > 0 && stride_width > 0 && stride_height > 0,
55       "stride should be greater than zero, but got stride_depth: ",
56       stride_depth,
57       " stride_height: ",
58       stride_height,
59       " stride_width: ",
60       stride_width);
61   TORCH_CHECK(
62       dilation_depth > 0 && dilation_width > 0 && dilation_height > 0,
63       "dilation should be greater than zero, but got dilation_depth: ",
64       dilation_depth,
65       ", dilation_height: ",
66       dilation_height,
67       ", dilation_width: ",
68       dilation_width);
69   TORCH_CHECK(
70       (output_padding_depth < stride_depth ||
71        output_padding_depth < dilation_depth) &&
72           (output_padding_width < stride_width ||
73            output_padding_width < dilation_width) &&
74           (output_padding_height < stride_height ||
75            output_padding_height < dilation_height),
76       "output padding must be smaller than either stride or dilation,",
77       " but got output_padding_depth: ",
78       output_padding_depth,
79       " output_padding_height: ",
80       output_padding_height,
81       " output_padding_width: ",
82       output_padding_width,
83       " stride_depth: ",
84       stride_depth,
85       " stride_height: ",
86       stride_height,
87       " stride_width: ",
88       stride_width,
89       " dilation_depth: ",
90       dilation_depth,
91       " dilation_height: ",
92       dilation_height,
93       " dilation_width: ",
94       dilation_width);
95 
96   // number of input & output planes and kernel size is indirectly defined by
97   // the weight tensor
98   if (weight.defined()) {
99     TORCH_CHECK(
100         weight.numel() != 0 && weight.dim() == 5,
101         "non-empty 5D (n_output_plane x n_input_plane ",
102         "x kernel_depth x kernel_height x kernel_width) tensor ",
103         "expected for weight, but got: ",
104         weight.sizes());
105     if (bias.defined()) {
106       check_dim_size(bias, 1, 0, weight.size(1));
107     }
108   } else if (!weight_nullable) {
109     AT_ERROR("weight tensor is expected to be non-nullable");
110   }
111 
112   int ndim = input.dim();
113   int dimf = 0;
114   int dimd = 1;
115   int dimh = 2;
116   int dimw = 3;
117 
118   if (ndim == 5) {
119     dimf++;
120     dimd++;
121     dimh++;
122     dimw++;
123   }
124 
125   if (weight.defined()) {
126     const int64_t n_input_plane = weight.size(0);
127     check_dim_size(input, ndim, dimf, n_input_plane);
128   }
129 
130   int64_t input_width = input.size(dimw);
131   int64_t input_height = input.size(dimh);
132   int64_t input_depth = input.size(dimd);
133 
134   int64_t output_depth = (input_depth - 1) * stride_depth - 2 * padding_depth +
135       (dilation_depth * (kernel_depth - 1) + 1) + output_padding_depth;
136   int64_t output_height = (input_height - 1) * stride_height -
137       2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
138       output_padding_height;
139   int64_t output_width = (input_width - 1) * stride_width - 2 * padding_width +
140       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
141 
142   if (output_depth < 1 || output_width < 1 || output_height < 1) {
143     AT_ERROR(
144         "Given input size per channel: (",
145         input_depth,
146         " x ",
147         input_height,
148         " x ",
149         input_width,
150         "). Calculated output size per channel: (",
151         output_depth,
152         " x ",
153         output_height,
154         " x ",
155         output_width,
156         "). Output size is too small");
157   }
158 
159   if (grad_output.defined()) {
160     if (weight.defined()) {
161       const int64_t n_output_plane = weight.size(1);
162       check_dim_size(grad_output, ndim, dimf, n_output_plane);
163     } else if (bias.defined()) {
164       const int64_t n_output_plane = bias.size(0);
165       check_dim_size(grad_output, ndim, dimf, n_output_plane);
166     }
167     check_dim_size(grad_output, ndim, dimd, output_depth);
168     check_dim_size(grad_output, ndim, dimh, output_height);
169     check_dim_size(grad_output, ndim, dimw, output_width);
170   }
171 }
172 
slow_conv_transpose3d_out_cuda_template(Tensor & output,const Tensor & input_,const Tensor & weight_,IntArrayRef kernel_size,const Tensor & bias_,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)173 void slow_conv_transpose3d_out_cuda_template(
174     Tensor& output,
175     const Tensor& input_,
176     const Tensor& weight_,
177     IntArrayRef kernel_size,
178     const Tensor& bias_,
179     IntArrayRef stride,
180     IntArrayRef padding,
181     IntArrayRef output_padding,
182     IntArrayRef dilation) {
183   TORCH_CHECK(
184       kernel_size.size() == 3,
185       "It is expected kernel_size equals to 3, but got size ",
186       kernel_size.size());
187 
188   TORCH_CHECK(
189       dilation.size() == 3,
190       "It is expected dilation equals to 3, but got size ",
191       dilation.size());
192 
193   TORCH_CHECK(
194       padding.size() == 3,
195       "It is expected padding equals to 3, but got size ",
196       padding.size());
197 
198   TORCH_CHECK(
199       stride.size() == 3,
200       "It is expected stride equals to 3, but got size ",
201       stride.size());
202 
203   TORCH_CHECK(
204       output_padding.size() == 3,
205       "It is expected stride equals to 3, but got size ",
206       output_padding.size());
207 
208   int64_t kernel_depth = kernel_size[0];
209   int64_t kernel_height = kernel_size[1];
210   int64_t kernel_width = kernel_size[2];
211   int64_t dilation_depth = dilation[0];
212   int64_t dilation_height = dilation[1];
213   int64_t dilation_width = dilation[2];
214   int64_t padding_depth = padding[0];
215   int64_t padding_height = padding[1];
216   int64_t padding_width = padding[2];
217   int64_t stride_depth = stride[0];
218   int64_t stride_height = stride[1];
219   int64_t stride_width = stride[2];
220   int64_t output_padding_depth = output_padding[0];
221   int64_t output_padding_height = output_padding[1];
222   int64_t output_padding_width = output_padding[2];
223 
224   int n_input_plane = weight_.size(0);
225   int n_output_plane = weight_.size(1);
226 
227   TensorArg input_arg{input_, "input", 1}, output_arg{output, "output", 2},
228       weight_arg{weight_, "weight", 3}, bias_arg{bias_, "bias", 4};
229 
230   checkAllSameGPU(
231       "slow_conv_transpose3d_out_cuda",
232       {input_arg, output_arg, weight_arg, bias_arg});
233 
234   slow_conv_transpose3d_shape_check(
235       input_,
236       Tensor(),
237       weight_,
238       bias_,
239       kernel_depth,
240       kernel_width,
241       kernel_height,
242       stride_depth,
243       stride_width,
244       stride_height,
245       padding_depth,
246       padding_width,
247       padding_height,
248       dilation_depth,
249       dilation_width,
250       dilation_height,
251       output_padding_depth,
252       output_padding_width,
253       output_padding_height,
254       0);
255 
256   Tensor input = input_.contiguous();
257   Tensor weight = weight_.contiguous();
258   Tensor bias = bias_.defined() ? bias_.contiguous() : bias_;
259 
260   int is_batch = false;
261   if (input.dim() == 4) {
262     // Force batch
263     is_batch = true;
264     input.resize_(
265         {1, input.size(0), input.size(1), input.size(2), input.size(3)});
266   }
267 
268   int64_t input_width = input.size(4);
269   int64_t input_height = input.size(3);
270   int64_t input_depth = input.size(2);
271 
272   int64_t output_depth = (input_depth - 1) * stride_depth - 2 * padding_depth +
273       (dilation_depth * (kernel_depth - 1) + 1) + output_padding_depth;
274   int64_t output_height = (input_height - 1) * stride_height -
275       2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
276       output_padding_height;
277   int64_t output_width = (input_width - 1) * stride_width - 2 * padding_width +
278       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
279 
280   // Batch size + input planes
281   int64_t batch_size = input.size(0);
282 
283   // Resize output
284   output.resize_(
285       {batch_size, n_output_plane, output_depth, output_height, output_width});
286 
287   // Create temporary columns
288   Tensor columns = at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
289       input_depth * input_height * input_width}, input.options());
290 
291   // Define a buffer of ones, for bias accumulation
292   Tensor ones = bias.defined() ? at::ones({output_depth, output_height, output_width}, input_.options()) : Tensor();
293 
294   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
295       input.scalar_type(), "slow_conv_transpose3d_out_cuda", [&] {
296         using accscalar_t = at::acc_type<scalar_t, true>;
297 
298         // Helpers
299         Tensor input_n;
300         Tensor output_n;
301 
302         // For each elt in batch, do:
303         for (int elt = 0; elt < batch_size; elt++) {
304           // Matrix multiply per output:
305           input_n = input.select(0, elt);
306           output_n = output.select(0, elt);
307 
308           // M,N,K are dims of matrix A and B
309           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
310           int64_t m =
311               weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
312           int64_t n = columns.size(1);
313           int64_t k = weight.size(0);
314 
315           // Do GEMM (note: this is a bit confusing because gemm assumes
316           // column-major matrices)
317           at::cuda::blas::gemm<scalar_t>(
318               'n',
319               't',
320               n,
321               m,
322               k,
323               static_cast<scalar_t>(1),
324               input_n.const_data_ptr<scalar_t>(),
325               n,
326               weight.const_data_ptr<scalar_t>(),
327               m,
328               static_cast<scalar_t>(0),
329               columns.mutable_data_ptr<scalar_t>(),
330               n);
331 
332           // Unpack columns back into input:
333           at::native::col2vol<scalar_t, accscalar_t>(
334               at::cuda::getCurrentCUDAStream(),
335               columns.const_data_ptr<scalar_t>(),
336               n_output_plane,
337               output_depth,
338               output_height,
339               output_width,
340               input_depth,
341               input_height,
342               input_width,
343               kernel_depth,
344               kernel_height,
345               kernel_width,
346               padding_depth,
347               padding_height,
348               padding_width,
349               stride_depth,
350               stride_height,
351               stride_width,
352               dilation_depth,
353               dilation_height,
354               dilation_width,
355               output_n.mutable_data_ptr<scalar_t>());
356 
357           // Do Bias after:
358           // M,N,K are dims of matrix A and B
359           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
360           int64_t m_ = n_output_plane;
361           int64_t n_ = output_depth * output_height * output_width;
362           int64_t k_ = 1;
363 
364           // Do GEMM (note: this is a bit confusing because gemm assumes
365           // column-major matrices)
366           if (bias.defined()) {
367             at::cuda::blas::gemm<scalar_t>(
368                 't',
369                 'n',
370                 n_,
371                 m_,
372                 k_,
373                 static_cast<scalar_t>(1),
374                 ones.const_data_ptr<scalar_t>(),
375                 k_,
376                 bias.const_data_ptr<scalar_t>(),
377                 k_,
378                 static_cast<scalar_t>(1),
379                 output_n.mutable_data_ptr<scalar_t>(),
380                 n_);
381           }
382         }
383 
384         // Resize output
385         if (is_batch) {
386           output.resize_(
387               {n_output_plane, output_depth, output_height, output_width});
388           input.resize_(
389               {n_input_plane, input_depth, input_height, input_width});
390         }
391       });
392 }
393 
slow_conv_transpose3d_backward_out_cuda_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_input,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)394 void slow_conv_transpose3d_backward_out_cuda_template(
395     const Tensor& input_,
396     const Tensor& grad_output_,
397     Tensor& grad_input,
398     const Tensor& weight_,
399     IntArrayRef kernel_size,
400     IntArrayRef stride,
401     IntArrayRef padding,
402     IntArrayRef output_padding,
403     IntArrayRef dilation) {
404   TORCH_CHECK(
405       kernel_size.size() == 3,
406       "It is expected kernel_size equals to 3, but got size ",
407       kernel_size.size());
408 
409   TORCH_CHECK(
410       dilation.size() == 3,
411       "It is expected dilation equals to 3, but got size ",
412       dilation.size());
413 
414   TORCH_CHECK(
415       padding.size() == 3,
416       "It is expected padding equals to 3, but got size ",
417       padding.size());
418 
419   TORCH_CHECK(
420       stride.size() == 3,
421       "It is expected stride equals to 3, but got size ",
422       stride.size());
423 
424   TORCH_CHECK(
425       output_padding.size() == 3,
426       "It is expected stride equals to 3, but got size ",
427       output_padding.size());
428 
429   int n_input_plane = weight_.size(0);
430   int n_output_plane = weight_.size(1);
431 
432   int64_t kernel_depth = kernel_size[0];
433   int64_t kernel_height = kernel_size[1];
434   int64_t kernel_width = kernel_size[2];
435   int64_t dilation_depth = dilation[0];
436   int64_t dilation_height = dilation[1];
437   int64_t dilation_width = dilation[2];
438   int64_t padding_depth = padding[0];
439   int64_t padding_height = padding[1];
440   int64_t padding_width = padding[2];
441   int64_t stride_depth = stride[0];
442   int64_t stride_height = stride[1];
443   int64_t stride_width = stride[2];
444   int64_t output_padding_depth = output_padding[0];
445   int64_t output_padding_height = output_padding[1];
446   int64_t output_padding_width = output_padding[2];
447 
448   TensorArg input_arg{input_, "input", 1},
449       grad_output_arg{grad_output_, "grad_output", 2},
450       weight_arg{weight_, "weight", 3},
451       grad_input_arg{grad_input, "grad_input", 4};
452 
453   checkAllSameGPU(
454       "slow_conv_transpose3d_backward_out_cuda",
455       {input_arg,
456        grad_output_arg,
457        weight_arg,
458        grad_input_arg});
459 
460   slow_conv_transpose3d_shape_check(
461       input_,
462       grad_output_,
463       weight_,
464       Tensor(),
465       kernel_depth,
466       kernel_width,
467       kernel_height,
468       stride_depth,
469       stride_width,
470       stride_height,
471       padding_depth,
472       padding_width,
473       padding_height,
474       dilation_depth,
475       dilation_width,
476       dilation_height,
477       output_padding_depth,
478       output_padding_width,
479       output_padding_height,
480       0);
481 
482   Tensor input = input_.contiguous();
483   Tensor grad_output = grad_output_.contiguous();
484   Tensor weight = weight_.contiguous();
485 
486   bool is_batch = false;
487   if (input.dim() == 4) {
488     // Force batch
489     is_batch = true;
490     input.resize_(
491         {1, input.size(0), input.size(1), input.size(2), input.size(3)});
492     grad_output.resize_({1,
493                          grad_output.size(0),
494                          grad_output.size(1),
495                          grad_output.size(2),
496                          grad_output.size(3)});
497   }
498 
499   int64_t input_width = input.size(4);
500   int64_t input_height = input.size(3);
501   int64_t input_depth = input.size(2);
502   int64_t output_depth = (input_depth - 1) * stride_depth - 2 * padding_depth +
503       (dilation_depth * (kernel_depth - 1) + 1) + output_padding_depth;
504   int64_t output_height = (input_height - 1) * stride_height -
505       2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
506       output_padding_height;
507   int64_t output_width = (input_width - 1) * stride_width - 2 * padding_width +
508       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
509 
510   // Batch size + input planes
511   int64_t batch_size = input.size(0);
512 
513   // Resize output
514   grad_input.resize_(
515       {batch_size, n_input_plane, input_depth, input_height, input_width});
516 
517   // Create temporary columns
518   bool need_columns = (kernel_depth != 1 || kernel_height != 1 || kernel_width != 1 ||
519       stride_depth != 1 || stride_height != 1 || stride_width != 1 ||
520       dilation_depth != 1 || dilation_height != 1 ||
521       dilation_width != 1 || padding_depth != 0 ||
522       padding_height != 0 || padding_width != 0);
523   Tensor grad_columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
524       input_depth * input_height * input_width}, input.options()) : Tensor();
525 
526   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
527       input.scalar_type(), "slow_conv_transpose3d_backward_out_cuda", [&] {
528         // Helpers
529         Tensor grad_input_n;
530         Tensor grad_output_n;
531 
532         // For each elt in batch, do:
533         for (int elt = 0; elt < batch_size; elt++) {
534           // Matrix multiply per sample:
535           grad_input_n = grad_input.select(0, elt);
536           grad_output_n = grad_output.select(0, elt);
537 
538           if (need_columns) {
539             // Extract columns:
540             at::native::vol2col<scalar_t>(
541                 at::cuda::getCurrentCUDAStream(),
542                 grad_output_n.const_data_ptr<scalar_t>(),
543                 n_output_plane,
544                 output_depth,
545                 output_height,
546                 output_width,
547                 input_depth,
548                 input_height,
549                 input_width,
550                 kernel_depth,
551                 kernel_height,
552                 kernel_width,
553                 padding_depth,
554                 padding_height,
555                 padding_width,
556                 stride_depth,
557                 stride_height,
558                 stride_width,
559                 dilation_depth,
560                 dilation_height,
561                 dilation_width,
562                 grad_columns.mutable_data_ptr<scalar_t>());
563           }
564 
565           // M,N,K are dims of matrix A and B
566           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
567           int64_t m = weight.size(0);
568           int64_t n = input_depth * input_height * input_width;
569           int64_t k =
570               weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
571 
572           // Do GEMM (note: this is a bit confusing because gemm assumes
573           // column-major matrices)
574           auto gemm_in_ptr = need_columns ? grad_columns.const_data_ptr<scalar_t>()
575               : grad_output_n.const_data_ptr<scalar_t>();
576           at::cuda::blas::gemm<scalar_t>(
577               'n',
578               'n',
579               n,
580               m,
581               k,
582               static_cast<scalar_t>(1),
583               gemm_in_ptr,
584               n,
585               weight.const_data_ptr<scalar_t>(),
586               k,
587               static_cast<scalar_t>(0),
588               grad_input_n.mutable_data_ptr<scalar_t>(),
589               n);
590         }
591 
592         // Resize output
593         if (is_batch) {
594           grad_output.resize_(
595               {n_output_plane, output_depth, output_height, output_width});
596           input.resize_(
597               {n_input_plane, input_depth, input_height, input_width});
598           grad_input.resize_(
599               {n_input_plane, input_depth, input_height, input_width});
600         }
601       });
602 }
603 
slow_conv_transpose3d_acc_grad_parameters_cuda(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,int scale_)604 void slow_conv_transpose3d_acc_grad_parameters_cuda(
605     const Tensor& input_,
606     const Tensor& grad_output_,
607     Tensor& grad_weight,
608     Tensor& grad_bias,
609     IntArrayRef kernel_size,
610     IntArrayRef stride,
611     IntArrayRef padding,
612     IntArrayRef output_padding,
613     IntArrayRef dilation,
614     int scale_) {
615   TORCH_CHECK(
616       kernel_size.size() == 3,
617       "It is expected kernel_size equals to 3, but got size ",
618       kernel_size.size());
619 
620   TORCH_CHECK(
621       dilation.size() == 3,
622       "It is expected dilation equals to 3, but got size ",
623       dilation.size());
624 
625   TORCH_CHECK(
626       padding.size() == 3,
627       "It is expected padding equals to 3, but got size ",
628       padding.size());
629 
630   TORCH_CHECK(
631       stride.size() == 3,
632       "It is expected stride equals to 3, but got size ",
633       stride.size());
634 
635   TORCH_CHECK(
636       output_padding.size() == 3,
637       "It is expected stride equals to 3, but got size ",
638       output_padding.size());
639 
640   int64_t kernel_depth = kernel_size[0];
641   int64_t kernel_height = kernel_size[1];
642   int64_t kernel_width = kernel_size[2];
643   int64_t dilation_depth = dilation[0];
644   int64_t dilation_height = dilation[1];
645   int64_t dilation_width = dilation[2];
646   int64_t padding_depth = padding[0];
647   int64_t padding_height = padding[1];
648   int64_t padding_width = padding[2];
649   int64_t stride_depth = stride[0];
650   int64_t stride_height = stride[1];
651   int64_t stride_width = stride[2];
652   int64_t output_padding_depth = output_padding[0];
653   int64_t output_padding_height = output_padding[1];
654   int64_t output_padding_width = output_padding[2];
655 
656   TensorArg input_arg{input_, "input", 1},
657       grad_output_arg{grad_output_, "grad_output", 2},
658       grad_weight_arg{grad_weight, "grad_weight", 3},
659       grad_bias_arg{grad_bias, "grad_bias", 4};
660 
661   checkAllSameGPU(
662       "slow_conv_transpose3d_acc_grad_parameters_cuda",
663       {input_arg,
664        grad_output_arg,
665        grad_weight_arg,
666        grad_bias_arg});
667 
668   slow_conv_transpose3d_shape_check(
669       input_,
670       grad_output_,
671       grad_weight,
672       grad_bias,
673       kernel_depth,
674       kernel_width,
675       kernel_height,
676       stride_depth,
677       stride_width,
678       stride_height,
679       padding_depth,
680       padding_width,
681       padding_height,
682       dilation_depth,
683       dilation_width,
684       dilation_height,
685       output_padding_depth,
686       output_padding_width,
687       output_padding_height,
688       1);
689 
690   int n_output_plane;
691   if (grad_weight.defined()) {
692     n_output_plane = grad_weight.size(1);
693   } else if (grad_bias.defined()) {
694     n_output_plane = grad_bias.size(0);
695   } else {
696     return;
697   }
698 
699   if (grad_weight.defined()) {
700     TORCH_CHECK(
701         grad_weight.is_contiguous(), "grad_weight needs to be contiguous");
702   }
703   if (grad_bias.defined()) {
704     TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous");
705   }
706 
707   Tensor input = input_.contiguous();
708   Tensor grad_output = grad_output_.contiguous();
709 
710   bool is_batch = false;
711   if (input.dim() == 4) {
712     // Force batch
713     is_batch = true;
714     input.resize_(
715         {1, input.size(0), input.size(1), input.size(2), input.size(3)});
716     grad_output.resize_({1,
717                          grad_output.size(0),
718                          grad_output.size(1),
719                          grad_output.size(2),
720                          grad_output.size(3)});
721   }
722 
723   int64_t input_width = input.size(4);
724   int64_t input_height = input.size(3);
725   int64_t input_depth = input.size(2);
726 
727   int64_t output_depth = (input_depth - 1) * stride_depth - 2 * padding_depth +
728       (dilation_depth * (kernel_depth - 1) + 1) + output_padding_depth;
729   int64_t output_height = (input_height - 1) * stride_height -
730       2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
731       output_padding_height;
732   int64_t output_width = (input_width - 1) * stride_width - 2 * padding_width +
733       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
734 
735   // Batch size + input planes
736   int64_t batch_size = input.size(0);
737 
738   // Create temporary columns
739   bool need_columns = (kernel_depth != 1 || kernel_height != 1 || kernel_width != 1 ||
740       stride_depth != 1 || stride_height != 1 || stride_width != 1 ||
741       dilation_depth != 1 || dilation_height != 1 ||
742       dilation_width != 1 || padding_depth != 0 ||
743       padding_height != 0 || padding_width != 0);
744   Tensor columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
745       input_depth * input_height * input_width}, input.options()) : Tensor();
746 
747   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
748       input.scalar_type(),
749       "slow_conv_transpose3d_acc_grad_parameters_cuda",
750       [&] {
751         // Helpers
752         Tensor input_n;
753         Tensor grad_output_n;
754 
755         scalar_t scale = static_cast<scalar_t>(scale_);
756 
757         // For each elt in batch, do:
758         for (int elt = 0; elt < batch_size; elt++) {
759           // Matrix multiply per output:
760           grad_output_n = grad_output.select(0, elt);
761 
762           // Do Weight:
763           if (grad_weight.defined()) {
764             // Matrix multiply per output:
765             input_n = input.select(0, elt);
766 
767             if (need_columns) {
768               // Extract columns:
769               at::native::vol2col<scalar_t>(
770                   at::cuda::getCurrentCUDAStream(),
771                   grad_output_n.const_data_ptr<scalar_t>(),
772                   n_output_plane,
773                   output_depth,
774                   output_height,
775                   output_width,
776                   input_depth,
777                   input_height,
778                   input_width,
779                   kernel_depth,
780                   kernel_height,
781                   kernel_width,
782                   padding_depth,
783                   padding_height,
784                   padding_width,
785                   stride_depth,
786                   stride_height,
787                   stride_width,
788                   dilation_depth,
789                   dilation_height,
790                   dilation_width,
791                   columns.mutable_data_ptr<scalar_t>());
792             }
793 
794             // M,N,K are dims of matrix A and B
795             // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
796             int64_t n = n_output_plane * kernel_width * kernel_height * kernel_depth;
797             int64_t m = input_n.size(0); // n_input_plane
798             int64_t k = input_depth * input_height * input_width;
799 
800             // Do GEMM (note: this is a bit confusing because gemm assumes
801             // column-major matrices)
802             auto gemm_in_ptr = need_columns ? columns.const_data_ptr<scalar_t>() : grad_output_n.const_data_ptr<scalar_t>();
803             at::cuda::blas::gemm<scalar_t>(
804                 't',
805                 'n',
806                 n,
807                 m,
808                 k,
809                 scale,
810                 gemm_in_ptr,
811                 k,
812                 input_n.const_data_ptr<scalar_t>(),
813                 k,
814                 static_cast<scalar_t>(1),
815                 grad_weight.mutable_data_ptr<scalar_t>(),
816                 n);
817           }
818         }
819 
820         if (grad_bias.defined()) {
821           at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3, 4});
822         }
823 
824         // Resize
825         if (is_batch) {
826           grad_output.resize_(
827               {n_output_plane, output_depth, output_height, output_width});
828           input.resize_(
829               {input.size(1), input_depth, input_height, input_width});
830         }
831       });
832 }
833 
834 } // namespace
835 
slow_conv_transpose3d_out_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,Tensor & output)836 Tensor& slow_conv_transpose3d_out_cuda(const Tensor& input,
837     const Tensor& weight,
838     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
839     IntArrayRef stride,
840     IntArrayRef padding,
841     IntArrayRef output_padding,
842     IntArrayRef dilation,
843     Tensor& output) {
844   // See [Note: hacky wrapper removal for optional tensor]
845   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
846   const Tensor& bias = *bias_maybe_owned;
847 
848   slow_conv_transpose3d_out_cuda_template(
849       output,
850       input,
851       weight,
852       kernel_size,
853       bias,
854       stride,
855       padding,
856       output_padding,
857       dilation);
858 
859   return output;
860 }
861 
slow_conv_transpose3d_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)862 Tensor slow_conv_transpose3d_cuda(
863     const Tensor& input,
864     const Tensor& weight,
865     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
866     IntArrayRef stride,
867     IntArrayRef padding,
868     IntArrayRef output_padding,
869     IntArrayRef dilation) {
870   // See [Note: hacky wrapper removal for optional tensor]
871   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
872   const Tensor& bias = *bias_maybe_owned;
873 
874   Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
875 
876   slow_conv_transpose3d_out_cuda_template(
877       output,
878       input,
879       weight,
880       kernel_size,
881       bias,
882       stride,
883       padding,
884       output_padding,
885       dilation);
886 
887   return output;
888 }
889 
slow_conv_transpose3d_backward_out_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)890 std::tuple<Tensor&, Tensor&, Tensor&> slow_conv_transpose3d_backward_out_cuda(const Tensor& grad_output,
891     const Tensor& input,
892     const Tensor& weight,
893     IntArrayRef kernel_size,
894     IntArrayRef stride,
895     IntArrayRef padding,
896     IntArrayRef output_padding,
897     IntArrayRef dilation,
898     Tensor& grad_input,
899     Tensor& grad_weight,
900     Tensor& grad_bias) {
901   if (grad_input.defined()) {
902     slow_conv_transpose3d_backward_out_cuda_template(
903         input,
904         grad_output,
905         grad_input,
906         weight,
907         kernel_size,
908         stride,
909         padding,
910         output_padding,
911         dilation);
912   }
913 
914   if (grad_weight.defined()) {
915     grad_weight.resize_(weight.sizes());
916     grad_weight.zero_();
917   }
918 
919   if (grad_bias.defined()) {
920     grad_bias.resize_({weight.size(1)});
921     grad_bias.zero_();
922   }
923 
924   if (grad_weight.defined() || grad_bias.defined()) {
925     slow_conv_transpose3d_acc_grad_parameters_cuda(
926         input,
927         grad_output,
928         grad_weight,
929         grad_bias,
930         kernel_size,
931         stride,
932         padding,
933         output_padding,
934         dilation,
935         1);
936   }
937 
938   return std::tuple<Tensor&, Tensor&, Tensor&>(
939       grad_input, grad_weight, grad_bias);
940 }
941 
slow_conv_transpose3d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,std::array<bool,3> output_mask)942 std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose3d_backward_cuda(
943     const Tensor& grad_output,
944     const Tensor& input,
945     const Tensor& weight,
946     IntArrayRef kernel_size,
947     IntArrayRef stride,
948     IntArrayRef padding,
949     IntArrayRef output_padding,
950     IntArrayRef dilation,
951     std::array<bool, 3> output_mask) {
952   Tensor grad_input;
953   Tensor grad_weight;
954   Tensor grad_bias;
955 
956   if (output_mask[0]) {
957     grad_input = at::empty({0}, grad_output.options());
958   } else {
959     grad_input = Tensor();
960   }
961 
962   if (output_mask[1]) {
963     grad_weight = at::empty({0}, grad_output.options());
964   } else {
965     grad_weight = Tensor();
966   }
967 
968   if (output_mask[2]) {
969     grad_bias = at::empty({0}, grad_output.options());
970   } else {
971     grad_bias = Tensor();
972   }
973 
974   if (grad_input.defined()) {
975     slow_conv_transpose3d_backward_out_cuda_template(
976         input,
977         grad_output,
978         grad_input,
979         weight,
980         kernel_size,
981         stride,
982         padding,
983         output_padding,
984         dilation);
985   }
986 
987   if (grad_weight.defined()) {
988     grad_weight.resize_(weight.sizes());
989     grad_weight.zero_();
990   }
991 
992   if (grad_bias.defined()) {
993     grad_bias.resize_({weight.size(1)});
994     grad_bias.zero_();
995   }
996 
997   if (grad_weight.defined() || grad_bias.defined()) {
998     slow_conv_transpose3d_acc_grad_parameters_cuda(
999         input,
1000         grad_output,
1001         grad_weight,
1002         grad_bias,
1003         kernel_size,
1004         stride,
1005         padding,
1006         output_padding,
1007         dilation,
1008         1);
1009   }
1010 
1011   return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
1012 }
1013 
1014 REGISTER_CUDA_DISPATCH(slow_conv_transpose3d_backward_stub, &slow_conv_transpose3d_backward_cuda);
1015 
1016 } // namespace at::native
1017