xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/DepthwiseConv3d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/detail/KernelUtils.h>
5 #include <ATen/cuda/CUDAContext.h>
6 #include <ATen/AccumulateType.h>
7 #include <ATen/TensorUtils.h>
8 #include <ATen/native/ConvUtils.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/conv_depthwise3d_native.h>
16 #endif
17 
18 #include <algorithm>
19 #include <tuple>
20 #include <limits>
21 
22 namespace at::native {
23 namespace {
24 
25 template <typename scalar_t, typename accscalar_t,
26     int kKnownKernelT, int kKnownKernelH, int kKnownKernelW,
27     int kKnownDilationT, int kKnownDilationH, int kKnownDilationW>
conv_depthwise3d_cuda_kernel(const PackedTensorAccessor32<const scalar_t,5> input,PackedTensorAccessor32<scalar_t,5> output,const PackedTensorAccessor32<const scalar_t,5> kernel,const scalar_t * bias,int strideT,int strideH,int strideW,int paddingT,int paddingH,int paddingW,int dilationT_,int dilationH_,int dilationW_)28 __global__ void conv_depthwise3d_cuda_kernel(
29     const PackedTensorAccessor32<const scalar_t, 5> input,
30     PackedTensorAccessor32<scalar_t, 5> output,
31     const PackedTensorAccessor32<const scalar_t, 5> kernel,
32     const scalar_t* bias,
33     int strideT, int strideH, int strideW,
34     int paddingT, int paddingH, int paddingW,
35     int dilationT_, int dilationH_, int dilationW_)
36 {
37   const int kT = kKnownKernelT > 0 ? kKnownKernelT : kernel.size(2);
38   const int kH = kKnownKernelH > 0 ? kKnownKernelH : kernel.size(3);
39   const int kW = kKnownKernelW > 0 ? kKnownKernelW : kernel.size(4);
40   const int oC = output.size(1);
41   const int oT = output.size(2);
42   const int oH = output.size(3);
43   const int oW = output.size(4);
44   const int iC = input.size(1);
45   const int iT = input.size(2);
46   const int iH = input.size(3);
47   const int iW = input.size(4);
48   const int channel_multiplier = oC / iC;
49   const int dilationT = kKnownDilationT > 0 ? kKnownDilationT : dilationT_;
50   const int dilationH = kKnownDilationH > 0 ? kKnownDilationH : dilationH_;
51   const int dilationW = kKnownDilationW > 0 ? kKnownDilationW : dilationW_;
52   const int num_output = output.size(0) * output.stride(0);
53 
54   CUDA_KERNEL_LOOP(index, num_output) {
55     const int out_col = index % oW;
56     const int out_row = (index / oW) % oH;
57     const int out_frame = (index / oW / oH) % oT;
58     const int out_channel = (index / oW / oH / oT) % oC;
59     const int batch = index / oW / oH / oT / oC;
60 
61     const int in_channel = out_channel / channel_multiplier;
62 
63     const int in_col_start = out_col * strideW - paddingW;
64     const int in_row_start = out_row * strideH - paddingH;
65     const int in_frame_start = out_frame * strideT - paddingT;
66 
67     accscalar_t sum = 0;
68     const scalar_t *kernel_ptr = kernel[out_channel].data();
69     const scalar_t *input_ptr =
70         &input[batch][in_channel][in_frame_start][in_row_start][in_col_start];
71     for (int k_frame = 0; k_frame < kT; ++k_frame) {
72       const int in_frame = in_frame_start + k_frame * dilationT;
73       for (int k_row = 0; k_row < kH; ++k_row) {
74         const int in_row = in_row_start + k_row * dilationH;
75         for (int k_col = 0; k_col < kW; ++k_col) {
76           const accscalar_t op1 = *(kernel_ptr++);
77           const int in_col = in_col_start + k_col * dilationW;
78           if (in_frame >= 0 && in_row >= 0 && in_col >= 0 &&
79               in_frame < iT && in_row < iH && in_col < iW) {
80             sum += op1 * *(input_ptr);
81           }
82           input_ptr += dilationW;
83         }
84         input_ptr += iW * dilationH - kW * dilationW;
85       }
86       input_ptr += iW * (iH * dilationT - kH * dilationH);
87     }
88     if (bias != NULL) {
89       sum += bias[out_channel];
90     }
91 
92     output[batch][out_channel][out_frame][out_row][out_col] = sum;
93   }
94 }
95 
96 template <typename scalar_t, typename accscalar_t,
97     int kKnownKernelT, int kKnownKernelH, int kKnownKernelW,
98     int kKnownDilationT, int kKnownDilationH, int kKnownDilationW,
99     int kKnownStrideT, int kKnownStrideH, int kKnownStrideW>
100 __global__ void
conv_depthwise3d_cuda_backward_input_kernel(const PackedTensorAccessor32<const scalar_t,5> grad_output,PackedTensorAccessor32<scalar_t,5> grad_input,const PackedTensorAccessor32<const scalar_t,5> kernel,int strideT_,int strideH_,int strideW_,int paddingT,int paddingH,int paddingW,int dilationT_,int dilationH_,int dilationW_)101 conv_depthwise3d_cuda_backward_input_kernel(
102     const PackedTensorAccessor32<const scalar_t, 5> grad_output,
103     PackedTensorAccessor32<scalar_t, 5> grad_input,
104     const PackedTensorAccessor32<const scalar_t, 5> kernel,
105     int strideT_, int strideH_, int strideW_,
106     int paddingT, int paddingH, int paddingW,
107     int dilationT_, int dilationH_, int dilationW_) {
108   const int kT = kKnownKernelT > 0 ? kKnownKernelT : kernel.size(2);
109   const int kH = kKnownKernelH > 0 ? kKnownKernelH : kernel.size(3);
110   const int kW = kKnownKernelW > 0 ? kKnownKernelW : kernel.size(4);
111   const int oC = grad_output.size(1);
112   const int oT = grad_output.size(2);
113   const int oH = grad_output.size(3);
114   const int oW = grad_output.size(4);
115   const int iC = grad_input.size(1);
116   const int iT = grad_input.size(2);
117   const int iH = grad_input.size(3);
118   const int iW = grad_input.size(4);
119   const int channel_multiplier = oC / iC;
120   const int dilationT = kKnownDilationT > 0 ? kKnownDilationT : dilationT_;
121   const int dilationH = kKnownDilationH > 0 ? kKnownDilationH : dilationH_;
122   const int dilationW = kKnownDilationW > 0 ? kKnownDilationW : dilationW_;
123   const int strideT = kKnownStrideT > 0 ? kKnownStrideT : strideT_;
124   const int strideH = kKnownStrideH > 0 ? kKnownStrideH : strideH_;
125   const int strideW = kKnownStrideW > 0 ? kKnownStrideW : strideW_;
126   const int num_input = grad_input.size(0) * grad_input.stride(0);
127 
128   CUDA_KERNEL_LOOP(index, num_input) {
129     const int in_col = index % iW;
130     const int in_row = (index / iW) % iH;
131     const int in_frame = (index / iW / iH) % iT;
132     const int in_channel = (index / iW / iH / iT) % iC;
133     const int batch = index / iW / iH / iT / iC;
134 
135     const int out_col_end = in_col + paddingW;
136     const int out_row_end = in_row + paddingH;
137     const int out_frame_end = in_frame + paddingT;
138 
139     const scalar_t* kernel_ptr = kernel[in_channel * channel_multiplier].data();
140     accscalar_t sum = 0;
141 
142     for (int k_chn = in_channel * channel_multiplier;
143          k_chn < (in_channel + 1) * channel_multiplier;
144          ++k_chn) {
145       const scalar_t* gout_ptr = grad_output[batch][k_chn].data();
146 
147       for (int k_frame = 0; k_frame < kT; ++k_frame) {
148         const int out_frame_raw = out_frame_end - k_frame * dilationT;
149         const int out_frame = out_frame_raw / strideT;
150         for (int k_row = 0; k_row < kH; ++k_row) {
151           const int out_row_raw = out_row_end - k_row * dilationH;
152           const int out_row = out_row_raw / strideH;
153           for (int k_col = 0; k_col < kW; ++k_col) {
154             const accscalar_t op1 = *(kernel_ptr++);
155             const int out_col_raw = out_col_end - k_col * dilationW;
156             const int out_col = out_col_raw / strideW;
157 
158             const int out_offs = (out_frame * oH + out_row) * oW + out_col;
159 
160             accscalar_t op2 = (accscalar_t)0;
161             if (out_col >= 0 && out_row >= 0 && out_frame >= 0 &&
162                 out_col < oW && out_row < oH && out_frame < oT) {
163               op2 = *(gout_ptr + out_offs);
164             }
165             if (out_frame * strideT == out_frame_raw &&
166                 out_row * strideH == out_row_raw &&
167                 out_col * strideW == out_col_raw) {
168               sum += op1 * op2;
169             }
170           }
171         }
172       }
173     }
174 
175     grad_input[batch][in_channel][in_frame][in_row][in_col] = sum;
176   }
177 }
178 
179 template <typename scalar_t, typename accscalar_t,
180     int kKnownStrideH, int kKnownStrideW>
181 __global__ void
conv_depthwise3d_cuda_backward_weight_kernel(const PackedTensorAccessor32<const scalar_t,5> grad_output,const PackedTensorAccessor32<const scalar_t,5> input,PackedTensorAccessor32<scalar_t,5> grad_kernel,int strideT,int strideH_,int strideW_,int paddingT,int paddingH,int paddingW,int dilationT,int dilationH,int dilationW)182 conv_depthwise3d_cuda_backward_weight_kernel(
183     const PackedTensorAccessor32<const scalar_t, 5> grad_output,
184     const PackedTensorAccessor32<const scalar_t, 5> input,
185     PackedTensorAccessor32<scalar_t, 5> grad_kernel,
186     int strideT, int strideH_, int strideW_,
187     int paddingT, int paddingH, int paddingW,
188     int dilationT, int dilationH, int dilationW) {
189   const int kC = grad_kernel.size(0);
190   const int kT = grad_kernel.size(2);
191   const int kH = grad_kernel.size(3);
192   const int kW = grad_kernel.size(4);
193 
194   const int strideH = kKnownStrideH > 0 ? kKnownStrideH : strideH_;
195   const int strideW = kKnownStrideW > 0 ? kKnownStrideW : strideW_;
196 
197   const int k_col = blockIdx.x % kW;
198   const int k_row = (blockIdx.x / kW) % kH;
199   const int k_frame = (blockIdx.x / kW / kH) % kT;
200   const int k_channel = blockIdx.x / kW / kH / kT;
201   scalar_t *result = &grad_kernel[k_channel][0][k_frame][k_row][k_col];
202 
203   const int oT = grad_output.size(2);
204   const int oH = grad_output.size(3);
205   const int oW = grad_output.size(4);
206   const int iT = input.size(2);
207   const int iH = input.size(3);
208   const int iW = input.size(4);
209   const int channel_multiplier = grad_output.size(1) / input.size(1);
210   const int in_channel = k_channel / channel_multiplier;
211 
212   extern __shared__ int sdata_raw[];
213   scalar_t* sdata = reinterpret_cast<scalar_t*>(sdata_raw);
214 
215   if (k_channel >= kC) {
216     return;
217   }
218 
219   const int laneid = threadIdx.x % C10_WARP_SIZE;
220   const int warpid = threadIdx.x / C10_WARP_SIZE;
221   const int nwarps = blockDim.x / C10_WARP_SIZE;
222 
223   accscalar_t grad = 0;
224   int batch = warpid / oT;
225   int gout_frame = warpid - batch * oT;
226   for (int outer_pos = warpid; outer_pos < input.size(0) * oT;
227        outer_pos += nwarps, gout_frame += nwarps) {
228     while (gout_frame >= oT) { gout_frame -= oT; batch ++; }
229 
230     const int in_frame = (gout_frame * strideT) + (k_frame * dilationT) - paddingT;
231 
232     if (in_frame < 0 || in_frame >= iT) {
233       continue;
234     }
235 
236     const scalar_t* gout_ptr = grad_output[batch][k_channel][gout_frame].data() + laneid;
237     const scalar_t* input_ptr = input[batch][in_channel][in_frame].data();
238 
239     int gout_row = laneid / oW;
240     int gout_col = laneid - gout_row * oW;
241 
242     for (; gout_row < oH; ) {
243       const accscalar_t op1 = *(gout_ptr);
244       gout_ptr += C10_WARP_SIZE;
245 
246       const int in_col = (gout_col * strideW) + (k_col * dilationW) - paddingW;
247       const int in_row = (gout_row * strideH) + (k_row * dilationH) - paddingH;
248       const int in_pos = in_row * iW + in_col;
249 
250       accscalar_t op2 = (accscalar_t)0;
251       if (in_col >= 0 && in_col < iW && in_row >= 0 && in_row < iH) {
252         op2 = *(input_ptr + in_pos);
253       }
254 
255       gout_col += C10_WARP_SIZE;
256       while (gout_col >= oW) {
257         gout_col -= oW; gout_row ++;
258       }
259 
260       grad += op1 * op2;
261     }
262   }
263 
264   sdata[threadIdx.x] = grad;
265   __syncthreads();
266 
267   CUDA_KERNEL_ASSERT(__popc(blockDim.x) == 1);
268 #pragma unroll
269   for (int i = blockDim.x / 2; i >= 1; i >>= 1) {
270     if (threadIdx.x < i) {
271       sdata[threadIdx.x] += sdata[threadIdx.x + i];
272     }
273     __syncthreads();
274   }
275 
276   if (threadIdx.x == 0) {
277     *result = sdata[0];
278   }
279 }
280 
281 template <int dim>
conv_depthwise_shape_check(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)282 void conv_depthwise_shape_check(
283     const Tensor& input,
284     const Tensor& weight,
285     const Tensor& bias,
286     const Tensor& grad_output,
287     IntArrayRef kernel_size,
288     IntArrayRef stride,
289     IntArrayRef padding,
290     IntArrayRef dilation) {
291   TORCH_CHECK(kernel_size.size() == dim,
292               "kernel size length should be ", dim, ", but got ", kernel_size.size());
293   TORCH_CHECK(stride.size() == dim,
294               "stride length should be ", dim, ", but got ", stride.size());
295   TORCH_CHECK(padding.size() == dim,
296               "padding length should be ", dim, ", but got ", padding.size());
297   TORCH_CHECK(dilation.size() == dim,
298               "dilation length should be ", dim, ", but got ", dilation.size());
299 
300   TORCH_CHECK(weight.defined(),
301               "Weight must be defined.");
302   TORCH_CHECK(input.dim() == dim + 1 || input.dim() == dim + 2,
303               "Input dimension should be ",
304               dim + 1, "D or ", dim + 2, "D, got ",
305               input.dim(), "D");
306   TORCH_CHECK(weight.dim() == dim + 2,
307               "Weight dimension should be ", dim + 2, "D, got ", weight.dim(), "D");
308   TORCH_CHECK(weight.size(1) == 1,
309               "Depthwise weight should have in_channels=1, got ", weight.size(1));
310   TORCH_CHECK(weight.size(0) % input.size(-dim - 1) == 0,
311               "Depthwise out channels should be a multiple of in channels, got ",
312               weight.size(0), " and ", input.size(-dim - 1));
313   for (int i = 0; i < dim; ++i) {
314     TORCH_CHECK(weight.size(i + 2) == kernel_size[i],
315                 "kernel size and weight size mismatch, got ",
316                 kernel_size, " and ", weight.sizes());
317     TORCH_CHECK(stride[i] >= 1,
318                 "stride should be at least 1, got ", stride);
319     TORCH_CHECK(padding[i] >= 0,
320                 "padding should be non-negative, got ", padding);
321     TORCH_CHECK(dilation[i] >= 1,
322                 "dilation should be at least 1, got ", dilation);
323   }
324 
325   if (bias.defined()) {
326     TORCH_CHECK(bias.dim() == 1,
327                 "Bias should be 1D tensor, got ", bias.dim(), "D");
328     TORCH_CHECK(bias.size(0) == weight.size(0),
329                 "Bias length should be equal to out_channels, got ",
330                 bias.size(0), " and ", weight.size(0));
331   }
332 
333   if (grad_output.defined()) {
334     auto expected_output_size = conv_output_size(input.sizes(), weight.sizes(),
335                                                  padding, stride, dilation);
336     TORCH_CHECK(static_cast<size_t>(grad_output.dim()) == expected_output_size.size(),
337                 "Expect grad_output to be ",
338                 expected_output_size.size(), "D, got ",
339                 grad_output.dim(), "D.");
340     for (int i = 0; i < grad_output.dim(); ++i) {
341       TORCH_CHECK(grad_output.size(i) == expected_output_size[i],
342                   "Expect grad_output to be of same shape as output, got ",
343                   grad_output.size(i), " and ", expected_output_size[i],
344                   " at dimension ", i);
345     }
346   }
347 }
348 
349 }
350 
351 #define NODEF_OR_EQUAL(x, y) ((y) < 0 || (x) == (y))
352 #define NODEF_OR_EQUAL_3(x, y1, y2, y3) \
353   (NODEF_OR_EQUAL(x[0], y1) && \
354    NODEF_OR_EQUAL(x[1], y2) && \
355    NODEF_OR_EQUAL(x[2], y3))
356 
357 #define DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(kt, kh, kw, dilt, dilh, dilw) \
358   if (NODEF_OR_EQUAL_3(kernel_size, (kt), (kh), (kw)) &&                    \
359       NODEF_OR_EQUAL_3(dilation, (dilt), (dilh), (dilw))) {                 \
360     using accscalar_t = acc_type<scalar_t, true>;                           \
361     conv_depthwise3d_cuda_kernel                                            \
362     <scalar_t, accscalar_t, (kt), (kh), (kw), (dilt), (dilh), (dilw)>       \
363       <<<grid, block, (smem), at::cuda::getCurrentCUDAStream()>>>(          \
364         input_.packed_accessor32<const scalar_t, 5>(),                      \
365         output_.packed_accessor32<scalar_t, 5>(),                           \
366         weight_.packed_accessor32<const scalar_t, 5>(),                     \
367         bias_ptr,                                                           \
368         stride[0], stride[1], stride[2],                                    \
369         padding[0], padding[1], padding[2],                                 \
370         dilation[0], dilation[1], dilation[2]);                             \
371     C10_CUDA_KERNEL_LAUNCH_CHECK();                                         \
372   } else
373 
374 #define DWCONV3D_FORWARD_DISPATCH_OTHERS \
375   {                                      \
376     using accscalar_t = acc_type<scalar_t, true>;                           \
377     conv_depthwise3d_cuda_kernel                                            \
378     <scalar_t,accscalar_t, -1, -1, -1, -1, -1, -1>                          \
379       <<<grid, block, (smem), at::cuda::getCurrentCUDAStream()>>>(          \
380         input_.packed_accessor32<const scalar_t, 5>(),                      \
381         output_.packed_accessor32<scalar_t, 5>(),                           \
382         weight_.packed_accessor32<const scalar_t, 5>(),                     \
383         bias_ptr,                                                           \
384         stride[0], stride[1], stride[2],                                    \
385         padding[0], padding[1], padding[2],                                 \
386         dilation[0], dilation[1], dilation[2]);                             \
387     C10_CUDA_KERNEL_LAUNCH_CHECK();                                         \
388   }
389 
conv_depthwise3d_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)390 Tensor conv_depthwise3d_cuda(
391     const Tensor& input,
392     const Tensor& weight,
393     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
394     IntArrayRef stride,
395     IntArrayRef padding,
396     IntArrayRef dilation) {
397   // See [Note: hacky wrapper removal for optional tensor]
398   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
399   const Tensor& bias = *bias_maybe_owned;
400 
401   TORCH_CHECK(input.device() == weight.device(), "expects input and weight tensors to be on the same device.");
402   if (bias.defined()) {
403     TORCH_CHECK(input.device() == bias.device(), "expects input and bias tensors to be on the same device.");
404   }
405 
406   conv_depthwise_shape_check<3>(input, weight, bias, Tensor() /* undefined */,
407                                 kernel_size, stride, padding, dilation);
408 
409   Tensor input_ = input.contiguous();
410 
411   if (input.dim() == 4 /* no batch */) {
412     input_ = input.unsqueeze(0);
413   }
414 
415   auto output_size = conv_output_size(input_.sizes(), weight.sizes(),
416                                       padding, stride, dilation);
417   for (size_t i = 0; i < output_size.size(); ++i) {
418     TORCH_CHECK(output_size[i] > 0,
419                 "Output size should be positive, got ", output_size[i], " at dim ", i);
420   }
421   Tensor output = at::empty(output_size, input.options());
422   Tensor output_ = output;
423   Tensor weight_ = weight.contiguous();
424   Tensor bias_ = bias.defined() ? bias.contiguous() : bias;
425 
426   AT_DISPATCH_FLOATING_TYPES_AND2(
427       kHalf,
428       kBFloat16,
429       input.scalar_type(),
430       "conv_depthwise3d",
431       [&]{
432         int64_t num_outputs = output_.numel();
433         int64_t block = 256;
434         int64_t grid = std::min((num_outputs - 1) / block + 1, (int64_t)65536);
435         int64_t smem = 0;
436 
437         const scalar_t* bias_ptr =
438             bias_.defined() ? bias_.const_data_ptr<scalar_t>() : NULL;
439 
440         // Range check to avoid overflow in CUDA kernels.
441         TORCH_CHECK(input_.numel() <= std::numeric_limits<int32_t>::max(),
442                     "Input tensor is too large.");
443         TORCH_CHECK(output_.numel() <= std::numeric_limits<int32_t>::max(),
444                     "Output tensor is too large.");
445         TORCH_CHECK(weight_.numel() <= std::numeric_limits<int32_t>::max(),
446                     "Weight tensor is too large.");
447         for (int i = 0; i < 3; ++i) {
448           TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= std::numeric_limits<int32_t>::max(),
449                       "Padded input tensor is too large.");
450         }
451 
452         DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(3, 3, 3, 1, 1, 1)
453         DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(-1, -1, -1, 1, 1, 1)
454         DWCONV3D_FORWARD_DISPATCH_OTHERS
455       }
456   );
457 
458   return output;
459 }
460 
461 #undef DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION
462 #undef DWCONV3D_FORWARD_DISPATCH_OTHERS
463 
464 #define DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(                    \
465     kt, kh, kw, dilt, dilh, dilw, dt, dh, dw)                               \
466   if (NODEF_OR_EQUAL_3(kernel_size, (kt), (kh), (kw)) &&                    \
467       NODEF_OR_EQUAL_3(dilation, (dilt), (dilh), (dilw)) &&                 \
468       NODEF_OR_EQUAL_3(stride, (dt), (dh), (dw))) {                         \
469     using accscalar_t = acc_type<scalar_t, true>;                           \
470     conv_depthwise3d_cuda_backward_input_kernel                             \
471     <scalar_t, accscalar_t, (kt), (kh), (kw), (dilt), (dilh), (dilw), (dt), (dh), (dw)>  \
472       <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(               \
473         grad_output_.packed_accessor32<const scalar_t, 5>(),                \
474         grad_input_.packed_accessor32<scalar_t, 5>(),                       \
475         weight_.packed_accessor32<const scalar_t, 5>(),                     \
476         stride[0], stride[1], stride[2],                                    \
477         padding[0], padding[1], padding[2],                                 \
478         dilation[0], dilation[1], dilation[2]);                             \
479     C10_CUDA_KERNEL_LAUNCH_CHECK();                                         \
480   } else
481 
482 #define DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS                             \
483   {                                                                         \
484     using accscalar_t = acc_type<scalar_t, true>;                           \
485     conv_depthwise3d_cuda_backward_input_kernel                             \
486     <scalar_t, accscalar_t, -1, -1, -1, -1, -1, -1, -1, -1, -1>             \
487       <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(               \
488         grad_output_.packed_accessor32<const scalar_t, 5>(),                \
489         grad_input_.packed_accessor32<scalar_t, 5>(),                       \
490         weight_.packed_accessor32<const scalar_t, 5>(),                     \
491         stride[0], stride[1], stride[2],                                    \
492         padding[0], padding[1], padding[2],                                 \
493         dilation[0], dilation[1], dilation[2]);                             \
494     C10_CUDA_KERNEL_LAUNCH_CHECK();                                         \
495   }
496 
497 #define DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(dh, dw)            \
498   if (NODEF_OR_EQUAL_3(stride, -1, (dh), (dw))) {                           \
499     using accscalar_t = acc_type<scalar_t, true>;                           \
500     conv_depthwise3d_cuda_backward_weight_kernel                            \
501     <scalar_t, accscalar_t, (dh), (dw)>                                     \
502       <<<grid, block, smem, at::cuda::getCurrentCUDAStream()>>>(            \
503         grad_output_.packed_accessor32<const scalar_t, 5>(),                \
504         input_.packed_accessor32<const scalar_t, 5>(),                      \
505         grad_weight.packed_accessor32<scalar_t, 5>(),                       \
506         stride[0], stride[1], stride[2],                                    \
507         padding[0], padding[1], padding[2],                                 \
508         dilation[0], dilation[1], dilation[2]);                             \
509     C10_CUDA_KERNEL_LAUNCH_CHECK();                                         \
510   } else
511 
512 #define DWCONV3D_BACKWARD_WEIGHT_DISPATCH_OTHERS                            \
513   {                                                                         \
514     using accscalar_t = acc_type<scalar_t, true>;                           \
515     conv_depthwise3d_cuda_backward_weight_kernel                            \
516     <scalar_t, accscalar_t, -1, -1>                                         \
517       <<<grid, block, smem, at::cuda::getCurrentCUDAStream()>>>(            \
518         grad_output_.packed_accessor32<const scalar_t, 5>(),                \
519         input_.packed_accessor32<const scalar_t, 5>(),                      \
520         grad_weight.packed_accessor32<scalar_t, 5>(),                       \
521         stride[0], stride[1], stride[2],                                    \
522         padding[0], padding[1], padding[2],                                 \
523         dilation[0], dilation[1], dilation[2]);                             \
524     C10_CUDA_KERNEL_LAUNCH_CHECK();                                         \
525   }
526 
_depthwise_3d_backward_cuda_out(Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias,const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,const std::array<bool,3> output_mask)527 std::tuple<Tensor&, Tensor&, Tensor&> _depthwise_3d_backward_cuda_out(
528     Tensor& grad_input,
529     Tensor& grad_weight,
530     Tensor& grad_bias,
531     const Tensor& grad_output,
532     const Tensor& input,
533     const Tensor& weight,
534     IntArrayRef kernel_size,
535     IntArrayRef stride,
536     IntArrayRef padding,
537     IntArrayRef dilation,
538     const std::array<bool, 3> output_mask)
539 {
540 
541   TORCH_CHECK(grad_output.device() == input.device() &&
542               input.device() == weight.device(),
543               "expects input, weight and grad_output to be on the same device.");
544   conv_depthwise_shape_check<3>(
545       input, weight, Tensor() /* undefined */, grad_output,
546       kernel_size, stride, padding, dilation);
547 
548   const Tensor grad_output_ = grad_output.contiguous();
549 
550   Tensor grad_input_ =
551       (output_mask[0] ?  grad_input
552                       : Tensor());
553 
554   if (output_mask[0]) {
555     const Tensor weight_ = weight.contiguous();
556     AT_DISPATCH_FLOATING_TYPES_AND2(
557         kHalf,
558         kBFloat16,
559         grad_output.scalar_type(),
560         "conv_depthwise3d",
561         [&] {
562           int64_t num_inputs = grad_input_.numel();
563           int64_t block = 256;
564           int64_t grid = std::min((num_inputs - 1) / block + 1, (int64_t)65536);
565 
566           // Range check to avoid overflow in CUDA kernels.
567           TORCH_CHECK(grad_input_.numel() <= std::numeric_limits<int32_t>::max(),
568                       "Input tensor is too large.");
569           TORCH_CHECK(grad_output_.numel() <= std::numeric_limits<int32_t>::max(),
570                       "Output tensor is too large.");
571           TORCH_CHECK(weight_.numel() <= std::numeric_limits<int32_t>::max(),
572                       "Weight tensor is too large.");
573           for (int i = 0; i < 3; ++i) {
574             TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= std::numeric_limits<int32_t>::max(),
575                         "Padded input tensor is too large.");
576           }
577 
578           DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(
579               3, 3, 3, 1, 1, 1, 1, 1, 1)
580           DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(
581               3, 3, 3, 1, 1, 1, -1, -1, -1)
582           DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(
583               3, 3, 3, -1, -1, -1, 1, 1, 1)
584           DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION(
585               3, 3, 3, -1, -1, -1, -1, -1, -1)
586           DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS
587         }
588     );
589   }
590 
591   if (output_mask[1]) {
592     const Tensor input_ = input.contiguous();
593     AT_DISPATCH_FLOATING_TYPES_AND2(
594         kHalf,
595         kBFloat16,
596         grad_output.scalar_type(),
597         "conv_depthwise3d",
598         [&] {
599           int64_t grid = grad_weight.numel();
600           int64_t block = 256;
601           int64_t smem = sizeof(scalar_t) * block;
602 
603           const int64_t int_max = std::numeric_limits<int32_t>::max();
604           TORCH_CHECK(grad_input_.numel() <= int_max,
605                       "Input tensor is too large.");
606           TORCH_CHECK(grad_output_.numel() <= int_max,
607                       "Output tensor is too large.");
608           TORCH_CHECK(weight.numel() <= int_max,
609                       "Weight tensor is too large.");
610           for (int i = 0; i < 3; ++i) {
611             TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= int_max,
612                         "Padded input tensor is too large.");
613           }
614           int64_t warp_size = at::cuda::warp_size();
615           TORCH_CHECK(grad_output_.size(0) * grad_output_.size(2) < int_max - block / warp_size &&
616                       grad_output_.size(3) <= int_max - warp_size &&
617                       grad_output_.size(4) <= int_max - warp_size,
618                       "Output size is too large.");
619 
620           DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(1, 1)
621           DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(2, 2)
622           DWCONV3D_BACKWARD_WEIGHT_DISPATCH_OTHERS
623         }
624     );
625   }
626 
627   if (output_mask[2]) {
628     grad_bias = grad_output.sum({0, 2, 3, 4});
629   }
630 
631   return std::tie(grad_input, grad_weight, grad_bias);
632 
633 }
634 
635 
conv_depthwise3d_backward_cuda_out(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)636 std::tuple<Tensor&, Tensor&, Tensor&> conv_depthwise3d_backward_cuda_out(const Tensor& grad_output,
637     const Tensor& input,
638     const Tensor& weight,
639     IntArrayRef kernel_size,
640     IntArrayRef stride,
641     IntArrayRef padding,
642     IntArrayRef dilation,
643     Tensor& grad_input,
644     Tensor& grad_weight,
645     Tensor& grad_bias) {
646   if (grad_weight.defined()) {
647     grad_weight.resize_(weight.sizes());
648     grad_weight.zero_();
649   }
650 
651   return _depthwise_3d_backward_cuda_out(
652       grad_input,
653       grad_weight,
654       grad_bias,
655       grad_output,
656       input,
657       weight,
658       kernel_size,
659       stride,
660       padding,
661       dilation,
662       {true,true,true});
663 }
664 
conv_depthwise3d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,const std::array<bool,3> output_mask)665 std::tuple<Tensor, Tensor, Tensor> conv_depthwise3d_backward_cuda(
666     const Tensor& grad_output,
667     const Tensor& input,
668     const Tensor& weight,
669     IntArrayRef kernel_size,
670     IntArrayRef stride,
671     IntArrayRef padding,
672     IntArrayRef dilation,
673     const std::array<bool, 3> output_mask) {
674 
675   auto options = grad_output.options();
676   Tensor grad_input =
677       (output_mask[0] ? at::empty(input.sizes(), options) : Tensor());
678   Tensor grad_weight =
679       (output_mask[1] ? at::empty(weight.sizes(), options) : Tensor());
680   Tensor grad_bias; /* undefined temporarily */
681 
682   return _depthwise_3d_backward_cuda_out(
683       grad_input,
684       grad_weight,
685       grad_bias,
686       grad_output,
687       input,
688       weight,
689       kernel_size,
690       stride,
691       padding,
692       dilation,
693       output_mask
694   );
695 
696 }
697 
698 REGISTER_CUDA_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_cuda);
699 
700 #undef DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION
701 #undef DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS
702 
703 #undef NODEF_OR_EQUAL_3
704 #undef NODEF_OR_EQUAL
705 
706 } // namespace at::native
707