xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/DepthwiseConv2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/AccumulateType.h>
5 #include <ATen/div_rtn.h>
6 #include <ATen/cuda/CUDABlas.h>
7 #include <ATen/cuda/detail/KernelUtils.h>
8 #include <ATen/native/ConvUtils.h>
9 #include <ATen/native/cuda/block_reduce.cuh>
10 #include <ATen/native/Resize.h>
11 #include <ATen/native/IndexingUtils.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/_conv_depthwise2d_native.h>
19 #endif
20 
21 namespace at::native {
22 namespace {
23 using at::cuda::detail::CUDA_NUM_THREADS;
24 using at::cuda::detail::GET_BLOCKS;
25 
26 template <typename scalar_t, int ndim, template <typename U> class PtrTraits = DefaultPtrTraits>
dummy_packed_accessor32()27 PackedTensorAccessor32<scalar_t, ndim, PtrTraits> dummy_packed_accessor32() {
28   std::array<int64_t, ndim> zeros{};
29   return {nullptr, zeros.data(), zeros.data()};
30 }
31 
32 template <typename scalar_t, typename index_t>
33 __global__ void
34 #if !defined(USE_ROCM)
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)35 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
36 #endif
37 conv_depthwise2d_forward_kernel_generic(
38     const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> input,
39     PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> output,
40     const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> weight,
41     const PackedTensorAccessor32<const scalar_t, 1, DefaultPtrTraits> bias,
42     bool biasEnabled,
43     index_t totalElements,
44     const int outputChannels,
45     const int depthwiseMultiplier,
46     const int inputWidth, const int inputHeight,
47     const int outputWidth, const int outputHeight,
48     const int kernelWidth, const int kernelHeight,
49     const int strideWidth, const int strideHeight,
50     const int padWidth, const int padHeight,
51     const int dilationWidth, const int dilationHeight) {
52   using acc_t = at::acc_type<scalar_t, true>;
53 
54   CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
55     //calculate n,c,h,w indices, replacing modulos by divide and multiply add,
56     //result is same as would be in the code below
57     //const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth
58     //const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth
59     //const int h = (linearIndex / outputWidth) % outputHeight;
60     //const int w = linearIndex % outputWidth;
61 
62     int indtmp1 = linearIndex/outputWidth;
63     const int w = linearIndex - indtmp1 * outputWidth;
64     int indtmp2 = indtmp1/outputHeight;
65     const int h = indtmp1 - indtmp2 * outputHeight;
66     indtmp1 = indtmp2;
67     indtmp2 = indtmp1/outputChannels;
68     const int c = indtmp1 - indtmp2 * outputChannels;
69     const int n = indtmp2;
70 
71     int inputChannel = c;
72     int inputChannels = outputChannels;
73     if (depthwiseMultiplier !=1) {
74       inputChannel /= depthwiseMultiplier;
75       inputChannels /= depthwiseMultiplier;
76     }
77 
78     int weightOffset = c * kernelHeight * kernelWidth;
79 
80     // By precisely computing the filtering boundaries, we avoid repeating several
81     // expensive edge condition checks for every fetched item. If the input element is
82     // resident in L1, then the extra branches and comparisons would have been
83     // comparable in terms of cycles with the actual data fetch. Therefore computing
84     // boundaries ahead of the loop showed significant performance boost.
85 
86     int kHmin = 0, kHmax = kernelHeight, kWmin = 0, kWmax = kernelWidth;
87 
88     // Top
89     int h_in_min = -padHeight + h * strideHeight;
90     if (h_in_min < 0) {
91       kHmin =  -h_in_min / dilationHeight;
92       if ((-h_in_min) % dilationHeight > 0) {
93         kHmin++;
94       }
95     }
96 
97     // Bottom
98     int h_in_max = h_in_min + (kernelHeight - 1) * dilationHeight - inputHeight + 1;
99     if (h_in_max >= 0) {
100       kHmax = kernelHeight - h_in_max / dilationHeight;
101       if (h_in_max % dilationHeight > 0) {
102         kHmax--;
103       }
104     }
105 
106     // Left
107     int w_in_min = -padWidth + w * strideWidth;
108     if (w_in_min < 0) {
109       kWmin = -w_in_min / dilationWidth;
110       if ((-w_in_min) % dilationWidth > 0) {
111         kWmin++;
112       }
113     }
114 
115     // Right
116     int w_in_max = w_in_min + (kernelWidth - 1) * dilationWidth - inputWidth + 1;
117     if (w_in_max >= 0) {
118       kWmax = kernelWidth - w_in_max / dilationWidth;
119       if (w_in_max % dilationWidth > 0) {
120         kWmax--;
121       }
122     }
123 
124     acc_t value = biasEnabled ? static_cast<acc_t>(bias.data()[c]) : acc_t(0);
125     const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
126 
127     for (int kH = kHmin; kH < kHmax; ++kH) {
128       const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
129       for (int kW = kWmin; kW < kWmax; ++kW) {
130         const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
131         const index_t offset = offset0 + h_in * inputWidth + w_in;
132         value += (static_cast<acc_t>(weight.data()[weightOffset + kH * kernelWidth + kW]) *
133                     static_cast<acc_t>(input.data()[offset]));
134       }
135     }
136     output.data()[linearIndex] = static_cast<scalar_t>(value);
137   }
138 }
139 
140 template <int kSize, typename scalar_t, typename index_t>
141 __global__ void
142 #if !defined(USE_ROCM)
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)143 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
144 #endif
145 conv_depthwise2d_forward_kernel(
146     const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> input,
147     PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> output,
148     const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> weight,
149     const PackedTensorAccessor32<const scalar_t, 1, DefaultPtrTraits> bias,
150     bool biasEnabled,
151     index_t totalElements,
152     const int outputChannels,
153     const int depthwiseMultiplier,
154     const int inputWidth, const int inputHeight,
155     const int outputWidth, const int outputHeight,
156     const int kernelWidth, const int kernelHeight,
157     const int strideWidth, const int strideHeight,
158     const int padWidth, const int padHeight,
159     const int dilationWidth, const int dilationHeight) {
160   using acc_t = at::acc_type<scalar_t, true>;
161   const int KW_LIMIT = (kSize != 0) ? kSize : kernelWidth;
162   const int KH_LIMIT = (kSize != 0) ? kSize : kernelHeight;
163 
164   CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
165     //calculate n,c,h,w indices, replacing modulos by divide and multiply add,
166     //result is same as would be in the code below
167     //const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth
168     //const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth
169     //const int h = (linearIndex / outputWidth) % outputHeight;
170     //const int w = linearIndex % outputWidth;
171 
172     int indtmp1 = linearIndex/outputWidth;
173     const int w = linearIndex - indtmp1 * outputWidth;
174     int indtmp2 = indtmp1/outputHeight;
175     const int h = indtmp1 - indtmp2 * outputHeight;
176     indtmp1 = indtmp2;
177     indtmp2 = indtmp1/outputChannels;
178     const int c = indtmp1 - indtmp2 * outputChannels;
179     const int n = indtmp2;
180 
181     int inputChannel = c;
182     int inputChannels = outputChannels;
183     if (depthwiseMultiplier !=1) {
184       inputChannel /= depthwiseMultiplier;
185       inputChannels /= depthwiseMultiplier;
186     }
187 
188     int weightOffset = c * kernelHeight * kernelWidth;
189 
190     acc_t value = biasEnabled ? static_cast<acc_t>(bias.data()[c]) : acc_t(0);
191     const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
192 #if !defined(USE_ROCM)
193 #pragma unroll
194 #endif
195     for (int kH = 0; kH < KH_LIMIT; ++kH) {
196 #if !defined(USE_ROCM)
197 #pragma unroll
198 #endif
199       for (int kW = 0; kW < KW_LIMIT; ++kW) {
200         const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
201         const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
202 
203         if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) {
204           const index_t offset = offset0 + h_in * inputWidth + w_in;
205           value += (static_cast<acc_t>(weight.data()[weightOffset]) *
206                     static_cast<acc_t>(input.data()[offset]));
207         }
208         ++weightOffset;
209       }
210     }
211     output.data()[linearIndex] = static_cast<scalar_t>(value);
212   }
213 }
214 
215 template <int kSize, int stride, typename scalar_t, typename index_t>
216 #if !defined(USE_ROCM)
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)217 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
218 #endif
219 __global__ void conv_depthwise2d_backward_kernel(
220     const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> grad_output,
221     PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> grad_input,
222     const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> weight,
223     index_t totalElements,
224     const int inputChannels,
225     const int depthwiseMultiplier,
226     const int outputChannels,
227     const int inputWidth, const int inputHeight,
228     const int outputWidth, const int outputHeight,
229     const int kernelWidth, const int kernelHeight,
230     const int strideWidth, const int strideHeight,
231     const int padWidth, const int padHeight,
232     const int dilationWidth, const int dilationHeight) {
233   using acc_t = at::acc_type<scalar_t, true>;
234   const int KW_LIMIT = (kSize != 0) ? kSize : kernelWidth;
235   const int KH_LIMIT = (kSize != 0) ? kSize : kernelHeight;
236   const int strideW = (stride != 0) ? stride : strideWidth;
237   const int strideH = (stride != 0) ? stride : strideHeight;
238 
239   CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
240     int indtmp1 = linearIndex/inputWidth;
241     const int w = linearIndex - indtmp1 * inputWidth;
242     int indtmp2 = indtmp1/inputHeight;
243     const int h = indtmp1 - indtmp2 * inputHeight;
244     indtmp1 = indtmp2;
245     indtmp2 = indtmp1/inputChannels;
246     const int c = indtmp1 - indtmp2 * inputChannels;
247     const int n = indtmp2;
248 
249     acc_t value(0);
250 
251     for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) {
252       int och = (c * depthwiseMultiplier) + multiplier;
253       int weightOffset = och * kernelHeight * kernelWidth;
254       for (int kh = 0; kh < KH_LIMIT; ++kh) {
255 #if !defined(USE_ROCM)
256 #pragma unroll
257 #endif
258         for (int kw = 0; kw < KW_LIMIT; ++kw) {
259           int h_out = h + padHeight - kh * dilationHeight;
260           int w_out = w + padWidth - kw * dilationWidth;
261           if ((h_out % strideH == 0) && (w_out % strideW == 0)) {
262             h_out = h_out / strideH;
263             w_out = w_out / strideW;
264 
265             if ((h_out >= 0) && (h_out < outputHeight)
266                   && (w_out >= 0) && (w_out < outputWidth)) {
267 
268               const int offset = ((n * outputChannels + och) * outputHeight + h_out)
269                     * outputWidth + w_out;
270               value += (static_cast<acc_t>(weight.data()[weightOffset]) *
271                         static_cast<acc_t>(grad_output.data()[offset]));
272             }
273           }
274           ++weightOffset;
275         }
276       }
277     }
278     grad_input.data()[linearIndex] = static_cast<scalar_t>(value);
279   }
280 }
281 
282 template <typename scalar_t, typename index_t=unsigned>
conv_depthwise2d_grad_weight_kernel(const PackedTensorAccessor32<const scalar_t,4,DefaultPtrTraits> grad_output,const PackedTensorAccessor32<const scalar_t,4,DefaultPtrTraits> input,PackedTensorAccessor32<scalar_t,4,DefaultPtrTraits> grad_weight,const int batchSize,const int inputChannels,const int kernelChannels,const int depthwiseMultiplier,const int inputWidth,const int inputHeight,const int outputWidth,const int outputHeight,const int kernelWidth,const int kernelHeight,const int strideWidth,const int strideHeight,const int padWidth,const int padHeight,const int dilationWidth,const int dilationHeight)283 __global__ void conv_depthwise2d_grad_weight_kernel(
284     const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> grad_output,
285     const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> input,
286     PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> grad_weight,
287     const int batchSize,
288     const int inputChannels,
289     const int kernelChannels,
290     const int depthwiseMultiplier,
291     const int inputWidth, const int inputHeight,
292     const int outputWidth, const int outputHeight,
293     const int kernelWidth, const int kernelHeight,
294     const int strideWidth, const int strideHeight,
295     const int padWidth, const int padHeight,
296     const int dilationWidth, const int dilationHeight) {
297   using acc_t = at::acc_type<scalar_t, true>;
298   const int channelStride = kernelWidth * kernelHeight;
299 
300   // Each Block is responsible for accumulating over a permutation of
301   // (channels x kH x kW), use blockIdx to determine which one
302   int bidx = blockIdx.x;
303   int kW = bidx % kernelWidth;
304   int kH = (bidx / kernelWidth) % kernelHeight;
305   int ch = (bidx / channelStride);
306 
307   // Need to calculate which input channel is associated with this filter
308   // channel
309   int inputCh = ch / depthwiseMultiplier;
310 
311   acc_t grad(0);
312 
313   const int laneId = threadIdx.x % C10_WARP_SIZE;
314   const int batch = threadIdx.x / C10_WARP_SIZE;
315   const int nwarps = blockDim.x / C10_WARP_SIZE;
316   const int imageElements = outputWidth * outputHeight;
317   // Use warp per item.  In the original kernel, a threadblock was used to sum over NHW.
318   // Here, we use a warp to sum values over HW dimension, and if batchSize is larger than the
319   // number of warps, a warp would loop over remaining batch items (e.g. if there are 8 warps,
320   // warp 0 would go over 0-8-16 etc image, warp 1 over 1-9-17 etc). Later in blockReduce,
321   // all the warps will be reduced anyway, thus the full reduction will be over NHW, like it
322   // should be. That allows to get rid of one modulo operation inside the loop (because n/batchIdx
323   // now does not have to be computed through modulo, you are just looping over it), and
324   // bring a nice speed-up.
325   for (int batchIdx = batch; batchIdx < batchSize; batchIdx += nwarps){
326     // Warp-stride loop over elements in a batch item
327     for (index_t idx = laneId; idx < imageElements; idx += C10_WARP_SIZE) {
328     // Need to calculate the following: batch position, and offset into the grad_output
329     // in height, and width. We can intuit the corresponding position in the input from
330     // the other parameters we have
331       int go_w_offset = idx % outputWidth;
332       int go_h_offset = (idx / outputWidth);
333 
334       int i_w_offset = (go_w_offset * strideWidth) + (kW * dilationWidth) - padWidth;
335       int i_h_offset = (go_h_offset * strideHeight) + (kH * dilationHeight) - padHeight;
336 
337       if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < inputWidth && i_h_offset < inputHeight) {
338         int inputOffset = ((batchIdx * inputChannels + inputCh) * inputHeight + i_h_offset) * inputWidth + i_w_offset;
339         int outputOffset = ((batchIdx * kernelChannels + ch) * outputHeight ) * outputWidth + idx;
340         grad += (static_cast<acc_t>(input.data()[inputOffset]) *
341                  static_cast<acc_t>(grad_output.data()[outputOffset]));
342       }
343     }
344   }
345 
346   // At this point each thread in the block has a local gradient, which we need to
347   // accumulate prior to writing the global value
348   extern __shared__ char smem[];
349   acc_t* buf = reinterpret_cast<acc_t*>(smem);
350   acc_t tval = cuda_utils::BlockReduceSum(grad, buf);
351 
352   // After reduction, first thread in the block has the gradient, so its responsible
353   // for writing it to grad_weight
354   if (threadIdx.x == 0) {
355     int weightOffset = kW + (kernelWidth * kH) + (kernelWidth * kernelHeight * ch);
356     grad_weight.data()[weightOffset] = static_cast<scalar_t>(tval);
357   }
358 }
359 
conv_depthwise2d_forward_out(const Tensor & input,const Tensor & output,const Tensor & weight,const Tensor & bias,const int kW,const int kH,const int dW,const int dH,const int padW,const int padH,const int dilationW,const int dilationH)360 void conv_depthwise2d_forward_out(
361                   const Tensor &input,
362                   const Tensor &output,
363                   const Tensor &weight,
364                   const Tensor &bias,
365                   const int kW, const int kH,
366                   const int dW, const int dH,
367                   const int padW, const int padH,
368                   const int dilationW, const int dilationH) {
369   // Only handle 4D Input Tensors for now
370   TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
371   TORCH_CHECK(weight.numel() > 0 && weight.dim() == 4);
372   TORCH_CHECK(output.is_contiguous());
373 
374   auto in_sizes = input.sizes();
375   auto w_sizes = weight.sizes();
376 
377   // We assume that the input and weight Tensors are shaped properly by
378   // the caller, so we verify that here to some extent
379 
380   // Weight Tensor is shape (output_channels, 1, kH, kW)
381   TORCH_CHECK(w_sizes[1] == 1);
382 
383   // Input Tensor is shape (N, input_channels, H, W)
384   // We verify that the # of output_channels is a multiple of input_channels
385   TORCH_CHECK(w_sizes[0] % in_sizes[1] == 0);
386 
387   // Bias has same # of channels as output
388   const bool has_bias = bias.defined();
389   TORCH_CHECK(!has_bias || (bias.dim() <= 1 && bias.numel() == w_sizes[0]));
390 
391   // Following the behavior of other THCUNN functions, we shape the output
392   // Tensor ourselves
393   int64_t height = in_sizes[2];
394   int64_t width = in_sizes[3];
395   int64_t outputChannels = w_sizes[0];
396   auto out_sizes = conv_output_size(in_sizes, weight.sizes(), {padH, padW}, {dH, dW},
397                                     {dilationH, dilationW});
398   const auto outputWidth = out_sizes[3];
399   const auto outputHeight = out_sizes[2];
400 
401   resize_output(output, out_sizes);
402 
403   int64_t inputChannels = in_sizes[1];
404   int64_t depthwiseMultiplier = outputChannels / inputChannels;
405 
406   // One thread per output value
407   TORCH_CHECK(canUse32BitIndexMath(input) && canUse32BitIndexMath(output));
408   int32_t n = output.numel();
409   int blocks = GET_BLOCKS(n);
410   dim3 grid(blocks);
411   dim3 block(CUDA_NUM_THREADS);
412   const auto stream = c10::cuda::getCurrentCUDAStream();
413 
414   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
415                                   "conv_depthwise2d_forward_cuda", [&] {
416     // Create PackedTensorAccessor
417     // Kernel currently relies upon all the Tensors to be contiguous, but we made
418     // them contiguous above
419     const auto input_a = input.packed_accessor32<const scalar_t, 4>();
420     const auto weight_a = weight.packed_accessor32<const scalar_t, 4>();
421     const auto output_a = output.packed_accessor32<scalar_t, 4>();
422     const auto bias_a = has_bias ?
423       bias.packed_accessor32<const scalar_t, 1>() :
424       dummy_packed_accessor32<const scalar_t, 1>();
425     if (kW == 5 && kH == 5) {
426       conv_depthwise2d_forward_kernel<5> <<<grid, block, 0, stream>>>(
427         input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
428         width, height, outputWidth, outputHeight,
429         kW, kH, dW, dH, padW, padH, dilationW, dilationH);
430       C10_CUDA_KERNEL_LAUNCH_CHECK();
431     } else if (kW == 3 && kH == 3) {
432       conv_depthwise2d_forward_kernel<3> <<<grid, block, 0, stream>>>(
433         input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
434         width, height, outputWidth, outputHeight,
435         kW, kH, dW, dH, padW, padH, dilationW, dilationH);
436       C10_CUDA_KERNEL_LAUNCH_CHECK();
437     } else if (kW == 1 && kH == 1) {
438       conv_depthwise2d_forward_kernel<1> <<<grid, block, 0, stream>>>(
439         input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
440         width, height, outputWidth, outputHeight,
441         kW, kH, dW, dH, padW, padH, dilationW, dilationH);
442       C10_CUDA_KERNEL_LAUNCH_CHECK();
443     } else {
444       conv_depthwise2d_forward_kernel_generic<<<grid, block, 0, stream>>>(
445         input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
446         width, height, outputWidth, outputHeight,
447         kW, kH, dW, dH, padW, padH, dilationW, dilationH);
448       C10_CUDA_KERNEL_LAUNCH_CHECK();
449     }
450   });
451 }
452 
conv_depthwise2d_backward_out(const Tensor & input,const Tensor & grad_output,const Tensor & grad_input,const Tensor & weight,const int kW,const int kH,const int dW,const int dH,const int padW,const int padH,const int dilationW,const int dilationH)453 void conv_depthwise2d_backward_out(
454                   const Tensor &input,
455                   const Tensor &grad_output,
456                   const Tensor &grad_input,
457                   const Tensor &weight,
458                   const int kW, const int kH,
459                   const int dW, const int dH,
460                   const int padW, const int padH,
461                   const int dilationW, const int dilationH) {
462   // Only handle 4D Input Tensors for now
463   TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
464   TORCH_CHECK(weight.numel() > 0 && weight.dim() == 4);
465   TORCH_CHECK(grad_output.numel() > 0 && grad_output.dim() == 4);
466 
467   // Minimal shape checking, as above
468   // Same # of elements in batch
469   TORCH_CHECK(input.sizes()[0] == grad_output.sizes()[0]);
470   // Same # of filters as outputChannels
471   TORCH_CHECK(weight.sizes()[0] == grad_output.sizes()[1]);
472 
473   // Resize Grainput_a
474   auto in_sizes = input.sizes();
475   resize_output(grad_input, in_sizes);
476 
477   int inputChannels = in_sizes[1];
478   int height = in_sizes[2];
479   int width = in_sizes[3];
480 
481   auto gO_sizes = grad_output.sizes();
482   int outputChannels = gO_sizes[1];
483   int outputHeight = gO_sizes[2];
484   int outputWidth = gO_sizes[3];
485 
486   int depthwiseMultiplier = outputChannels / inputChannels;
487 
488   // Kernel currently relies upon all the Tensors to be contiguous
489   TORCH_CHECK(grad_output.is_contiguous());
490   TORCH_CHECK(weight.is_contiguous());
491   TORCH_CHECK(grad_input.is_contiguous());
492 
493   // One thread per grainput_a value
494   TORCH_CHECK(canUse32BitIndexMath(grad_input) &&
495               canUse32BitIndexMath(grad_output));
496   int32_t n = grad_input.numel();
497   int blocks = GET_BLOCKS(n);
498   dim3 grid(blocks);
499   dim3 block(CUDA_NUM_THREADS);
500   const auto stream = c10::cuda::getCurrentCUDAStream();
501   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
502                                   "conv_depthwise2d_backward_cuda", [&] {
503     auto grad_output_a = grad_output.packed_accessor32<const scalar_t, 4>();
504     auto grad_input_a = grad_input.packed_accessor32<scalar_t, 4>();
505     auto weight_a = weight.packed_accessor32<const scalar_t, 4>();
506 
507     if (kW == 5 && kH == 5) {
508       if (dW == 1 && dH == 1){
509         conv_depthwise2d_backward_kernel<5, 1><<<grid, block, 0, stream>>>(
510             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
511             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
512         C10_CUDA_KERNEL_LAUNCH_CHECK();
513       } else if (dW == 2 && dH == 2) {
514         conv_depthwise2d_backward_kernel<5, 2><<<grid, block, 0, stream>>>(
515             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
516             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
517         C10_CUDA_KERNEL_LAUNCH_CHECK();
518       } else {
519         conv_depthwise2d_backward_kernel<5, 0><<<grid, block, 0, stream>>>(
520             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
521             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
522         C10_CUDA_KERNEL_LAUNCH_CHECK();
523       }
524     } else if (kW == 3 && kH == 3) {
525       if (dW == 1 && dH == 1){
526         conv_depthwise2d_backward_kernel<3, 1><<<grid, block, 0, stream>>>(
527             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
528             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
529         C10_CUDA_KERNEL_LAUNCH_CHECK();
530       } else if (dW == 2 && dH == 2) {
531         conv_depthwise2d_backward_kernel<3, 2><<<grid, block, 0, stream>>>(
532             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
533             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
534         C10_CUDA_KERNEL_LAUNCH_CHECK();
535       } else {
536         conv_depthwise2d_backward_kernel<3, 0><<<grid, block, 0, stream>>>(
537             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
538             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
539         C10_CUDA_KERNEL_LAUNCH_CHECK();
540       }
541     } else if (kW == 1 && kH == 1) {
542       if (dW == 1 && dH == 1){
543         conv_depthwise2d_backward_kernel<1, 1><<<grid, block, 0, stream>>>(
544             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
545             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
546         C10_CUDA_KERNEL_LAUNCH_CHECK();
547       } else if (dW == 2 && dH == 2) {
548         conv_depthwise2d_backward_kernel<1, 2><<<grid, block, 0, stream>>>(
549             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
550             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
551         C10_CUDA_KERNEL_LAUNCH_CHECK();
552       } else {
553         conv_depthwise2d_backward_kernel<1, 0><<<grid, block, 0, stream>>>(
554             grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
555             height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
556         C10_CUDA_KERNEL_LAUNCH_CHECK();
557       }
558     } else if (dW == 1 && dH == 1) {
559       conv_depthwise2d_backward_kernel<0, 1><<<grid, block, 0, stream>>>(
560           grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
561           height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
562       C10_CUDA_KERNEL_LAUNCH_CHECK();
563     } else if (dW == 2 && dH == 2) {
564       conv_depthwise2d_backward_kernel<0, 2><<<grid, block, 0, stream>>>(
565           grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
566           height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
567       C10_CUDA_KERNEL_LAUNCH_CHECK();
568     } else {
569       conv_depthwise2d_backward_kernel<0, 0><<<grid, block, 0, stream>>>(
570           grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
571           height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
572       C10_CUDA_KERNEL_LAUNCH_CHECK();
573     }
574   });
575 }
576 
577 // Crude benchmarks suggest 256 is better than 512 and 1024
578 // TODO: Autotune/use better heuristics, improve speed more.
getGradParamsNumThreads(int batchSize)579 int getGradParamsNumThreads(int batchSize) {
580   //warp per item in a batch, up to a maximum
581   constexpr int MAX_BLOCK_SIZE = 256;
582   return std::min(batchSize * at::cuda::warp_size(), MAX_BLOCK_SIZE);
583 }
584 
conv_depthwise2d_grad_weight_out(const Tensor & input,const Tensor & grad_output,const Tensor & grad_weight,const int kW,const int kH,const int dW,const int dH,const int padW,const int padH,const int dilationW,const int dilationH)585 void conv_depthwise2d_grad_weight_out(
586                   const Tensor &input,
587                   const Tensor &grad_output,
588                   const Tensor &grad_weight,
589                   const int kW, const int kH,
590                   const int dW, const int dH,
591                   const int padW, const int padH,
592                   const int dilationW, const int dilationH) {
593   // Only handle 4D Input Tensors for now
594   TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
595   TORCH_CHECK(grad_output.numel() > 0 && grad_output.dim() == 4);
596 
597   // Minimal shape checking as above
598   // Same # of elements in batch
599   TORCH_CHECK(input.sizes()[0] == grad_output.sizes()[0]);
600 
601   auto in_sizes = input.sizes();
602   int batchSize = in_sizes[0];
603   int inputChannels = in_sizes[1];
604   int height = in_sizes[2];
605   int width = in_sizes[3];
606 
607   auto gO_sizes = grad_output.sizes();
608   int outputChannels = gO_sizes[1];
609   int outputHeight = gO_sizes[2];
610   int outputWidth = gO_sizes[3];
611 
612   int depthwiseMultiplier = outputChannels / inputChannels;
613 
614   resize_output(grad_weight, {outputChannels, 1, kH, kW});
615 
616   // Kernel currently relies upon all the Tensors to be contiguous
617   TORCH_CHECK(grad_output.is_contiguous());
618   TORCH_CHECK(input.is_contiguous());
619   TORCH_CHECK(grad_weight.is_contiguous());
620 
621   // We parallelize so that each block computes a single value in grad_weight
622   TORCH_CHECK(canUse32BitIndexMath(input) &&
623               canUse32BitIndexMath(grad_output));
624   int blocks = outputChannels * kH * kW;
625 
626   // Make sure we have enough threads to perform the reduction, and use this number
627   // to create the shared memory size for the reduction
628   dim3 grid(blocks);
629   dim3 block(getGradParamsNumThreads(batchSize));
630   const auto stream = c10::cuda::getCurrentCUDAStream();
631 
632   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
633                                   "conv_depthwise2d_grad_weight_cuda", [&] {
634     const auto grad_output_a = grad_output.packed_accessor32<const scalar_t, 4>();
635     const auto input_a = input.packed_accessor32<const scalar_t, 4>();
636     const auto grad_weight_a = grad_weight.packed_accessor32<scalar_t, 4>();
637     using acc_t = at::acc_type<scalar_t, true>;
638     int warp_size = at::cuda::warp_size();
639     TORCH_INTERNAL_ASSERT(block.x % warp_size == 0);
640     int smem = (block.x  / warp_size) * sizeof(acc_t);
641     conv_depthwise2d_grad_weight_kernel<<<grid, block, smem, stream>>>(
642         grad_output_a, input_a, grad_weight_a, batchSize, inputChannels, outputChannels, depthwiseMultiplier,
643         width, height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
644     C10_CUDA_KERNEL_LAUNCH_CHECK();
645   });
646 }
647 
648 }  // namespace (anonymous)
649 
conv_depthwise2d_cuda_out(const Tensor & input_,const Tensor & weight_,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,const Tensor & out)650 const Tensor& conv_depthwise2d_cuda_out(
651     const Tensor &input_,
652     const Tensor &weight_,
653     IntArrayRef kernel_size,
654     const std::optional<Tensor> &bias_opt,
655     IntArrayRef stride,
656     IntArrayRef padding,
657     IntArrayRef dilation,
658     const Tensor &out) {
659   TORCH_CHECK(kernel_size.size() == 2);
660   TORCH_CHECK(stride.size() == 2);
661   TORCH_CHECK(padding.size() == 2);
662   TORCH_CHECK(dilation.size() == 2);
663 
664   auto input = input_.expect_contiguous();
665   auto weight = weight_.expect_contiguous();
666   auto bias = [&] {
667     if (bias_opt.has_value() && bias_opt->defined()) {
668       return bias_opt->expect_contiguous();
669     }
670     return c10::MaybeOwned<Tensor>::owned(std::in_place);
671   }();
672 
673   conv_depthwise2d_forward_out(
674       *input,
675       out,
676       *weight,
677       *bias,
678       kernel_size[1], kernel_size[0],
679       stride[1], stride[0],
680       padding[1], padding[0],
681       dilation[1], dilation[0]);
682   return out;
683 }
684 
conv_depthwise2d_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)685 Tensor conv_depthwise2d_cuda(
686     const Tensor &input,
687     const Tensor &weight,
688     IntArrayRef kernel_size,
689     const std::optional<Tensor> &bias,
690     IntArrayRef stride,
691     IntArrayRef padding,
692     IntArrayRef dilation) {
693   auto out = at::empty({0}, input.options());
694   return conv_depthwise2d_cuda_out(input, weight, kernel_size, bias,
695                                    stride, padding, dilation, out);
696 }
697 
conv_depthwise2d_backward_cuda_out(const Tensor & grad_output_,const Tensor & self_,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,Tensor & grad_input,Tensor & grad_weight)698 std::tuple<Tensor&, Tensor&> conv_depthwise2d_backward_cuda_out(
699     const Tensor & grad_output_,
700     const Tensor & self_,
701     const Tensor & weight_,
702     IntArrayRef kernel_size,
703     IntArrayRef stride,
704     IntArrayRef padding,
705     IntArrayRef dilation,
706     Tensor & grad_input,
707     Tensor & grad_weight) {
708   auto grad_output = grad_output_.expect_contiguous();
709 
710   if (grad_weight.defined()) {
711     auto self = self_.expect_contiguous();
712     conv_depthwise2d_grad_weight_out(
713         *self, *grad_output, grad_weight,
714         kernel_size[1], kernel_size[0],
715         stride[1], stride[0],
716         padding[1], padding[0],
717         dilation[1], dilation[0]);
718   }
719 
720   if (grad_input.defined()) {
721     auto weight = weight_.expect_contiguous();
722     conv_depthwise2d_backward_out(
723         self_, *grad_output, grad_input, *weight,
724         kernel_size[1], kernel_size[0],
725         stride[1], stride[0],
726         padding[1], padding[0],
727         dilation[1], dilation[0]);
728   }
729   return std::forward_as_tuple(grad_input, grad_weight);
730 }
731 
conv_depthwise2d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,std::array<bool,2> output_mask)732 std::tuple<Tensor, Tensor> conv_depthwise2d_backward_cuda(
733     const Tensor& grad_output,
734     const Tensor& self,
735     const Tensor& weight,
736     IntArrayRef kernel_size,
737     IntArrayRef stride,
738     IntArrayRef padding,
739     IntArrayRef dilation,
740     std::array<bool, 2> output_mask) {
741   Tensor grad_input;
742   Tensor grad_weight;
743 
744   if (output_mask[0]) {
745     grad_input = at::empty({0}, grad_output.options());
746   }
747 
748   if (output_mask[1]) {
749     grad_weight = at::empty({0}, grad_output.options());
750   }
751   return conv_depthwise2d_backward_cuda_out(
752       grad_output,
753       self,
754       weight,
755       kernel_size,
756       stride,
757       padding,
758       dilation,
759       grad_input,
760       grad_weight);
761 }
762 
763 REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda);
764 
765 } // namespace at::native
766