1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/div_rtn.h>
6 #include <ATen/cuda/CUDABlas.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/Resize.h>
9 #include <ATen/native/cuda/im2col.cuh>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/_slow_conv2d_forward_native.h>
16 #include <ATen/ops/_slow_conv2d_backward_native.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/sum.h>
19 #endif
20
21 namespace at::native {
22 namespace {
23
slow_conv2d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW,bool weight_nullable)24 void slow_conv2d_shape_check(
25 const Tensor& input, const Tensor& grad_output,
26 const Tensor& weight, const Tensor& bias,
27 int64_t kH, int64_t kW,
28 int64_t dH, int64_t dW,
29 int64_t padH, int64_t padW,
30 bool weight_nullable) {
31 TORCH_CHECK(kW > 0 && kH > 0,
32 "kernel size should be greater than zero, but got kH: ", kH, " kW: ", kW);
33 TORCH_CHECK(dW > 0 && dH > 0,
34 "stride should be greater than zero, but got dH: ", dH, " dW: ", dW);
35
36 TORCH_CHECK(weight_nullable || weight.defined(),
37 "weight tensor is expected to be non-nullable");
38 TORCH_CHECK(!weight.defined() ||
39 ((weight.numel() > 0) && (weight.dim() == 2)),
40 "non-empty 2D weight tensor expected, but got: ", weight.sizes());
41 TORCH_CHECK(!bias.defined() || (bias.dim() == 1 && bias.sizes()[0] == weight.sizes()[0]),
42 "Expected bias to have shape [", weight.sizes()[0], "] but got ", bias.sizes());
43
44 const auto in_sizes = input.sizes();
45 constexpr int ndim = 4;
46 constexpr int dimf = 1;
47 constexpr int dimh = 2;
48 constexpr int dimw = 3;
49 TORCH_CHECK(in_sizes.size() == ndim, "Expected 4D input tensor, but got ", in_sizes);
50
51 // Allow for empty batch size but not other dimensions
52 const bool valid_empty = c10::multiply_integers(in_sizes.slice(1)) != 0;
53 TORCH_CHECK(valid_empty, "non-empty input tensor expected but got: ", in_sizes);
54
55 int64_t inputHeight = in_sizes[dimh];
56 int64_t inputWidth = in_sizes[dimw];
57
58 int64_t exactInputHeight = inputHeight + 2 * padH;
59 int64_t exactInputWidth = inputWidth + 2 * padW;
60
61 TORCH_CHECK(exactInputHeight >= kH && exactInputWidth >= kW,
62 "Calculated padded input size per channel: ",
63 IntArrayRef{exactInputHeight, exactInputWidth},
64 ". Kernel size: ", IntArrayRef{kH, kW},
65 ". Kernel size can't be greater than actual input size");
66
67 // NOTE: can't use conv_output_size if the weight isn't defined
68 auto outputHeight = div_rtn<int64_t>(exactInputHeight - kH, dH) + 1;
69 auto outputWidth = div_rtn<int64_t>(exactInputWidth - kW, dW) + 1;
70
71 TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
72 "Given input size per channel: ",
73 IntArrayRef{inputHeight, inputWidth},
74 ". Calculated output size per channel: ",
75 IntArrayRef{outputHeight, outputWidth},
76 ". Output size is too small");
77
78 if (weight.defined()) {
79 const auto w_sizes = weight.sizes();
80 int64_t nInputPlane = w_sizes[1];
81 if (w_sizes.size() == 2) {
82 nInputPlane /= (kH * kW);
83 }
84 TORCH_CHECK(in_sizes[dimf] == nInputPlane,
85 "Expected input dim ", dimf, " to have size ", nInputPlane,
86 " but got ", in_sizes[dimf]);
87 }
88
89 if (grad_output.defined()) {
90 const auto gO_sizes = grad_output.sizes();
91 TORCH_CHECK(gO_sizes.size() == ndim,
92 "Expected grad_output to have ", ndim,
93 " dimensions but got shape", gO_sizes);
94
95 if (weight.defined()) {
96 const auto w_sizes = weight.sizes();
97 TORCH_CHECK(gO_sizes[dimf] == w_sizes[0],
98 "Expected dim ", dimf, " to have size ", w_sizes[0],
99 " but got ", gO_sizes[dimf]);
100 } else if (bias.defined()) {
101 const auto b_sizes = bias.sizes();
102 int64_t nOutputPlane = b_sizes.size() == 0 ? 1 : b_sizes[0];
103 TORCH_CHECK(gO_sizes[dimf] == nOutputPlane,
104 "Expected grad_output dim ", dimf, " to have size ",
105 nOutputPlane, " but got ", gO_sizes[dimf]);
106 }
107 TORCH_CHECK(gO_sizes[dimh] == outputHeight,
108 "Expected grad_output dim ", dimh, " to have size ",
109 outputHeight, " but got ", gO_sizes[dimh]);
110 TORCH_CHECK(gO_sizes[dimw] == outputWidth,
111 "Expected grad_output dim ", dimw, " to have size ",
112 outputWidth, " but got ", gO_sizes[dimw]);
113 }
114 }
115
new_view_weight_MM2d(const Tensor & weight_)116 Tensor new_view_weight_MM2d(const Tensor& weight_) {
117 auto weight = weight_.expect_contiguous();
118 const auto w_sizes = weight->sizes();
119 TORCH_CHECK(w_sizes.size() == 4);
120 int64_t s1 = w_sizes[0];
121 int64_t s2 = c10::multiply_integers(w_sizes.slice(1));
122 return weight->view({s1, s2});
123 }
124
slow_conv2d_forward(const Tensor & input,const Tensor & output,const Tensor & weight_,const Tensor & bias,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW)125 void slow_conv2d_forward(
126 const Tensor &input,
127 const Tensor &output,
128 const Tensor &weight_,
129 const Tensor &bias,
130 int64_t kH, int64_t kW,
131 int64_t dH, int64_t dW,
132 int64_t padH, int64_t padW) {
133 auto weight = new_view_weight_MM2d(weight_);
134 slow_conv2d_shape_check(
135 input, {}, weight, bias, kH, kW, dH, dW, padH, padW, /*weight_nullable*/false);
136
137 constexpr int dimf = 1;
138 constexpr int dimh = 2;
139 constexpr int dimw = 3;
140
141 auto in_sizes = input.sizes();
142 int64_t batchSize = in_sizes[0];
143 int64_t nInputPlane = in_sizes[dimf];
144 int64_t inputHeight = in_sizes[dimh];
145 int64_t inputWidth = in_sizes[dimw];
146 int64_t nOutputPlane = weight.sizes()[0];
147 int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
148 int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
149
150 // Resize output
151 resize_output(output, {batchSize, nOutputPlane, outputHeight, outputWidth});
152
153 // Create temporary columns
154 at::Tensor columns;
155
156 const bool requires_columns = (
157 kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0);
158
159 if (requires_columns) {
160 columns = at::empty({nInputPlane * kW * kH, outputHeight * outputWidth}, input.options());
161 }
162
163 if (bias.defined()) {
164 TORCH_CHECK(bias.scalar_type() == input.scalar_type(),
165 "Expected bias to have type ", input.scalar_type(),
166 " but got ", bias.scalar_type());
167 output.copy_(bias.view({-1, 1, 1}));
168 } else {
169 output.zero_();
170 }
171
172 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
173 "slow_conv2d_cuda", [&] {
174 // For each elt in batch, do:
175 for (int elt = 0; elt < batchSize; elt ++) {
176 // Matrix multiply per output:
177 auto input_n = input.select(0, elt);
178 auto output_n = output.select(0, elt);
179
180 if (requires_columns) {
181 // Extract columns:
182 at::native::im2col(
183 c10::cuda::getCurrentCUDAStream(),
184 input_n.const_data_ptr<scalar_t>(),
185 nInputPlane, inputHeight, inputWidth,
186 outputHeight, outputWidth,
187 kH, kW, padH, padW, dH, dW,
188 1, 1,
189 columns.mutable_data_ptr<scalar_t>()
190 );
191 }
192
193 // M,N,K are dims of matrix A and B
194 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
195 int64_t m = nOutputPlane;
196 int64_t n = outputHeight * outputWidth;
197 int64_t k = nInputPlane*kH*kW;
198
199 // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
200 auto gemm_in_ptr = requires_columns ?
201 columns.const_data_ptr<scalar_t>() :
202 input_n.const_data_ptr<scalar_t>();
203 at::cuda::blas::gemm(
204 'n', 'n',
205 n, m, k,
206 scalar_t(1),
207 gemm_in_ptr, n,
208 weight.const_data_ptr<scalar_t>(), k,
209 scalar_t(1),
210 output_n.mutable_data_ptr<scalar_t>(), n
211 );
212 }
213 });
214 }
215
slow_conv2d_backward(const Tensor & input,const Tensor & grad_output,const Tensor & grad_input,const Tensor & weight_,const Tensor & grad_columns,int kH,int kW,int dH,int dW,int padH,int padW)216 void slow_conv2d_backward(
217 const Tensor &input,
218 const Tensor &grad_output,
219 const Tensor &grad_input,
220 const Tensor &weight_,
221 const Tensor &grad_columns,
222 int kH, int kW,
223 int dH, int dW,
224 int padH, int padW) {
225 Tensor weight = new_view_weight_MM2d(weight_);
226 slow_conv2d_shape_check(input, grad_output, weight, {},
227 kH, kW, dH, dW, padH, padW, /*weight_nullable=*/false);
228
229 // Params
230 auto weight_sizes = weight.sizes();
231 int nInputPlane = weight_sizes[1]/(kW*kH);
232 int nOutputPlane = weight_sizes[0];
233
234 TORCH_INTERNAL_ASSERT(grad_output.is_contiguous());
235
236 auto input_sizes = input.sizes();
237 int64_t inputWidth = input_sizes[3];
238 int64_t inputHeight = input_sizes[2];
239 auto output_sizes = grad_output.sizes();
240 int64_t outputWidth = output_sizes[3];
241 int64_t outputHeight = output_sizes[2];
242
243 // Batch size + input planes
244 int64_t batchSize = input_sizes[0];
245
246 // Resize output
247 resize_output(grad_input, input_sizes);
248 TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
249
250 // Resize temporary columns
251 resize_output(grad_columns, {nInputPlane*kW*kH, outputHeight*outputWidth});
252 TORCH_CHECK(grad_columns.is_contiguous(), "grad_columns must be contiguous");
253
254 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
255 "slow_conv2d_backward_cuda", [&] {
256 // For each elt in batch, do:
257 for (int elt = 0; elt < batchSize; elt ++) {
258 // Matrix multiply per sample:
259 auto grad_input_n = grad_input.select(0, elt);
260 auto grad_output_n = grad_output.select(0, elt);
261
262 // M,N,K are dims of matrix A and B
263 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
264 int64_t m = nInputPlane*kW*kH;
265 int64_t n = grad_columns.sizes()[1];
266 int64_t k = nOutputPlane;
267
268 // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
269 at::cuda::blas::gemm<scalar_t>(
270 'n', 't',
271 n, m, k,
272 scalar_t(1),
273 grad_output_n.const_data_ptr<scalar_t>(), n,
274 weight.const_data_ptr<scalar_t>(), m,
275 scalar_t(0),
276 grad_columns.mutable_data_ptr<scalar_t>(), n
277 );
278
279 // Unpack columns back into input:
280 using acc_t = at::acc_type<scalar_t, true>;
281 at::native::col2im<scalar_t, acc_t>(
282 c10::cuda::getCurrentCUDAStream(),
283 grad_columns.const_data_ptr<scalar_t>(),
284 nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
285 1, 1, grad_input_n.mutable_data_ptr<scalar_t>()
286 );
287 }
288 });
289 }
290
slow_conv2d_grad_weight(const Tensor & input,const Tensor & grad_output,const Tensor & grad_weight_,const Tensor & columns,int64_t kH,int64_t kW,int64_t dH,int64_t dW,int64_t padH,int64_t padW)291 void slow_conv2d_grad_weight(
292 const Tensor &input,
293 const Tensor &grad_output,
294 const Tensor &grad_weight_,
295 const Tensor &columns,
296 int64_t kH, int64_t kW,
297 int64_t dH, int64_t dW,
298 int64_t padH, int64_t padW) {
299 TORCH_CHECK(grad_weight_.is_contiguous(), "grad_weight needs to be contiguous");
300 auto grad_weight = new_view_weight_MM2d(grad_weight_);
301 slow_conv2d_shape_check(input, grad_output, grad_weight, {},
302 kH, kW, dH, dW, padH, padW, /*weight_nullable=*/true);
303
304 // Params
305 TORCH_INTERNAL_ASSERT(input.is_contiguous());
306 TORCH_INTERNAL_ASSERT(grad_output.is_contiguous());
307
308 auto input_sizes = input.sizes();
309 int64_t nInputPlane = input_sizes[1];
310 int64_t nOutputPlane = grad_output.sizes()[1];
311
312 int64_t inputWidth = input_sizes[3];
313 int64_t inputHeight = input_sizes[2];
314 int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
315 int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
316
317 // Batch size + input planes
318 int64_t batchSize = input_sizes[0];
319
320 // Resize temporary columns
321 resize_output(columns, {nInputPlane * kH * kW, outputHeight * outputWidth});
322
323 const bool requires_columns = (
324 kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0);
325
326 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
327 "slow_conv2d_grad_weight_cuda", [&] {
328 // For each elt in batch, do:
329 for (int elt = 0; elt < batchSize; elt ++) {
330 // Matrix multiply per output:
331 auto grad_output_n = grad_output.select(0, elt);
332
333 // Matrix multiply per output:
334 auto input_n = input.select(0, elt);
335
336 if (requires_columns) {
337 // Extract columns:
338 at::native::im2col<scalar_t>(
339 c10::cuda::getCurrentCUDAStream(),
340 input_n.const_data_ptr<scalar_t>(),
341 nInputPlane, inputHeight, inputWidth,
342 outputHeight, outputWidth,
343 kH, kW, padH, padW, dH, dW,
344 1, 1,
345 columns.mutable_data_ptr<scalar_t>()
346 );
347 }
348
349 // M,N,K are dims of matrix A and B
350 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
351 int64_t m = nOutputPlane;
352 int64_t n = nInputPlane*kW*kH;
353 int64_t k = columns.sizes()[1];
354
355 // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
356 auto gemm_in_ptr = requires_columns ?
357 columns.const_data_ptr<scalar_t>() :
358 input_n.const_data_ptr<scalar_t>();
359 at::cuda::blas::gemm(
360 't', 'n',
361 n, m, k,
362 scalar_t(1),
363 gemm_in_ptr, k,
364 grad_output_n.const_data_ptr<scalar_t>(), k,
365 scalar_t(1),
366 grad_weight.mutable_data_ptr<scalar_t>(), n
367 );
368 }
369 });
370 }
371
372 } // namespace (anonymous)
373
374
slow_conv2d_forward_out_cuda(const Tensor & self_,const Tensor & weight_,IntArrayRef kernel_size,const std::optional<Tensor> & bias_,IntArrayRef stride,IntArrayRef padding,Tensor & output)375 Tensor& slow_conv2d_forward_out_cuda(
376 const Tensor &self_,
377 const Tensor &weight_,
378 IntArrayRef kernel_size,
379 const std::optional<Tensor> &bias_,
380 IntArrayRef stride,
381 IntArrayRef padding,
382 Tensor &output) {
383 TORCH_CHECK(kernel_size.size() == 2);
384 TORCH_CHECK(stride.size() == 2);
385 TORCH_CHECK(padding.size() == 2);
386
387 auto self = self_.expect_contiguous();
388 auto weight = weight_.expect_contiguous();
389 auto bias = [&] {
390 if (bias_.has_value() && bias_->defined()) {
391 return bias_->expect_contiguous();
392 }
393 return MaybeOwned<Tensor>::owned(std::in_place);
394 }();
395
396 slow_conv2d_forward(
397 *self,
398 output,
399 *weight,
400 *bias,
401 kernel_size[0], kernel_size[1],
402 stride[0], stride[1],
403 padding[0], padding[1]
404 );
405 return output;
406 }
407
slow_conv2d_forward_cuda(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding)408 Tensor slow_conv2d_forward_cuda(
409 const Tensor &self,
410 const Tensor &weight,
411 IntArrayRef kernel_size,
412 const std::optional<Tensor> &bias,
413 IntArrayRef stride,
414 IntArrayRef padding) {
415 auto output = at::empty({0}, self.options());
416 return slow_conv2d_forward_out_cuda(
417 self, weight, kernel_size, bias, stride, padding, output);
418 }
419
slow_conv2d_backward_out_cuda(const Tensor & grad_output_,const Tensor & self_,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)420 std::tuple<Tensor&, Tensor&, Tensor&> slow_conv2d_backward_out_cuda(
421 const Tensor& grad_output_,
422 const Tensor& self_,
423 const Tensor& weight_,
424 IntArrayRef kernel_size,
425 IntArrayRef stride,
426 IntArrayRef padding,
427 Tensor& grad_input,
428 Tensor& grad_weight,
429 Tensor& grad_bias) {
430 auto grad_output = grad_output_.expect_contiguous();
431
432 Tensor columns = at::empty({0}, self_.options());
433 if (grad_input.defined()) {
434 resize_output(grad_input, self_.sizes());
435 auto weight = weight_.expect_contiguous();
436
437 slow_conv2d_backward(
438 self_, *grad_output,
439 grad_input, *weight,
440 columns,
441 kernel_size[0], kernel_size[1],
442 stride[0], stride[1],
443 padding[0], padding[1]);
444 }
445 if (grad_bias.defined()) {
446 at::sum_out(grad_bias, *grad_output, IntArrayRef{0, 2, 3});
447 }
448 if (grad_weight.defined()) {
449 resize_output(grad_weight, weight_.sizes());
450 grad_weight.zero_();
451 auto self = self_.expect_contiguous();
452 slow_conv2d_grad_weight(
453 *self,
454 *grad_output,
455 grad_weight,
456 columns,
457 kernel_size[0], kernel_size[1],
458 stride[0], stride[1],
459 padding[0], padding[1]
460 );
461 }
462 return std::tuple<Tensor&, Tensor&, Tensor&>{
463 grad_input, grad_weight, grad_bias};
464 }
465
slow_conv2d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,std::array<bool,3> output_mask)466 std::tuple<Tensor, Tensor, Tensor> slow_conv2d_backward_cuda(
467 const Tensor& grad_output,
468 const Tensor& self,
469 const Tensor& weight,
470 IntArrayRef kernel_size,
471 IntArrayRef stride,
472 IntArrayRef padding,
473 std::array<bool, 3> output_mask) {
474 Tensor grad_input;
475 Tensor grad_weight;
476 Tensor grad_bias;
477
478 if (output_mask[0]) {
479 grad_input = at::empty({0}, grad_output.options());
480 }
481
482 if (output_mask[1]) {
483 grad_weight = at::empty({0}, grad_output.options());
484 }
485
486 if (output_mask[2]) {
487 grad_bias = at::empty({0}, grad_output.options());
488 }
489
490 return native::slow_conv2d_backward_out_cuda(
491 grad_output,
492 self,
493 weight,
494 kernel_size,
495 stride,
496 padding,
497 grad_input,
498 grad_weight,
499 grad_bias);
500 }
501
502 } // namespace at::native
503