xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/vol2col.cuh>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/CUDABlas.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/cuda/im2col.cuh>
9 #include <ATen/native/DilatedConvolutionUtils.h>
10 #include <c10/util/accumulate.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/sum.h>
18 #include <ATen/ops/ones.h>
19 #include <ATen/ops/slow_conv_dilated2d_native.h>
20 #include <ATen/ops/slow_conv_dilated3d_native.h>
21 #endif
22 
23 #include <tuple>
24 
25 namespace at::native {
26 
27 namespace {
28 
29 // hyper-volume to column, CUDA
30 template <typename Dtype, int64_t dim>
hvol2col(cudaStream_t stream,const Dtype * data_hvol,const int channels,const IntArrayRef input_size,const IntArrayRef output_size,const IntArrayRef kernel_size,const IntArrayRef stride_size,const IntArrayRef pad_size,const IntArrayRef dilation_size,Dtype * data_col)31 void hvol2col(
32     cudaStream_t stream,
33     const Dtype* data_hvol,
34     const int channels,
35     const IntArrayRef input_size,
36     const IntArrayRef output_size,
37     const IntArrayRef kernel_size,
38     const IntArrayRef stride_size,
39     const IntArrayRef pad_size,
40     const IntArrayRef dilation_size,
41     Dtype* data_col) {
42   if (dim == 3) {
43     vol2col<Dtype>(
44         stream,
45         data_hvol,
46         channels,
47         input_size[0],
48         input_size[1],
49         input_size[2],
50         output_size[0],
51         output_size[1],
52         output_size[2],
53         kernel_size[0],
54         kernel_size[1],
55         kernel_size[2],
56         pad_size[0],
57         pad_size[1],
58         pad_size[2],
59         stride_size[0],
60         stride_size[1],
61         stride_size[2],
62         dilation_size[0],
63         dilation_size[1],
64         dilation_size[2],
65         data_col);
66   }
67   if (dim == 2) {
68     im2col<Dtype>(
69         stream,
70         data_hvol,
71         channels,
72         input_size[0],
73         input_size[1],
74         output_size[0],
75         output_size[1],
76         kernel_size[0],
77         kernel_size[1],
78         pad_size[0],
79         pad_size[1],
80         stride_size[0],
81         stride_size[1],
82         dilation_size[0],
83         dilation_size[1],
84         data_col);
85   }
86 }
87 
88 // column to hyper-volume, CUDA
89 template <typename Dtype, int64_t dim>
col2hvol(cudaStream_t stream,const Dtype * data_col,const int channels,const IntArrayRef input_size,const IntArrayRef output_size,const IntArrayRef kernel_size,const IntArrayRef stride_size,const IntArrayRef pad_size,const IntArrayRef dilation_size,Dtype * data_hvol)90 void col2hvol(
91     cudaStream_t stream,
92     const Dtype* data_col,
93     const int channels,
94     const IntArrayRef input_size,
95     const IntArrayRef output_size,
96     const IntArrayRef kernel_size,
97     const IntArrayRef stride_size,
98     const IntArrayRef pad_size,
99     const IntArrayRef dilation_size,
100     Dtype* data_hvol) {
101   if (dim == 3) {
102     col2vol<Dtype, Dtype>(
103         stream,
104         data_col,
105         channels,
106         input_size[0],
107         input_size[1],
108         input_size[2],
109         output_size[0],
110         output_size[1],
111         output_size[2],
112         kernel_size[0],
113         kernel_size[1],
114         kernel_size[2],
115         pad_size[0],
116         pad_size[1],
117         pad_size[2],
118         stride_size[0],
119         stride_size[1],
120         stride_size[2],
121         dilation_size[0],
122         dilation_size[1],
123         dilation_size[2],
124         data_hvol);
125   }
126   if (dim == 2) {
127     col2im<Dtype, Dtype>(
128         stream,
129         data_col,
130         channels,
131         input_size[0],
132         input_size[1],
133         output_size[0],
134         output_size[1],
135         kernel_size[0],
136         kernel_size[1],
137         pad_size[0],
138         pad_size[1],
139         stride_size[0],
140         stride_size[1],
141         dilation_size[0],
142         dilation_size[1],
143         data_hvol);
144   }
145 }
146 
147 /*
148    check tensor data locations
149 */
slow_conv_dilated_location_check(CheckedFrom c,const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output)150 void slow_conv_dilated_location_check(
151     CheckedFrom c,
152     const Tensor& input,
153     const Tensor& weight,
154     const Tensor& bias,
155     const Tensor& grad_output) {
156   // checking data locations of user-provided tensor arguments
157   TensorArg input_arg{input, "input", 2}, weight_arg{weight, "weight", 3},
158       bias_arg{bias, "bias", 4}, grad_output_arg{grad_output, "grad_output", 5};
159   checkAllSameGPU(c, {input_arg, weight_arg});
160   if (bias.defined()) {
161     checkAllSameGPU(c, {input_arg, bias_arg});
162   }
163   if (grad_output.defined()) {
164     checkAllSameGPU(c, {input_arg, grad_output_arg});
165   }
166   // we are not checking the data locations of other tensor
167   // arguments such as output, grad_input, etc because of these are
168   // allocated based on input options and hence these tensors always
169   // have the same data location as of input tensor.
170 }
171 
172 /*
173   slow_conv_dilated_all_cuda_template
174 
175   Main worker. Computes tensors output, grad_input, grad_weight,
176   and/or grad_bias if defined, respectively.
177  */
178 
179 template <int64_t dim>
slow_conv_dilated_all_cuda_template(Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)180 void slow_conv_dilated_all_cuda_template(
181     Tensor& output,
182     const Tensor& input,
183     const Tensor& weight,
184     const Tensor& bias,
185     const Tensor& grad_output,
186     Tensor& grad_input,
187     Tensor& grad_weight,
188     Tensor& grad_bias,
189     IntArrayRef kernel_size,
190     IntArrayRef stride_size,
191     IntArrayRef pad_size,
192     IntArrayRef dilation_size) {
193   slow_conv_dilated_location_check(__func__, input, weight, bias, grad_output);
194   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
195   auto options = input.options();
196   // The rear part of input tensor sizes:
197   auto input_size = input.sizes().slice(2);
198   // The rear part of output tensor sizes:
199   auto output_size = internal::get_output_size<dim>(
200       input, kernel_size, stride_size, pad_size, dilation_size);
201   int64_t batchSize = input.size(0);
202   int64_t nInputPlane = weight.size(1);
203   int64_t nOutputPlane = weight.size(0);
204   // Temporary buffers:
205   const int64_t m = c10::multiply_integers(kernel_size);
206   const int64_t output_vsize = c10::multiply_integers(output_size);
207   Tensor columns = at::empty({0}, options);
208   if (output.defined() || grad_weight.defined() || grad_input.defined()) {
209     columns.resize_({nInputPlane * m, output_vsize});
210   }
211   // Initialize
212   if (grad_weight.defined()) {
213     grad_weight.zero_();
214   }
215   if (grad_bias.defined()) {
216     grad_bias.zero_();
217   }
218   if (output.defined() && !bias.defined()) {
219     output.zero_();
220   }
221 
222 #if defined(USE_ROCM)
223   /* When using ROCm, the sum evaluation is inaccurate for double
224      tensors. The reason is currently unknown. Hence, we use gemv for
225      computing `grad_output_n.sum(dims)` until the ROCm-sum issue is
226      resolved. */
227   Tensor ones = at::empty({0}, options);
228   if (grad_bias.defined()) {
229     ones.resize_({output_vsize});
230     ones.fill_(1);
231   }
232   /* MSVC does not like #ifdef-s inside the CPP macro
233      AT_DISPATCH_FLOATING_TYPES_AND_HALF. So, we define the code
234      branching outside the CPP macro: */
235 #define CALCULATE_GRAD_BIAS                                \
236   at::cuda::blas::gemv<scalar_t>(                          \
237       /*trans=*/'t',                                       \
238       /*    m=*/output_vsize,                              \
239       /*    n=*/nOutputPlane,                              \
240       /*alpha=*/static_cast<scalar_t>(1),                  \
241       /*    A=*/grad_output_n.const_data_ptr<scalar_t>(),  \
242       /*  lda=*/output_vsize,                              \
243       /*    x=*/ones.const_data_ptr<scalar_t>(),           \
244       /* incx=*/1,                                         \
245       /* beta=*/static_cast<scalar_t>(1),                  \
246       /*    y=*/grad_bias.mutable_data_ptr<scalar_t>(),    \
247       /* incy=*/1)
248 #else
249 #define CALCULATE_GRAD_BIAS grad_bias += grad_output_n.sum(dims)
250 #endif
251 
252   // Helpers
253   Tensor grad_output_n;
254   std::vector<int64_t> dims(dim);
255   std::iota(dims.begin(), dims.end(), 1);
256 
257   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
258       input.scalar_type(), "slow_conv_dilated<>", [&] {
259         // For each elt in batch, do:
260         for (int elt = 0; elt < batchSize; elt++) {
261           // Matrix multiply per output:
262           Tensor input_n = input.select(0, elt);
263 
264           // Output
265           if (output.defined()) {
266             Tensor output_n = output.select(0, elt);
267             if (bias.defined()) {
268               /* For gemm argument derivation, see
269                  slow_conv_dilated_all_cuda_template in
270                  ATen/native/DilatedConvolution.cpp */
271               for (int n = 0; n < nOutputPlane; n++) {
272                 output_n.select(0, n).fill_(bias[n]);
273               }
274             }
275             // Extract columns:
276             hvol2col<scalar_t, dim>(
277                 stream,
278                 input_n.const_data_ptr<scalar_t>(),
279                 nInputPlane,
280                 input_size,
281                 output_size,
282                 kernel_size,
283                 stride_size,
284                 pad_size,
285                 dilation_size,
286                 columns.mutable_data_ptr<scalar_t>());
287             /* For gemm argument derivation, see
288                slow_conv_dilated_all_cuda_template in
289                ATen/native/DilatedConvolution.cpp */
290             at::cuda::blas::gemm<scalar_t>(
291                 /*transa=*/'n',
292                 /*transb=*/'n',
293                 /*     m=*/columns.size(1),
294                 /*     n=*/nOutputPlane,
295                 /*     k=*/columns.size(0),
296                 /* alpha=*/static_cast<scalar_t>(1),
297                 /*     A=*/columns.const_data_ptr<scalar_t>(),
298                 /*   lda=*/columns.size(1),
299                 /*     B=*/weight.const_data_ptr<scalar_t>(),
300                 /*   ldb=*/columns.size(0),
301                 /*  beta=*/static_cast<scalar_t>(1),
302                 /*     C=*/output_n.mutable_data_ptr<scalar_t>(),
303                 /*   ldc=*/columns.size(1));
304 
305           } else {
306             // All gradients
307             grad_output_n = grad_output.select(0, elt);
308           }
309 
310           // Gradient of input:
311           if (grad_input.defined()) {
312             /* For gemm argument derivation, see
313                slow_conv_dilated_all_cuda_template in
314                ATen/native/DilatedConvolution.cpp */
315             at::cuda::blas::gemm<scalar_t>(
316                 /*transa=*/'n',
317                 /*transb=*/'t',
318                 /*     m=*/columns.size(1),
319                 /*     n=*/columns.size(0),
320                 /*     k=*/nOutputPlane,
321                 /* alpha=*/static_cast<scalar_t>(1),
322                 /*     A=*/grad_output_n.const_data_ptr<scalar_t>(),
323                 /*   lda=*/columns.size(1),
324                 /*     B=*/weight.const_data_ptr<scalar_t>(),
325                 /*   ldb=*/columns.size(0),
326                 /*  beta=*/static_cast<scalar_t>(0),
327                 /*     C=*/columns.mutable_data_ptr<scalar_t>(),
328                 /*   ldc=*/columns.size(1));
329             // Unpack columns back into input:
330             Tensor grad_input_n = grad_input.select(0, elt);
331 
332             col2hvol<scalar_t, dim>(
333                 stream,
334                 columns.const_data_ptr<scalar_t>(),
335                 nInputPlane,
336                 input_size,
337                 output_size,
338                 kernel_size,
339                 stride_size,
340                 pad_size,
341                 dilation_size,
342                 grad_input_n.mutable_data_ptr<scalar_t>());
343           }
344 
345           // Gradient of weight:
346           if (grad_weight.defined()) {
347             // Extract columns:
348             hvol2col<scalar_t, dim>(
349                 stream,
350                 input_n.const_data_ptr<scalar_t>(),
351                 nInputPlane,
352                 input_size,
353                 output_size,
354                 kernel_size,
355                 stride_size,
356                 pad_size,
357                 dilation_size,
358                 columns.mutable_data_ptr<scalar_t>());
359             scalar_t scale = static_cast<scalar_t>(
360                 1); // TODO: expose as argument?
361             /* For gemm argument derivation, see
362                slow_conv_dilated_all_cuda_template in
363                ATen/native/DilatedConvolution.cpp */
364             at::cuda::blas::gemm<scalar_t>(
365                 /*transa=*/'t',
366                 /*transb=*/'n',
367                 /*     m=*/columns.size(0),
368                 /*     n=*/nOutputPlane,
369                 /*     k=*/columns.size(1),
370                 /* alpha=*/scale,
371                 /*     A=*/columns.const_data_ptr<scalar_t>(),
372                 /*   lda=*/columns.size(1),
373                 /*     B=*/grad_output_n.const_data_ptr<scalar_t>(),
374                 /*   ldb=*/columns.size(1),
375                 /*  beta=*/static_cast<scalar_t>(1),
376                 /*     C=*/grad_weight.mutable_data_ptr<scalar_t>(),
377                 /*   ldc=*/columns.size(0));
378           }
379 
380           // Gradient of bias:
381           if (grad_bias.defined()) {
382             /* For gemv argument derivation, see
383                slow_conv_dilated_all_cpu_template in
384                ATen/native/DilatedConvolution.cpp */
385             CALCULATE_GRAD_BIAS; /* MSVC does not like #ifdef-s
386                                     inside the CPP macros, see above. */
387             /*
388               TODO: when scale != 1 is introduced then use:
389                 grad_bias += scale * grad_output_n.sum(dims);
390              */
391           }
392         }
393       });
394 
395 } // slow_conv_dilated_all_cuda_template
396 
397 } // namespace
398 
slow_conv_dilated2d_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)399 Tensor slow_conv_dilated2d_cuda(
400     const Tensor& input,
401     const Tensor& weight,
402     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
403     IntArrayRef stride_size,
404     IntArrayRef pad_size,
405     IntArrayRef dilation_size) {
406   // See [Note: hacky wrapper removal for optional tensor]
407   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
408   const Tensor& bias = *bias_maybe_owned;
409 
410   Tensor undefined;
411   internal::slow_conv_dilated_shape_check<2>(
412       input,
413       weight,
414       bias,
415       undefined,
416       kernel_size,
417       stride_size,
418       pad_size,
419       dilation_size);
420   auto is_batch = input.dim() == 4;
421   auto options = input.options();
422   // calculate output tensor size
423   auto output_size = internal::get_output_size<2>(
424       input, weight, kernel_size, stride_size, pad_size, dilation_size);
425   // template function assumes batched tensors.  unsqueeze(0) will
426   // insert batch dimension without affecting the original tensor.
427   const Tensor input_ =
428       (is_batch ? input.contiguous() : input.contiguous().unsqueeze(0));
429   const Tensor weight_ = weight.contiguous();
430   const Tensor bias_ = (bias.defined() ? bias.contiguous() : undefined);
431   Tensor output = at::empty(output_size, options);
432   Tensor output_ = (is_batch ? output : output.unsqueeze(0));
433 
434   slow_conv_dilated_all_cuda_template<2>(
435       output_,
436       input_,
437       weight_,
438       bias_,
439       undefined,
440       undefined,
441       undefined,
442       undefined,
443       kernel_size,
444       stride_size,
445       pad_size,
446       dilation_size);
447   return output;
448 }
449 
slow_conv_dilated2d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size,const std::array<bool,3ul> output_mask)450 std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated2d_backward_cuda(
451     const Tensor& grad_output,
452     const Tensor& input,
453     const Tensor& weight,
454     IntArrayRef kernel_size,
455     IntArrayRef stride_size,
456     IntArrayRef pad_size,
457     IntArrayRef dilation_size,
458     const std::array<bool, 3ul> output_mask) {
459   Tensor undefined;
460   internal::slow_conv_dilated_shape_check<2>(
461       input,
462       weight,
463       undefined,
464       grad_output,
465       kernel_size,
466       stride_size,
467       pad_size,
468       dilation_size);
469   auto is_batch = input.dim() == 4;
470   auto options = grad_output.options();
471   // template function assumes batched tensors.  unsqueeze(0) will
472   // insert batch dimension without affecting the original tensor.
473   const Tensor grad_output_ =
474       (is_batch ? grad_output.contiguous()
475                 : grad_output.contiguous().unsqueeze(0));
476   const Tensor input_ =
477       (is_batch ? input.contiguous() : input.contiguous().unsqueeze(0));
478   const Tensor weight_ = weight.contiguous();
479   // compute only gradients for which the corresponding output_mask is true:
480   Tensor grad_input =
481       (output_mask[0] ? at::empty(input.sizes(), options) : undefined);
482   Tensor grad_weight =
483       (output_mask[1] ? at::empty(weight.sizes(), options) : undefined);
484   Tensor grad_bias =
485       (output_mask[2] ? at::empty(weight.size(0), options) : undefined);
486   Tensor grad_input_ =
487       (output_mask[0] ? (is_batch ? grad_input : grad_input.unsqueeze(0))
488                       : undefined);
489   slow_conv_dilated_all_cuda_template<2>(
490       undefined,
491       input_,
492       weight_,
493       undefined,
494       grad_output_,
495       grad_input,
496       grad_weight,
497       grad_bias,
498       kernel_size,
499       stride_size,
500       pad_size,
501       dilation_size);
502   return std::tie(grad_input, grad_weight, grad_bias);
503 }
504 
slow_conv_dilated3d_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)505 Tensor slow_conv_dilated3d_cuda(
506     const Tensor& input,
507     const Tensor& weight,
508     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
509     IntArrayRef stride_size,
510     IntArrayRef pad_size,
511     IntArrayRef dilation_size) {
512   // See [Note: hacky wrapper removal for optional tensor]
513   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
514   const Tensor& bias = *bias_maybe_owned;
515 
516   Tensor undefined;
517   internal::slow_conv_dilated_shape_check<3>(
518       input,
519       weight,
520       bias,
521       undefined,
522       kernel_size,
523       stride_size,
524       pad_size,
525       dilation_size);
526   auto is_batch = input.dim() == 5;
527   auto options = input.options();
528   // calculate output tensor size
529   auto output_size = internal::get_output_size<3>(
530       input, weight, kernel_size, stride_size, pad_size, dilation_size);
531   // template function assumes batched tensors.  unsqueeze(0) will
532   // insert batch dimension without affecting the original tensor.
533   const Tensor input_ =
534       (is_batch ? input.contiguous() : input.contiguous().unsqueeze(0));
535   const Tensor weight_ = weight.contiguous();
536   const Tensor bias_ = (bias.defined() ? bias.contiguous() : undefined);
537   Tensor output = at::empty(output_size, options);
538   Tensor output_ = (is_batch ? output : output.unsqueeze(0));
539 
540   slow_conv_dilated_all_cuda_template<3>(
541       output,
542       input_,
543       weight_,
544       bias_,
545       undefined,
546       undefined,
547       undefined,
548       undefined,
549       kernel_size,
550       stride_size,
551       pad_size,
552       dilation_size);
553   return output;
554 }
555 
slow_conv_dilated3d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size,const std::array<bool,3ul> output_mask)556 std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated3d_backward_cuda(
557     const Tensor& grad_output,
558     const Tensor& input,
559     const Tensor& weight,
560     IntArrayRef kernel_size,
561     IntArrayRef stride_size,
562     IntArrayRef pad_size,
563     IntArrayRef dilation_size,
564     const std::array<bool, 3ul> output_mask) {
565   Tensor undefined;
566   internal::slow_conv_dilated_shape_check<3>(
567       input,
568       weight,
569       undefined,
570       grad_output,
571       kernel_size,
572       stride_size,
573       pad_size,
574       dilation_size);
575   auto is_batch = input.dim() == 5;
576   auto options = grad_output.options();
577   // template function assumes batched tensors.  unsqueeze(0) will
578   // insert batch dimension without affecting the original tensor.
579   const Tensor grad_output_ =
580       (is_batch ? grad_output.contiguous()
581                 : grad_output.contiguous().unsqueeze(0));
582   const Tensor input_ =
583       (is_batch ? input.contiguous() : input.contiguous().unsqueeze(0));
584   const Tensor weight_ = weight.contiguous();
585   // compute only gradients for which the corresponding output_mask is true:
586   Tensor grad_input =
587       (output_mask[0] ? at::empty(input.sizes(), options) : undefined);
588   Tensor grad_weight =
589       (output_mask[1] ? at::empty(weight.sizes(), options) : undefined);
590   Tensor grad_bias =
591       (output_mask[2] ? at::empty(weight.size(0), options) : undefined);
592   Tensor grad_input_ =
593       (output_mask[0] ? (is_batch ? grad_input : grad_input.unsqueeze(0))
594                       : undefined);
595   slow_conv_dilated_all_cuda_template<3>(
596       undefined,
597       input_,
598       weight_,
599       undefined,
600       grad_output_,
601       grad_input,
602       grad_weight,
603       grad_bias,
604       kernel_size,
605       stride_size,
606       pad_size,
607       dilation_size);
608   return std::tie(grad_input, grad_weight, grad_bias);
609 }
610 
611 REGISTER_CUDA_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cuda);
612 REGISTER_CUDA_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cuda);
613 
614 } // namespace at::native
615