xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ConvolutionMM2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/div_rtn.h>
6 #include <ATen/cuda/CUDABlas.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/Resize.h>
9 #include <ATen/native/cuda/im2col.cuh>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/_slow_conv2d_forward_native.h>
16 #include <ATen/ops/_slow_conv2d_backward_native.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/sum.h>
19 #endif
20 
21 namespace at::native {
22 namespace {
23 
slow_conv2d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,bool weight_nullable)24 void slow_conv2d_shape_check(
25     const Tensor& input, const Tensor& grad_output,
26     const Tensor& weight, const Tensor& bias,
27     int64_t kH, int64_t kW,
28     int64_t dH, int64_t dW,
29     int64_t padH, int64_t padW,
30     bool weight_nullable) {
31   TORCH_CHECK(kW > 0 && kH > 0,
32               "kernel size should be greater than zero, but got kH: ", kH, " kW: ", kW);
33   TORCH_CHECK(dW > 0 && dH > 0,
34               "stride should be greater than zero, but got dH: ", dH, " dW: ", dW);
35 
36   TORCH_CHECK(weight_nullable || weight.defined(),
37               "weight tensor is expected to be non-nullable");
38   TORCH_CHECK(!weight.defined() ||
39               ((weight.numel() > 0) && (weight.dim() == 2)),
40               "non-empty 2D weight tensor expected, but got: ", weight.sizes());
41   TORCH_CHECK(!bias.defined() || (bias.dim() == 1 && bias.sizes()[0] == weight.sizes()[0]),
42               "Expected bias to have shape [", weight.sizes()[0], "] but got ", bias.sizes());
43 
44   const auto in_sizes = input.sizes();
45   constexpr int ndim = 4;
46   constexpr int dimf = 1;
47   constexpr int dimh = 2;
48   constexpr int dimw = 3;
49   TORCH_CHECK(in_sizes.size() == ndim, "Expected 4D input tensor, but got ", in_sizes);
50 
51   // Allow for empty batch size but not other dimensions
52   const bool valid_empty = c10::multiply_integers(in_sizes.slice(1)) != 0;
53   TORCH_CHECK(valid_empty, "non-empty input tensor expected but got: ", in_sizes);
54 
55   int64_t inputHeight = in_sizes[dimh];
56   int64_t inputWidth = in_sizes[dimw];
57 
58   int64_t exactInputHeight = inputHeight + 2 * padH;
59   int64_t exactInputWidth = inputWidth + 2 * padW;
60 
61   TORCH_CHECK(exactInputHeight >= kH && exactInputWidth >= kW,
62               "Calculated padded input size per channel: ",
63               IntArrayRef{exactInputHeight, exactInputWidth},
64               ". Kernel size: ", IntArrayRef{kH, kW},
65               ". Kernel size can't be greater than actual input size");
66 
67   // NOTE: can't use conv_output_size if the weight isn't defined
68   auto outputHeight = div_rtn<int64_t>(exactInputHeight - kH, dH) + 1;
69   auto outputWidth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;
70 
71   TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
72               "Given input size per channel: ",
73               IntArrayRef{inputHeight, inputWidth},
74               ". Calculated output size per channel: ",
75               IntArrayRef{outputHeight, outputWidth},
76               ". Output size is too small");
77 
78   if (weight.defined()) {
79     const auto w_sizes = weight.sizes();
80     int64_t nInputPlane = w_sizes[1];
81     if (w_sizes.size() == 2) {
82       nInputPlane /= (kH * kW);
83     }
84     TORCH_CHECK(in_sizes[dimf] == nInputPlane,
85                 "Expected input dim ", dimf, " to have size ", nInputPlane,
86                 " but got ", in_sizes[dimf]);
87   }
88 
89   if (grad_output.defined()) {
90     const auto gO_sizes = grad_output.sizes();
91     TORCH_CHECK(gO_sizes.size() == ndim,
92                 "Expected grad_output to have ", ndim,
93                 " dimensions but got shape", gO_sizes);
94 
95     if (weight.defined()) {
96       const auto w_sizes = weight.sizes();
97       TORCH_CHECK(gO_sizes[dimf] == w_sizes[0],
98                   "Expected  dim ", dimf, " to have size ", w_sizes[0],
99                   " but got ", gO_sizes[dimf]);
100     } else if (bias.defined()) {
101       const auto b_sizes = bias.sizes();
102       int64_t nOutputPlane = b_sizes.size() == 0 ? 1 : b_sizes[0];
103       TORCH_CHECK(gO_sizes[dimf] == nOutputPlane,
104                   "Expected grad_output dim ", dimf, " to have size ",
105                   nOutputPlane, " but got ", gO_sizes[dimf]);
106     }
107     TORCH_CHECK(gO_sizes[dimh] == outputHeight,
108                 "Expected grad_output dim ", dimh, " to have size ",
109                 outputHeight, " but got ", gO_sizes[dimh]);
110     TORCH_CHECK(gO_sizes[dimw] == outputWidth,
111                 "Expected grad_output dim ", dimw, " to have size ",
112                 outputWidth, " but got ", gO_sizes[dimw]);
113   }
114 }
115 
new_view_weight_MM2d(const Tensor & weight_)116 Tensor new_view_weight_MM2d(const Tensor& weight_) {
117   auto weight = weight_.expect_contiguous();
118   const auto w_sizes = weight->sizes();
119   TORCH_CHECK(w_sizes.size() == 4);
120   int64_t s1 = w_sizes[0];
121   int64_t s2 = c10::multiply_integers(w_sizes.slice(1));
122   return weight->view({s1, s2});
123 }
124 
slow_conv2d_forward(const Tensor & input,const Tensor & output,const Tensor & weight_,const Tensor & bias,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW)125 void slow_conv2d_forward(
126            const Tensor &input,
127            const Tensor &output,
128            const Tensor &weight_,
129            const Tensor &bias,
130            int64_t kH, int64_t kW,
131            int64_t dH, int64_t dW,
132            int64_t padH, int64_t padW) {
133   auto weight = new_view_weight_MM2d(weight_);
134   slow_conv2d_shape_check(
135       input, {}, weight, bias, kH, kW, dH, dW, padH, padW, /*weight_nullable*/false);
136 
137   constexpr int dimf = 1;
138   constexpr int dimh = 2;
139   constexpr int dimw = 3;
140 
141   auto in_sizes = input.sizes();
142   int64_t batchSize = in_sizes[0];
143   int64_t nInputPlane  = in_sizes[dimf];
144   int64_t inputHeight  = in_sizes[dimh];
145   int64_t inputWidth   = in_sizes[dimw];
146   int64_t nOutputPlane = weight.sizes()[0];
147   int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
148   int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
149 
150   // Resize output
151   resize_output(output, {batchSize, nOutputPlane, outputHeight, outputWidth});
152 
153   // Create temporary columns
154   at::Tensor columns;
155 
156   const bool requires_columns = (
157       kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0);
158 
159   if (requires_columns) {
160     columns = at::empty({nInputPlane * kW * kH, outputHeight * outputWidth}, input.options());
161   }
162 
163   if (bias.defined()) {
164     TORCH_CHECK(bias.scalar_type() == input.scalar_type(),
165                 "Expected bias to have type ", input.scalar_type(),
166                 " but got ", bias.scalar_type());
167     output.copy_(bias.view({-1, 1, 1}));
168   } else {
169     output.zero_();
170   }
171 
172   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
173                                   "slow_conv2d_cuda", [&] {
174     // For each elt in batch, do:
175     for (int elt = 0; elt < batchSize; elt ++) {
176       // Matrix multiply per output:
177       auto input_n = input.select(0, elt);
178       auto output_n = output.select(0, elt);
179 
180       if (requires_columns) {
181         // Extract columns:
182         at::native::im2col(
183           c10::cuda::getCurrentCUDAStream(),
184           input_n.const_data_ptr<scalar_t>(),
185           nInputPlane, inputHeight, inputWidth,
186           outputHeight, outputWidth,
187           kH, kW, padH, padW, dH, dW,
188           1, 1,
189           columns.mutable_data_ptr<scalar_t>()
190         );
191       }
192 
193       // M,N,K are dims of matrix A and B
194       // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
195       int64_t m = nOutputPlane;
196       int64_t n = outputHeight * outputWidth;
197       int64_t k = nInputPlane*kH*kW;
198 
199       // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
200       auto gemm_in_ptr = requires_columns ?
201           columns.const_data_ptr<scalar_t>() :
202           input_n.const_data_ptr<scalar_t>();
203       at::cuda::blas::gemm(
204           'n', 'n',
205           n, m, k,
206           scalar_t(1),
207           gemm_in_ptr, n,
208           weight.const_data_ptr<scalar_t>(), k,
209           scalar_t(1),
210           output_n.mutable_data_ptr<scalar_t>(), n
211       );
212     }
213   });
214 }
215 
slow_conv2d_backward(const Tensor & input,const Tensor & grad_output,const Tensor & grad_input,const Tensor & weight_,const Tensor & grad_columns,int kH,int kW,int dH,int dW,int padH,int padW)216 void slow_conv2d_backward(
217     const Tensor &input,
218     const Tensor &grad_output,
219     const Tensor &grad_input,
220     const Tensor &weight_,
221     const Tensor &grad_columns,
222     int kH, int kW,
223     int dH, int dW,
224     int padH, int padW) {
225   Tensor weight = new_view_weight_MM2d(weight_);
226   slow_conv2d_shape_check(input, grad_output, weight, {},
227                           kH, kW, dH, dW, padH, padW, /*weight_nullable=*/false);
228 
229   // Params
230   auto weight_sizes = weight.sizes();
231   int nInputPlane = weight_sizes[1]/(kW*kH);
232   int nOutputPlane = weight_sizes[0];
233 
234   TORCH_INTERNAL_ASSERT(grad_output.is_contiguous());
235 
236   auto input_sizes = input.sizes();
237   int64_t inputWidth   = input_sizes[3];
238   int64_t inputHeight  = input_sizes[2];
239   auto output_sizes = grad_output.sizes();
240   int64_t outputWidth  = output_sizes[3];
241   int64_t outputHeight = output_sizes[2];
242 
243   // Batch size + input planes
244   int64_t batchSize = input_sizes[0];
245 
246   // Resize output
247   resize_output(grad_input, input_sizes);
248   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
249 
250   // Resize temporary columns
251   resize_output(grad_columns, {nInputPlane*kW*kH, outputHeight*outputWidth});
252   TORCH_CHECK(grad_columns.is_contiguous(), "grad_columns must be contiguous");
253 
254   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
255                                   "slow_conv2d_backward_cuda", [&] {
256     // For each elt in batch, do:
257     for (int elt = 0; elt < batchSize; elt ++) {
258       // Matrix multiply per sample:
259       auto grad_input_n = grad_input.select(0, elt);
260       auto grad_output_n = grad_output.select(0, elt);
261 
262       // M,N,K are dims of matrix A and B
263       // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
264       int64_t m = nInputPlane*kW*kH;
265       int64_t n = grad_columns.sizes()[1];
266       int64_t k = nOutputPlane;
267 
268       // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
269       at::cuda::blas::gemm<scalar_t>(
270           'n', 't',
271           n, m, k,
272           scalar_t(1),
273           grad_output_n.const_data_ptr<scalar_t>(), n,
274           weight.const_data_ptr<scalar_t>(), m,
275           scalar_t(0),
276           grad_columns.mutable_data_ptr<scalar_t>(), n
277       );
278 
279       // Unpack columns back into input:
280       using acc_t = at::acc_type<scalar_t, true>;
281       at::native::col2im<scalar_t, acc_t>(
282         c10::cuda::getCurrentCUDAStream(),
283         grad_columns.const_data_ptr<scalar_t>(),
284         nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
285         1, 1, grad_input_n.mutable_data_ptr<scalar_t>()
286       );
287     }
288   });
289 }
290 
slow_conv2d_grad_weight(const Tensor & input,const Tensor & grad_output,const Tensor & grad_weight_,const Tensor & columns,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW)291 void slow_conv2d_grad_weight(
292            const Tensor &input,
293            const Tensor &grad_output,
294            const Tensor &grad_weight_,
295            const Tensor &columns,
296            int64_t kH, int64_t kW,
297            int64_t dH, int64_t dW,
298            int64_t padH, int64_t padW) {
299   TORCH_CHECK(grad_weight_.is_contiguous(), "grad_weight needs to be contiguous");
300   auto grad_weight = new_view_weight_MM2d(grad_weight_);
301   slow_conv2d_shape_check(input, grad_output, grad_weight, {},
302                           kH, kW, dH, dW, padH, padW, /*weight_nullable=*/true);
303 
304   // Params
305   TORCH_INTERNAL_ASSERT(input.is_contiguous());
306   TORCH_INTERNAL_ASSERT(grad_output.is_contiguous());
307 
308   auto input_sizes = input.sizes();
309   int64_t nInputPlane = input_sizes[1];
310   int64_t nOutputPlane = grad_output.sizes()[1];
311 
312   int64_t inputWidth   = input_sizes[3];
313   int64_t inputHeight  = input_sizes[2];
314   int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
315   int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
316 
317   // Batch size + input planes
318   int64_t batchSize = input_sizes[0];
319 
320   // Resize temporary columns
321   resize_output(columns, {nInputPlane * kH * kW, outputHeight * outputWidth});
322 
323   const bool requires_columns = (
324       kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0);
325 
326   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
327                                   "slow_conv2d_grad_weight_cuda", [&] {
328     // For each elt in batch, do:
329     for (int elt = 0; elt < batchSize; elt ++) {
330       // Matrix multiply per output:
331       auto grad_output_n = grad_output.select(0, elt);
332 
333       // Matrix multiply per output:
334       auto input_n = input.select(0, elt);
335 
336       if (requires_columns) {
337         // Extract columns:
338         at::native::im2col<scalar_t>(
339           c10::cuda::getCurrentCUDAStream(),
340           input_n.const_data_ptr<scalar_t>(),
341           nInputPlane, inputHeight, inputWidth,
342           outputHeight, outputWidth,
343           kH, kW, padH, padW, dH, dW,
344           1, 1,
345           columns.mutable_data_ptr<scalar_t>()
346         );
347       }
348 
349       // M,N,K are dims of matrix A and B
350       // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
351       int64_t m = nOutputPlane;
352       int64_t n = nInputPlane*kW*kH;
353       int64_t k = columns.sizes()[1];
354 
355       // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
356       auto gemm_in_ptr = requires_columns ?
357           columns.const_data_ptr<scalar_t>() :
358           input_n.const_data_ptr<scalar_t>();
359       at::cuda::blas::gemm(
360           't', 'n',
361           n, m, k,
362           scalar_t(1),
363           gemm_in_ptr, k,
364           grad_output_n.const_data_ptr<scalar_t>(), k,
365           scalar_t(1),
366           grad_weight.mutable_data_ptr<scalar_t>(), n
367       );
368     }
369   });
370 }
371 
372 }  // namespace (anonymous)
373 
374 
slow_conv2d_forward_out_cuda(const Tensor & self_,const Tensor & weight_,IntArrayRef kernel_size,const std::optional<Tensor> & bias_,IntArrayRef stride,IntArrayRef padding,Tensor & output)375 Tensor& slow_conv2d_forward_out_cuda(
376     const Tensor &self_,
377     const Tensor &weight_,
378     IntArrayRef kernel_size,
379     const std::optional<Tensor> &bias_,
380     IntArrayRef stride,
381     IntArrayRef padding,
382     Tensor &output) {
383   TORCH_CHECK(kernel_size.size() == 2);
384   TORCH_CHECK(stride.size() == 2);
385   TORCH_CHECK(padding.size() == 2);
386 
387   auto self = self_.expect_contiguous();
388   auto weight = weight_.expect_contiguous();
389   auto bias = [&] {
390     if (bias_.has_value() && bias_->defined()) {
391       return bias_->expect_contiguous();
392     }
393     return MaybeOwned<Tensor>::owned(std::in_place);
394   }();
395 
396   slow_conv2d_forward(
397       *self,
398       output,
399       *weight,
400       *bias,
401       kernel_size[0], kernel_size[1],
402       stride[0], stride[1],
403       padding[0], padding[1]
404     );
405   return output;
406 }
407 
slow_conv2d_forward_cuda(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding)408 Tensor slow_conv2d_forward_cuda(
409     const Tensor &self,
410     const Tensor &weight,
411     IntArrayRef kernel_size,
412     const std::optional<Tensor> &bias,
413     IntArrayRef stride,
414     IntArrayRef padding) {
415   auto output = at::empty({0}, self.options());
416   return slow_conv2d_forward_out_cuda(
417       self, weight, kernel_size, bias, stride, padding, output);
418 }
419 
slow_conv2d_backward_out_cuda(const Tensor & grad_output_,const Tensor & self_,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)420 std::tuple<Tensor&, Tensor&, Tensor&> slow_conv2d_backward_out_cuda(
421     const Tensor& grad_output_,
422     const Tensor& self_,
423     const Tensor& weight_,
424     IntArrayRef kernel_size,
425     IntArrayRef stride,
426     IntArrayRef padding,
427     Tensor& grad_input,
428     Tensor& grad_weight,
429     Tensor& grad_bias) {
430   auto grad_output = grad_output_.expect_contiguous();
431 
432   Tensor columns = at::empty({0}, self_.options());
433   if (grad_input.defined()) {
434     resize_output(grad_input, self_.sizes());
435     auto weight = weight_.expect_contiguous();
436 
437     slow_conv2d_backward(
438         self_, *grad_output,
439         grad_input, *weight,
440         columns,
441         kernel_size[0], kernel_size[1],
442         stride[0], stride[1],
443         padding[0], padding[1]);
444   }
445   if (grad_bias.defined()) {
446     at::sum_out(grad_bias, *grad_output, IntArrayRef{0, 2, 3});
447   }
448   if (grad_weight.defined()) {
449     resize_output(grad_weight, weight_.sizes());
450     grad_weight.zero_();
451     auto self = self_.expect_contiguous();
452     slow_conv2d_grad_weight(
453         *self,
454         *grad_output,
455         grad_weight,
456         columns,
457         kernel_size[0], kernel_size[1],
458         stride[0], stride[1],
459         padding[0], padding[1]
460       );
461   }
462   return std::tuple<Tensor&, Tensor&, Tensor&>{
463       grad_input, grad_weight, grad_bias};
464 }
465 
slow_conv2d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,std::array<bool,3> output_mask)466 std::tuple<Tensor, Tensor, Tensor> slow_conv2d_backward_cuda(
467     const Tensor& grad_output,
468     const Tensor& self,
469     const Tensor& weight,
470     IntArrayRef kernel_size,
471     IntArrayRef stride,
472     IntArrayRef padding,
473     std::array<bool, 3> output_mask) {
474   Tensor grad_input;
475   Tensor grad_weight;
476   Tensor grad_bias;
477 
478   if (output_mask[0]) {
479     grad_input = at::empty({0}, grad_output.options());
480   }
481 
482   if (output_mask[1]) {
483     grad_weight = at::empty({0}, grad_output.options());
484   }
485 
486   if (output_mask[2]) {
487     grad_bias = at::empty({0}, grad_output.options());
488   }
489 
490   return native::slow_conv2d_backward_out_cuda(
491       grad_output,
492       self,
493       weight,
494       kernel_size,
495       stride,
496       padding,
497       grad_input,
498       grad_weight,
499       grad_bias);
500 }
501 
502 } // namespace at::native
503