1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/detail/KernelUtils.h>
5 #include <ATen/cuda/CUDAContext.h>
6 #include <ATen/AccumulateType.h>
7 #include <ATen/TensorUtils.h>
8 #include <ATen/native/ConvUtils.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/conv_depthwise3d_native.h>
16 #endif
17
18 #include <algorithm>
19 #include <tuple>
20 #include <limits>
21
22 namespace at::native {
23 namespace {
24
25 template <typename scalar_t, typename accscalar_t,
26 int kKnownKernelT, int kKnownKernelH, int kKnownKernelW,
27 int kKnownDilationT, int kKnownDilationH, int kKnownDilationW>
conv_depthwise3d_cuda_kernel(const PackedTensorAccessor32<const scalar_t,5> input,PackedTensorAccessor32<scalar_t,5> output,const PackedTensorAccessor32<const scalar_t,5> kernel,const scalar_t * bias,int strideT,int strideH,int strideW,int paddingT,int paddingH,int paddingW,int dilationT_,int dilationH_,int dilationW_)28 __global__ void conv_depthwise3d_cuda_kernel(
29 const PackedTensorAccessor32<const scalar_t, 5> input,
30 PackedTensorAccessor32<scalar_t, 5> output,
31 const PackedTensorAccessor32<const scalar_t, 5> kernel,
32 const scalar_t* bias,
33 int strideT, int strideH, int strideW,
34 int paddingT, int paddingH, int paddingW,
35 int dilationT_, int dilationH_, int dilationW_)
36 {
37 const int kT = kKnownKernelT > 0 ? kKnownKernelT : kernel.size(2);
38 const int kH = kKnownKernelH > 0 ? kKnownKernelH : kernel.size(3);
39 const int kW = kKnownKernelW > 0 ? kKnownKernelW : kernel.size(4);
40 const int oC = output.size(1);
41 const int oT = output.size(2);
42 const int oH = output.size(3);
43 const int oW = output.size(4);
44 const int iC = input.size(1);
45 const int iT = input.size(2);
46 const int iH = input.size(3);
47 const int iW = input.size(4);
48 const int channel_multiplier = oC / iC;
49 const int dilationT = kKnownDilationT > 0 ? kKnownDilationT : dilationT_;
50 const int dilationH = kKnownDilationH > 0 ? kKnownDilationH : dilationH_;
51 const int dilationW = kKnownDilationW > 0 ? kKnownDilationW : dilationW_;
52 const int num_output = output.size(0) * output.stride(0);
53
54 CUDA_KERNEL_LOOP(index, num_output) {
55 const int out_col = index % oW;
56 const int out_row = (index / oW) % oH;
57 const int out_frame = (index / oW / oH) % oT;
58 const int out_channel = (index / oW / oH / oT) % oC;
59 const int batch = index / oW / oH / oT / oC;
60
61 const int in_channel = out_channel / channel_multiplier;
62
63 const int in_col_start = out_col * strideW - paddingW;
64 const int in_row_start = out_row * strideH - paddingH;
65 const int in_frame_start = out_frame * strideT - paddingT;
66
67 accscalar_t sum = 0;
68 const scalar_t *kernel_ptr = kernel[out_channel].data();
69 const scalar_t *input_ptr =
70 &input[batch][in_channel][in_frame_start][in_row_start][in_col_start];
71 for (int k_frame = 0; k_frame < kT; ++k_frame) {
72 const int in_frame = in_frame_start + k_frame * dilationT;
73 for (int k_row = 0; k_row < kH; ++k_row) {
74 const int in_row = in_row_start + k_row * dilationH;
75 for (int k_col = 0; k_col < kW; ++k_col) {
76 const accscalar_t op1 = *(kernel_ptr++);
77 const int in_col = in_col_start + k_col * dilationW;
78 if (in_frame >= 0 && in_row >= 0 && in_col >= 0 &&
79 in_frame < iT && in_row < iH && in_col < iW) {
80 sum += op1 * *(input_ptr);
81 }
82 input_ptr += dilationW;
83 }
84 input_ptr += iW * dilationH - kW * dilationW;
85 }
86 input_ptr += iW * (iH * dilationT - kH * dilationH);
87 }
88 if (bias != NULL) {
89 sum += bias[out_channel];
90 }
91
92 output[batch][out_channel][out_frame][out_row][out_col] = sum;
93 }
94 }
95
96 template <typename scalar_t, typename accscalar_t,
97 int kKnownKernelT, int kKnownKernelH, int kKnownKernelW,
98 int kKnownDilationT, int kKnownDilationH, int kKnownDilationW,
99 int kKnownStrideT, int kKnownStrideH, int kKnownStrideW>
100 __global__ void
conv_depthwise3d_cuda_backward_input_kernel(const PackedTensorAccessor32<const scalar_t,5> grad_output,PackedTensorAccessor32<scalar_t,5> grad_input,const PackedTensorAccessor32<const scalar_t,5> kernel,int strideT_,int strideH_,int strideW_,int paddingT,int paddingH,int paddingW,int dilationT_,int dilationH_,int dilationW_)101 conv_depthwise3d_cuda_backward_input_kernel(
102 const PackedTensorAccessor32<const scalar_t, 5> grad_output,
103 PackedTensorAccessor32<scalar_t, 5> grad_input,
104 const PackedTensorAccessor32<const scalar_t, 5> kernel,
105 int strideT_, int strideH_, int strideW_,
106 int paddingT, int paddingH, int paddingW,
107 int dilationT_, int dilationH_, int dilationW_) {
108 const int kT = kKnownKernelT > 0 ? kKnownKernelT : kernel.size(2);
109 const int kH = kKnownKernelH > 0 ? kKnownKernelH : kernel.size(3);
110 const int kW = kKnownKernelW > 0 ? kKnownKernelW : kernel.size(4);
111 const int oC = grad_output.size(1);
112 const int oT = grad_output.size(2);
113 const int oH = grad_output.size(3);
114 const int oW = grad_output.size(4);
115 const int iC = grad_input.size(1);
116 const int iT = grad_input.size(2);
117 const int iH = grad_input.size(3);
118 const int iW = grad_input.size(4);
119 const int channel_multiplier = oC / iC;
120 const int dilationT = kKnownDilationT > 0 ? kKnownDilationT : dilationT_;
121 const int dilationH = kKnownDilationH > 0 ? kKnownDilationH : dilationH_;
122 const int dilationW = kKnownDilationW > 0 ? kKnownDilationW : dilationW_;
123 const int strideT = kKnownStrideT > 0 ? kKnownStrideT : strideT_;
124 const int strideH = kKnownStrideH > 0 ? kKnownStrideH : strideH_;
125 const int strideW = kKnownStrideW > 0 ? kKnownStrideW : strideW_;
126 const int num_input = grad_input.size(0) * grad_input.stride(0);
127
128 CUDA_KERNEL_LOOP(index, num_input) {
129 const int in_col = index % iW;
130 const int in_row = (index / iW) % iH;
131 const int in_frame = (index / iW / iH) % iT;
132 const int in_channel = (index / iW / iH / iT) % iC;
133 const int batch = index / iW / iH / iT / iC;
134
135 const int out_col_end = in_col + paddingW;
136 const int out_row_end = in_row + paddingH;
137 const int out_frame_end = in_frame + paddingT;
138
139 const scalar_t* kernel_ptr = kernel[in_channel * channel_multiplier].data();
140 accscalar_t sum = 0;
141
142 for (int k_chn = in_channel * channel_multiplier;
143 k_chn < (in_channel + 1) * channel_multiplier;
144 ++k_chn) {
145 const scalar_t* gout_ptr = grad_output[batch][k_chn].data();
146
147 for (int k_frame = 0; k_frame < kT; ++k_frame) {
148 const int out_frame_raw = out_frame_end - k_frame * dilationT;
149 const int out_frame = out_frame_raw / strideT;
150 for (int k_row = 0; k_row < kH; ++k_row) {
151 const int out_row_raw = out_row_end - k_row * dilationH;
152 const int out_row = out_row_raw / strideH;
153 for (int k_col = 0; k_col < kW; ++k_col) {
154 const accscalar_t op1 = *(kernel_ptr++);
155 const int out_col_raw = out_col_end - k_col * dilationW;
156 const int out_col = out_col_raw / strideW;
157
158 const int out_offs = (out_frame * oH + out_row) * oW + out_col;
159
160 accscalar_t op2 = (accscalar_t)0;
161 if (out_col >= 0 && out_row >= 0 && out_frame >= 0 &&
162 out_col < oW && out_row < oH && out_frame < oT) {
163 op2 = *(gout_ptr + out_offs);
164 }
165 if (out_frame * strideT == out_frame_raw &&
166 out_row * strideH == out_row_raw &&
167 out_col * strideW == out_col_raw) {
168 sum += op1 * op2;
169 }
170 }
171 }
172 }
173 }
174
175 grad_input[batch][in_channel][in_frame][in_row][in_col] = sum;
176 }
177 }
178
179 template <typename scalar_t, typename accscalar_t,
180 int kKnownStrideH, int kKnownStrideW>
181 __global__ void
conv_depthwise3d_cuda_backward_weight_kernel(const PackedTensorAccessor32<const scalar_t,5> grad_output,const PackedTensorAccessor32<const scalar_t,5> input,PackedTensorAccessor32<scalar_t,5> grad_kernel,int strideT,int strideH_,int strideW_,int paddingT,int paddingH,int paddingW,int dilationT,int dilationH,int dilationW)182 conv_depthwise3d_cuda_backward_weight_kernel(
183 const PackedTensorAccessor32<const scalar_t, 5> grad_output,
184 const PackedTensorAccessor32<const scalar_t, 5> input,
185 PackedTensorAccessor32<scalar_t, 5> grad_kernel,
186 int strideT, int strideH_, int strideW_,
187 int paddingT, int paddingH, int paddingW,
188 int dilationT, int dilationH, int dilationW) {
189 const int kC = grad_kernel.size(0);
190 const int kT = grad_kernel.size(2);
191 const int kH = grad_kernel.size(3);
192 const int kW = grad_kernel.size(4);
193
194 const int strideH = kKnownStrideH > 0 ? kKnownStrideH : strideH_;
195 const int strideW = kKnownStrideW > 0 ? kKnownStrideW : strideW_;
196
197 const int k_col = blockIdx.x % kW;
198 const int k_row = (blockIdx.x / kW) % kH;
199 const int k_frame = (blockIdx.x / kW / kH) % kT;
200 const int k_channel = blockIdx.x / kW / kH / kT;
201 scalar_t *result = &grad_kernel[k_channel][0][k_frame][k_row][k_col];
202
203 const int oT = grad_output.size(2);
204 const int oH = grad_output.size(3);
205 const int oW = grad_output.size(4);
206 const int iT = input.size(2);
207 const int iH = input.size(3);
208 const int iW = input.size(4);
209 const int channel_multiplier = grad_output.size(1) / input.size(1);
210 const int in_channel = k_channel / channel_multiplier;
211
212 extern __shared__ int sdata_raw[];
213 scalar_t* sdata = reinterpret_cast<scalar_t*>(sdata_raw);
214
215 if (k_channel >= kC) {
216 return;
217 }
218
219 const int laneid = threadIdx.x % C10_WARP_SIZE;
220 const int warpid = threadIdx.x / C10_WARP_SIZE;
221 const int nwarps = blockDim.x / C10_WARP_SIZE;
222
223 accscalar_t grad = 0;
224 int batch = warpid / oT;
225 int gout_frame = warpid - batch * oT;
226 for (int outer_pos = warpid; outer_pos < input.size(0) * oT;
227 outer_pos += nwarps, gout_frame += nwarps) {
228 while (gout_frame >= oT) { gout_frame -= oT; batch ++; }
229
230 const int in_frame = (gout_frame * strideT) + (k_frame * dilationT) - paddingT;
231
232 if (in_frame < 0 || in_frame >= iT) {
233 continue;
234 }
235
236 const scalar_t* gout_ptr = grad_output[batch][k_channel][gout_frame].data() + laneid;
237 const scalar_t* input_ptr = input[batch][in_channel][in_frame].data();
238
239 int gout_row = laneid / oW;
240 int gout_col = laneid - gout_row * oW;
241
242 for (; gout_row < oH; ) {
243 const accscalar_t op1 = *(gout_ptr);
244 gout_ptr += C10_WARP_SIZE;
245
246 const int in_col = (gout_col * strideW) + (k_col * dilationW) - paddingW;
247 const int in_row = (gout_row * strideH) + (k_row * dilationH) - paddingH;
248 const int in_pos = in_row * iW + in_col;
249
250 accscalar_t op2 = (accscalar_t)0;
251 if (in_col >= 0 && in_col < iW && in_row >= 0 && in_row < iH) {
252 op2 = *(input_ptr + in_pos);
253 }
254
255 gout_col += C10_WARP_SIZE;
256 while (gout_col >= oW) {
257 gout_col -= oW; gout_row ++;
258 }
259
260 grad += op1 * op2;
261 }
262 }
263
264 sdata[threadIdx.x] = grad;
265 __syncthreads();
266
267 CUDA_KERNEL_ASSERT(__popc(blockDim.x) == 1);
268 #pragma unroll
269 for (int i = blockDim.x / 2; i >= 1; i >>= 1) {
270 if (threadIdx.x < i) {
271 sdata[threadIdx.x] += sdata[threadIdx.x + i];
272 }
273 __syncthreads();
274 }
275
276 if (threadIdx.x == 0) {
277 *result = sdata[0];
278 }
279 }
280
281 template <int dim>
conv_depthwise_shape_check(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)282 void conv_depthwise_shape_check(
283 const Tensor& input,
284 const Tensor& weight,
285 const Tensor& bias,
286 const Tensor& grad_output,
287 IntArrayRef kernel_size,
288 IntArrayRef stride,
289 IntArrayRef padding,
290 IntArrayRef dilation) {
291 TORCH_CHECK(kernel_size.size() == dim,
292 "kernel size length should be ", dim, ", but got ", kernel_size.size());
293 TORCH_CHECK(stride.size() == dim,
294 "stride length should be ", dim, ", but got ", stride.size());
295 TORCH_CHECK(padding.size() == dim,
296 "padding length should be ", dim, ", but got ", padding.size());
297 TORCH_CHECK(dilation.size() == dim,
298 "dilation length should be ", dim, ", but got ", dilation.size());
299
300 TORCH_CHECK(weight.defined(),
301 "Weight must be defined.");
302 TORCH_CHECK(input.dim() == dim + 1 || input.dim() == dim + 2,
303 "Input dimension should be ",
304 dim + 1, "D or ", dim + 2, "D, got ",
305 input.dim(), "D");
306 TORCH_CHECK(weight.dim() == dim + 2,
307 "Weight dimension should be ", dim + 2, "D, got ", weight.dim(), "D");
308 TORCH_CHECK(weight.size(1) == 1,
309 "Depthwise weight should have in_channels=1, got ", weight.size(1));
310 TORCH_CHECK(weight.size(0) % input.size(-dim - 1) == 0,
311 "Depthwise out channels should be a multiple of in channels, got ",
312 weight.size(0), " and ", input.size(-dim - 1));
313 for (int i = 0; i < dim; ++i) {
314 TORCH_CHECK(weight.size(i + 2) == kernel_size[i],
315 "kernel size and weight size mismatch, got ",
316 kernel_size, " and ", weight.sizes());
317 TORCH_CHECK(stride[i] >= 1,
318 "stride should be at least 1, got ", stride);
319 TORCH_CHECK(padding[i] >= 0,
320 "padding should be non-negative, got ", padding);
321 TORCH_CHECK(dilation[i] >= 1,
322 "dilation should be at least 1, got ", dilation);
323 }
324
325 if (bias.defined()) {
326 TORCH_CHECK(bias.dim() == 1,
327 "Bias should be 1D tensor, got ", bias.dim(), "D");
328 TORCH_CHECK(bias.size(0) == weight.size(0),
329 "Bias length should be equal to out_channels, got ",
330 bias.size(0), " and ", weight.size(0));
331 }
332
333 if (grad_output.defined()) {
334 auto expected_output_size = conv_output_size(input.sizes(), weight.sizes(),
335 padding, stride, dilation);
336 TORCH_CHECK(static_cast<size_t>(grad_output.dim()) == expected_output_size.size(),
337 "Expect grad_output to be ",
338 expected_output_size.size(), "D, got ",
339 grad_output.dim(), "D.");
340 for (int i = 0; i < grad_output.dim(); ++i) {
341 TORCH_CHECK(grad_output.size(i) == expected_output_size[i],
342 "Expect grad_output to be of same shape as output, got ",
343 grad_output.size(i), " and ", expected_output_size[i],
344 " at dimension ", i);
345 }
346 }
347 }
348
349 }
350
351 #define NODEF_OR_EQUAL(x, y) ((y) < 0 || (x) == (y))
352 #define NODEF_OR_EQUAL_3(x, y1, y2, y3) \
353 (NODEF_OR_EQUAL(x[0], y1) && \
354 NODEF_OR_EQUAL(x[1], y2) && \
355 NODEF_OR_EQUAL(x[2], y3))
356
357 #define DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(kt, kh, kw, dilt, dilh, dilw) \
358 if (NODEF_OR_EQUAL_3(kernel_size, (kt), (kh), (kw)) && \
359 NODEF_OR_EQUAL_3(dilation, (dilt), (dilh), (dilw))) { \
360 using accscalar_t = acc_type<scalar_t, true>; \
361 conv_depthwise3d_cuda_kernel \
362 <scalar_t, accscalar_t, (kt), (kh), (kw), (dilt), (dilh), (dilw)> \
363 <<<grid, block, (smem), at::cuda::getCurrentCUDAStream()>>>( \
364 input_.packed_accessor32<const scalar_t, 5>(), \
365 output_.packed_accessor32<scalar_t, 5>(), \
366 weight_.packed_accessor32<const scalar_t, 5>(), \
367 bias_ptr, \
368 stride[0], stride[1], stride[2], \
369 padding[0], padding[1], padding[2], \
370 dilation[0], dilation[1], dilation[2]); \
371 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
372 } else
373
374 #define DWCONV3D_FORWARD_DISPATCH_OTHERS \
375 { \
376 using accscalar_t = acc_type<scalar_t, true>; \
377 conv_depthwise3d_cuda_kernel \
378 <scalar_t,accscalar_t, -1, -1, -1, -1, -1, -1> \
379 <<<grid, block, (smem), at::cuda::getCurrentCUDAStream()>>>( \
380 input_.packed_accessor32<const scalar_t, 5>(), \
381 output_.packed_accessor32<scalar_t, 5>(), \
382 weight_.packed_accessor32<const scalar_t, 5>(), \
383 bias_ptr, \
384 stride[0], stride[1], stride[2], \
385 padding[0], padding[1], padding[2], \
386 dilation[0], dilation[1], dilation[2]); \
387 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
388 }
389
conv_depthwise3d_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)390 Tensor conv_depthwise3d_cuda(
391 const Tensor& input,
392 const Tensor& weight,
393 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
394 IntArrayRef stride,
395 IntArrayRef padding,
396 IntArrayRef dilation) {
397 // See [Note: hacky wrapper removal for optional tensor]
398 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
399 const Tensor& bias = *bias_maybe_owned;
400
401 TORCH_CHECK(input.device() == weight.device(), "expects input and weight tensors to be on the same device.");
402 if (bias.defined()) {
403 TORCH_CHECK(input.device() == bias.device(), "expects input and bias tensors to be on the same device.");
404 }
405
406 conv_depthwise_shape_check<3>(input, weight, bias, Tensor() /* undefined */,
407 kernel_size, stride, padding, dilation);
408
409 Tensor input_ = input.contiguous();
410
411 if (input.dim() == 4 /* no batch */) {
412 input_ = input.unsqueeze(0);
413 }
414
415 auto output_size = conv_output_size(input_.sizes(), weight.sizes(),
416 padding, stride, dilation);
417 for (size_t i = 0; i < output_size.size(); ++i) {
418 TORCH_CHECK(output_size[i] > 0,
419 "Output size should be positive, got ", output_size[i], " at dim ", i);
420 }
421 Tensor output = at::empty(output_size, input.options());
422 Tensor output_ = output;
423 Tensor weight_ = weight.contiguous();
424 Tensor bias_ = bias.defined() ? bias.contiguous() : bias;
425
426 AT_DISPATCH_FLOATING_TYPES_AND2(
427 kHalf,
428 kBFloat16,
429 input.scalar_type(),
430 "conv_depthwise3d",
431 [&]{
432 int64_t num_outputs = output_.numel();
433 int64_t block = 256;
434 int64_t grid = std::min((num_outputs - 1) / block + 1, (int64_t)65536);
435 int64_t smem = 0;
436
437 const scalar_t* bias_ptr =
438 bias_.defined() ? bias_.const_data_ptr<scalar_t>() : NULL;
439
440 // Range check to avoid overflow in CUDA kernels.
441 TORCH_CHECK(input_.numel() <= std::numeric_limits<int32_t>::max(),
442 "Input tensor is too large.");
443 TORCH_CHECK(output_.numel() <= std::numeric_limits<int32_t>::max(),
444 "Output tensor is too large.");
445 TORCH_CHECK(weight_.numel() <= std::numeric_limits<int32_t>::max(),
446 "Weight tensor is too large.");
447 for (int i = 0; i < 3; ++i) {
448 TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= std::numeric_limits<int32_t>::max(),
449 "Padded input tensor is too large.");
450 }
451
452 DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(3, 3, 3, 1, 1, 1)
453 DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(-1, -1, -1, 1, 1, 1)
454 DWCONV3D_FORWARD_DISPATCH_OTHERS
455 }
456 );
457
458 return output;
459 }
460
461 #undef DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION
462 #undef DWCONV3D_FORWARD_DISPATCH_OTHERS
463
464 #define DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( \
465 kt, kh, kw, dilt, dilh, dilw, dt, dh, dw) \
466 if (NODEF_OR_EQUAL_3(kernel_size, (kt), (kh), (kw)) && \
467 NODEF_OR_EQUAL_3(dilation, (dilt), (dilh), (dilw)) && \
468 NODEF_OR_EQUAL_3(stride, (dt), (dh), (dw))) { \
469 using accscalar_t = acc_type<scalar_t, true>; \
470 conv_depthwise3d_cuda_backward_input_kernel \
471 <scalar_t, accscalar_t, (kt), (kh), (kw), (dilt), (dilh), (dilw), (dt), (dh), (dw)> \
472 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( \
473 grad_output_.packed_accessor32<const scalar_t, 5>(), \
474 grad_input_.packed_accessor32<scalar_t, 5>(), \
475 weight_.packed_accessor32<const scalar_t, 5>(), \
476 stride[0], stride[1], stride[2], \
477 padding[0], padding[1], padding[2], \
478 dilation[0], dilation[1], dilation[2]); \
479 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
480 } else
481
482 #define DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS \
483 { \
484 using accscalar_t = acc_type<scalar_t, true>; \
485 conv_depthwise3d_cuda_backward_input_kernel \
486 <scalar_t, accscalar_t, -1, -1, -1, -1, -1, -1, -1, -1, -1> \
487 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( \
488 grad_output_.packed_accessor32<const scalar_t, 5>(), \
489 grad_input_.packed_accessor32<scalar_t, 5>(), \
490 weight_.packed_accessor32<const scalar_t, 5>(), \
491 stride[0], stride[1], stride[2], \
492 padding[0], padding[1], padding[2], \
493 dilation[0], dilation[1], dilation[2]); \
494 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
495 }
496
497 #define DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(dh, dw) \
498 if (NODEF_OR_EQUAL_3(stride, -1, (dh), (dw))) { \
499 using accscalar_t = acc_type<scalar_t, true>; \
500 conv_depthwise3d_cuda_backward_weight_kernel \
501 <scalar_t, accscalar_t, (dh), (dw)> \
502 <<<grid, block, smem, at::cuda::getCurrentCUDAStream()>>>( \
503 grad_output_.packed_accessor32<const scalar_t, 5>(), \
504 input_.packed_accessor32<const scalar_t, 5>(), \
505 grad_weight.packed_accessor32<scalar_t, 5>(), \
506 stride[0], stride[1], stride[2], \
507 padding[0], padding[1], padding[2], \
508 dilation[0], dilation[1], dilation[2]); \
509 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
510 } else
511
512 #define DWCONV3D_BACKWARD_WEIGHT_DISPATCH_OTHERS \
513 { \
514 using accscalar_t = acc_type<scalar_t, true>; \
515 conv_depthwise3d_cuda_backward_weight_kernel \
516 <scalar_t, accscalar_t, -1, -1> \
517 <<<grid, block, smem, at::cuda::getCurrentCUDAStream()>>>( \
518 grad_output_.packed_accessor32<const scalar_t, 5>(), \
519 input_.packed_accessor32<const scalar_t, 5>(), \
520 grad_weight.packed_accessor32<scalar_t, 5>(), \
521 stride[0], stride[1], stride[2], \
522 padding[0], padding[1], padding[2], \
523 dilation[0], dilation[1], dilation[2]); \
524 C10_CUDA_KERNEL_LAUNCH_CHECK(); \
525 }
526
_depthwise_3d_backward_cuda_out(Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias,const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,const std::array<bool,3> output_mask)527 std::tuple<Tensor&, Tensor&, Tensor&> _depthwise_3d_backward_cuda_out(
528 Tensor& grad_input,
529 Tensor& grad_weight,
530 Tensor& grad_bias,
531 const Tensor& grad_output,
532 const Tensor& input,
533 const Tensor& weight,
534 IntArrayRef kernel_size,
535 IntArrayRef stride,
536 IntArrayRef padding,
537 IntArrayRef dilation,
538 const std::array<bool, 3> output_mask)
539 {
540
541 TORCH_CHECK(grad_output.device() == input.device() &&
542 input.device() == weight.device(),
543 "expects input, weight and grad_output to be on the same device.");
544 conv_depthwise_shape_check<3>(
545 input, weight, Tensor() /* undefined */, grad_output,
546 kernel_size, stride, padding, dilation);
547
548 const Tensor grad_output_ = grad_output.contiguous();
549
550 Tensor grad_input_ =
551 (output_mask[0] ? grad_input
552 : Tensor());
553
554 if (output_mask[0]) {
555 const Tensor weight_ = weight.contiguous();
556 AT_DISPATCH_FLOATING_TYPES_AND2(
557 kHalf,
558 kBFloat16,
559 grad_output.scalar_type(),
560 "conv_depthwise3d",
561 [&] {
562 int64_t num_inputs = grad_input_.numel();
563 int64_t block = 256;
564 int64_t grid = std::min((num_inputs - 1) / block + 1, (int64_t)65536);
565
566 // Range check to avoid overflow in CUDA kernels.
567 TORCH_CHECK(grad_input_.numel() <= std::numeric_limits<int32_t>::max(),
568 "Input tensor is too large.");
569 TORCH_CHECK(grad_output_.numel() <= std::numeric_limits<int32_t>::max(),
570 "Output tensor is too large.");
571 TORCH_CHECK(weight_.numel() <= std::numeric_limits<int32_t>::max(),
572 "Weight tensor is too large.");
573 for (int i = 0; i < 3; ++i) {
574 TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= std::numeric_limits<int32_t>::max(),
575 "Padded input tensor is too large.");
576 }
577
578 DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(
579 3, 3, 3, 1, 1, 1, 1, 1, 1)
580 DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(
581 3, 3, 3, 1, 1, 1, -1, -1, -1)
582 DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(
583 3, 3, 3, -1, -1, -1, 1, 1, 1)
584 DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(
585 3, 3, 3, -1, -1, -1, -1, -1, -1)
586 DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS
587 }
588 );
589 }
590
591 if (output_mask[1]) {
592 const Tensor input_ = input.contiguous();
593 AT_DISPATCH_FLOATING_TYPES_AND2(
594 kHalf,
595 kBFloat16,
596 grad_output.scalar_type(),
597 "conv_depthwise3d",
598 [&] {
599 int64_t grid = grad_weight.numel();
600 int64_t block = 256;
601 int64_t smem = sizeof(scalar_t) * block;
602
603 const int64_t int_max = std::numeric_limits<int32_t>::max();
604 TORCH_CHECK(grad_input_.numel() <= int_max,
605 "Input tensor is too large.");
606 TORCH_CHECK(grad_output_.numel() <= int_max,
607 "Output tensor is too large.");
608 TORCH_CHECK(weight.numel() <= int_max,
609 "Weight tensor is too large.");
610 for (int i = 0; i < 3; ++i) {
611 TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= int_max,
612 "Padded input tensor is too large.");
613 }
614 int64_t warp_size = at::cuda::warp_size();
615 TORCH_CHECK(grad_output_.size(0) * grad_output_.size(2) < int_max - block / warp_size &&
616 grad_output_.size(3) <= int_max - warp_size &&
617 grad_output_.size(4) <= int_max - warp_size,
618 "Output size is too large.");
619
620 DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(1, 1)
621 DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(2, 2)
622 DWCONV3D_BACKWARD_WEIGHT_DISPATCH_OTHERS
623 }
624 );
625 }
626
627 if (output_mask[2]) {
628 grad_bias = grad_output.sum({0, 2, 3, 4});
629 }
630
631 return std::tie(grad_input, grad_weight, grad_bias);
632
633 }
634
635
conv_depthwise3d_backward_cuda_out(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)636 std::tuple<Tensor&, Tensor&, Tensor&> conv_depthwise3d_backward_cuda_out(const Tensor& grad_output,
637 const Tensor& input,
638 const Tensor& weight,
639 IntArrayRef kernel_size,
640 IntArrayRef stride,
641 IntArrayRef padding,
642 IntArrayRef dilation,
643 Tensor& grad_input,
644 Tensor& grad_weight,
645 Tensor& grad_bias) {
646 if (grad_weight.defined()) {
647 grad_weight.resize_(weight.sizes());
648 grad_weight.zero_();
649 }
650
651 return _depthwise_3d_backward_cuda_out(
652 grad_input,
653 grad_weight,
654 grad_bias,
655 grad_output,
656 input,
657 weight,
658 kernel_size,
659 stride,
660 padding,
661 dilation,
662 {true,true,true});
663 }
664
conv_depthwise3d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,const std::array<bool,3> output_mask)665 std::tuple<Tensor, Tensor, Tensor> conv_depthwise3d_backward_cuda(
666 const Tensor& grad_output,
667 const Tensor& input,
668 const Tensor& weight,
669 IntArrayRef kernel_size,
670 IntArrayRef stride,
671 IntArrayRef padding,
672 IntArrayRef dilation,
673 const std::array<bool, 3> output_mask) {
674
675 auto options = grad_output.options();
676 Tensor grad_input =
677 (output_mask[0] ? at::empty(input.sizes(), options) : Tensor());
678 Tensor grad_weight =
679 (output_mask[1] ? at::empty(weight.sizes(), options) : Tensor());
680 Tensor grad_bias; /* undefined temporarily */
681
682 return _depthwise_3d_backward_cuda_out(
683 grad_input,
684 grad_weight,
685 grad_bias,
686 grad_output,
687 input,
688 weight,
689 kernel_size,
690 stride,
691 padding,
692 dilation,
693 output_mask
694 );
695
696 }
697
698 REGISTER_CUDA_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_cuda);
699
700 #undef DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION
701 #undef DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS
702
703 #undef NODEF_OR_EQUAL_3
704 #undef NODEF_OR_EQUAL
705
706 } // namespace at::native
707