1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/AccumulateType.h>
5 #include <ATen/div_rtn.h>
6 #include <ATen/cuda/CUDABlas.h>
7 #include <ATen/cuda/detail/KernelUtils.h>
8 #include <ATen/native/ConvUtils.h>
9 #include <ATen/native/cuda/block_reduce.cuh>
10 #include <ATen/native/Resize.h>
11 #include <ATen/native/IndexingUtils.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/_conv_depthwise2d_native.h>
19 #endif
20
21 namespace at::native {
22 namespace {
23 using at::cuda::detail::CUDA_NUM_THREADS;
24 using at::cuda::detail::GET_BLOCKS;
25
26 template <typename scalar_t, int ndim, template <typename U> class PtrTraits = DefaultPtrTraits>
dummy_packed_accessor32()27 PackedTensorAccessor32<scalar_t, ndim, PtrTraits> dummy_packed_accessor32() {
28 std::array<int64_t, ndim> zeros{};
29 return {nullptr, zeros.data(), zeros.data()};
30 }
31
32 template <typename scalar_t, typename index_t>
33 __global__ void
34 #if !defined(USE_ROCM)
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)35 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
36 #endif
37 conv_depthwise2d_forward_kernel_generic(
38 const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> input,
39 PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> output,
40 const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> weight,
41 const PackedTensorAccessor32<const scalar_t, 1, DefaultPtrTraits> bias,
42 bool biasEnabled,
43 index_t totalElements,
44 const int outputChannels,
45 const int depthwiseMultiplier,
46 const int inputWidth, const int inputHeight,
47 const int outputWidth, const int outputHeight,
48 const int kernelWidth, const int kernelHeight,
49 const int strideWidth, const int strideHeight,
50 const int padWidth, const int padHeight,
51 const int dilationWidth, const int dilationHeight) {
52 using acc_t = at::acc_type<scalar_t, true>;
53
54 CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
55 //calculate n,c,h,w indices, replacing modulos by divide and multiply add,
56 //result is same as would be in the code below
57 //const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth
58 //const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth
59 //const int h = (linearIndex / outputWidth) % outputHeight;
60 //const int w = linearIndex % outputWidth;
61
62 int indtmp1 = linearIndex/outputWidth;
63 const int w = linearIndex - indtmp1 * outputWidth;
64 int indtmp2 = indtmp1/outputHeight;
65 const int h = indtmp1 - indtmp2 * outputHeight;
66 indtmp1 = indtmp2;
67 indtmp2 = indtmp1/outputChannels;
68 const int c = indtmp1 - indtmp2 * outputChannels;
69 const int n = indtmp2;
70
71 int inputChannel = c;
72 int inputChannels = outputChannels;
73 if (depthwiseMultiplier !=1) {
74 inputChannel /= depthwiseMultiplier;
75 inputChannels /= depthwiseMultiplier;
76 }
77
78 int weightOffset = c * kernelHeight * kernelWidth;
79
80 // By precisely computing the filtering boundaries, we avoid repeating several
81 // expensive edge condition checks for every fetched item. If the input element is
82 // resident in L1, then the extra branches and comparisons would have been
83 // comparable in terms of cycles with the actual data fetch. Therefore computing
84 // boundaries ahead of the loop showed significant performance boost.
85
86 int kHmin = 0, kHmax = kernelHeight, kWmin = 0, kWmax = kernelWidth;
87
88 // Top
89 int h_in_min = -padHeight + h * strideHeight;
90 if (h_in_min < 0) {
91 kHmin = -h_in_min / dilationHeight;
92 if ((-h_in_min) % dilationHeight > 0) {
93 kHmin++;
94 }
95 }
96
97 // Bottom
98 int h_in_max = h_in_min + (kernelHeight - 1) * dilationHeight - inputHeight + 1;
99 if (h_in_max >= 0) {
100 kHmax = kernelHeight - h_in_max / dilationHeight;
101 if (h_in_max % dilationHeight > 0) {
102 kHmax--;
103 }
104 }
105
106 // Left
107 int w_in_min = -padWidth + w * strideWidth;
108 if (w_in_min < 0) {
109 kWmin = -w_in_min / dilationWidth;
110 if ((-w_in_min) % dilationWidth > 0) {
111 kWmin++;
112 }
113 }
114
115 // Right
116 int w_in_max = w_in_min + (kernelWidth - 1) * dilationWidth - inputWidth + 1;
117 if (w_in_max >= 0) {
118 kWmax = kernelWidth - w_in_max / dilationWidth;
119 if (w_in_max % dilationWidth > 0) {
120 kWmax--;
121 }
122 }
123
124 acc_t value = biasEnabled ? static_cast<acc_t>(bias.data()[c]) : acc_t(0);
125 const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
126
127 for (int kH = kHmin; kH < kHmax; ++kH) {
128 const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
129 for (int kW = kWmin; kW < kWmax; ++kW) {
130 const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
131 const index_t offset = offset0 + h_in * inputWidth + w_in;
132 value += (static_cast<acc_t>(weight.data()[weightOffset + kH * kernelWidth + kW]) *
133 static_cast<acc_t>(input.data()[offset]));
134 }
135 }
136 output.data()[linearIndex] = static_cast<scalar_t>(value);
137 }
138 }
139
140 template <int kSize, typename scalar_t, typename index_t>
141 __global__ void
142 #if !defined(USE_ROCM)
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)143 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
144 #endif
145 conv_depthwise2d_forward_kernel(
146 const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> input,
147 PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> output,
148 const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> weight,
149 const PackedTensorAccessor32<const scalar_t, 1, DefaultPtrTraits> bias,
150 bool biasEnabled,
151 index_t totalElements,
152 const int outputChannels,
153 const int depthwiseMultiplier,
154 const int inputWidth, const int inputHeight,
155 const int outputWidth, const int outputHeight,
156 const int kernelWidth, const int kernelHeight,
157 const int strideWidth, const int strideHeight,
158 const int padWidth, const int padHeight,
159 const int dilationWidth, const int dilationHeight) {
160 using acc_t = at::acc_type<scalar_t, true>;
161 const int KW_LIMIT = (kSize != 0) ? kSize : kernelWidth;
162 const int KH_LIMIT = (kSize != 0) ? kSize : kernelHeight;
163
164 CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
165 //calculate n,c,h,w indices, replacing modulos by divide and multiply add,
166 //result is same as would be in the code below
167 //const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth
168 //const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth
169 //const int h = (linearIndex / outputWidth) % outputHeight;
170 //const int w = linearIndex % outputWidth;
171
172 int indtmp1 = linearIndex/outputWidth;
173 const int w = linearIndex - indtmp1 * outputWidth;
174 int indtmp2 = indtmp1/outputHeight;
175 const int h = indtmp1 - indtmp2 * outputHeight;
176 indtmp1 = indtmp2;
177 indtmp2 = indtmp1/outputChannels;
178 const int c = indtmp1 - indtmp2 * outputChannels;
179 const int n = indtmp2;
180
181 int inputChannel = c;
182 int inputChannels = outputChannels;
183 if (depthwiseMultiplier !=1) {
184 inputChannel /= depthwiseMultiplier;
185 inputChannels /= depthwiseMultiplier;
186 }
187
188 int weightOffset = c * kernelHeight * kernelWidth;
189
190 acc_t value = biasEnabled ? static_cast<acc_t>(bias.data()[c]) : acc_t(0);
191 const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
192 #if !defined(USE_ROCM)
193 #pragma unroll
194 #endif
195 for (int kH = 0; kH < KH_LIMIT; ++kH) {
196 #if !defined(USE_ROCM)
197 #pragma unroll
198 #endif
199 for (int kW = 0; kW < KW_LIMIT; ++kW) {
200 const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
201 const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
202
203 if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) {
204 const index_t offset = offset0 + h_in * inputWidth + w_in;
205 value += (static_cast<acc_t>(weight.data()[weightOffset]) *
206 static_cast<acc_t>(input.data()[offset]));
207 }
208 ++weightOffset;
209 }
210 }
211 output.data()[linearIndex] = static_cast<scalar_t>(value);
212 }
213 }
214
215 template <int kSize, int stride, typename scalar_t, typename index_t>
216 #if !defined(USE_ROCM)
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)217 C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
218 #endif
219 __global__ void conv_depthwise2d_backward_kernel(
220 const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> grad_output,
221 PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> grad_input,
222 const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> weight,
223 index_t totalElements,
224 const int inputChannels,
225 const int depthwiseMultiplier,
226 const int outputChannels,
227 const int inputWidth, const int inputHeight,
228 const int outputWidth, const int outputHeight,
229 const int kernelWidth, const int kernelHeight,
230 const int strideWidth, const int strideHeight,
231 const int padWidth, const int padHeight,
232 const int dilationWidth, const int dilationHeight) {
233 using acc_t = at::acc_type<scalar_t, true>;
234 const int KW_LIMIT = (kSize != 0) ? kSize : kernelWidth;
235 const int KH_LIMIT = (kSize != 0) ? kSize : kernelHeight;
236 const int strideW = (stride != 0) ? stride : strideWidth;
237 const int strideH = (stride != 0) ? stride : strideHeight;
238
239 CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
240 int indtmp1 = linearIndex/inputWidth;
241 const int w = linearIndex - indtmp1 * inputWidth;
242 int indtmp2 = indtmp1/inputHeight;
243 const int h = indtmp1 - indtmp2 * inputHeight;
244 indtmp1 = indtmp2;
245 indtmp2 = indtmp1/inputChannels;
246 const int c = indtmp1 - indtmp2 * inputChannels;
247 const int n = indtmp2;
248
249 acc_t value(0);
250
251 for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) {
252 int och = (c * depthwiseMultiplier) + multiplier;
253 int weightOffset = och * kernelHeight * kernelWidth;
254 for (int kh = 0; kh < KH_LIMIT; ++kh) {
255 #if !defined(USE_ROCM)
256 #pragma unroll
257 #endif
258 for (int kw = 0; kw < KW_LIMIT; ++kw) {
259 int h_out = h + padHeight - kh * dilationHeight;
260 int w_out = w + padWidth - kw * dilationWidth;
261 if ((h_out % strideH == 0) && (w_out % strideW == 0)) {
262 h_out = h_out / strideH;
263 w_out = w_out / strideW;
264
265 if ((h_out >= 0) && (h_out < outputHeight)
266 && (w_out >= 0) && (w_out < outputWidth)) {
267
268 const int offset = ((n * outputChannels + och) * outputHeight + h_out)
269 * outputWidth + w_out;
270 value += (static_cast<acc_t>(weight.data()[weightOffset]) *
271 static_cast<acc_t>(grad_output.data()[offset]));
272 }
273 }
274 ++weightOffset;
275 }
276 }
277 }
278 grad_input.data()[linearIndex] = static_cast<scalar_t>(value);
279 }
280 }
281
282 template <typename scalar_t, typename index_t=unsigned>
conv_depthwise2d_grad_weight_kernel(const PackedTensorAccessor32<const scalar_t,4,DefaultPtrTraits> grad_output,const PackedTensorAccessor32<const scalar_t,4,DefaultPtrTraits> input,PackedTensorAccessor32<scalar_t,4,DefaultPtrTraits> grad_weight,const int batchSize,const int inputChannels,const int kernelChannels,const int depthwiseMultiplier,const int inputWidth,const int inputHeight,const int outputWidth,const int outputHeight,const int kernelWidth,const int kernelHeight,const int strideWidth,const int strideHeight,const int padWidth,const int padHeight,const int dilationWidth,const int dilationHeight)283 __global__ void conv_depthwise2d_grad_weight_kernel(
284 const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> grad_output,
285 const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> input,
286 PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> grad_weight,
287 const int batchSize,
288 const int inputChannels,
289 const int kernelChannels,
290 const int depthwiseMultiplier,
291 const int inputWidth, const int inputHeight,
292 const int outputWidth, const int outputHeight,
293 const int kernelWidth, const int kernelHeight,
294 const int strideWidth, const int strideHeight,
295 const int padWidth, const int padHeight,
296 const int dilationWidth, const int dilationHeight) {
297 using acc_t = at::acc_type<scalar_t, true>;
298 const int channelStride = kernelWidth * kernelHeight;
299
300 // Each Block is responsible for accumulating over a permutation of
301 // (channels x kH x kW), use blockIdx to determine which one
302 int bidx = blockIdx.x;
303 int kW = bidx % kernelWidth;
304 int kH = (bidx / kernelWidth) % kernelHeight;
305 int ch = (bidx / channelStride);
306
307 // Need to calculate which input channel is associated with this filter
308 // channel
309 int inputCh = ch / depthwiseMultiplier;
310
311 acc_t grad(0);
312
313 const int laneId = threadIdx.x % C10_WARP_SIZE;
314 const int batch = threadIdx.x / C10_WARP_SIZE;
315 const int nwarps = blockDim.x / C10_WARP_SIZE;
316 const int imageElements = outputWidth * outputHeight;
317 // Use warp per item. In the original kernel, a threadblock was used to sum over NHW.
318 // Here, we use a warp to sum values over HW dimension, and if batchSize is larger than the
319 // number of warps, a warp would loop over remaining batch items (e.g. if there are 8 warps,
320 // warp 0 would go over 0-8-16 etc image, warp 1 over 1-9-17 etc). Later in blockReduce,
321 // all the warps will be reduced anyway, thus the full reduction will be over NHW, like it
322 // should be. That allows to get rid of one modulo operation inside the loop (because n/batchIdx
323 // now does not have to be computed through modulo, you are just looping over it), and
324 // bring a nice speed-up.
325 for (int batchIdx = batch; batchIdx < batchSize; batchIdx += nwarps){
326 // Warp-stride loop over elements in a batch item
327 for (index_t idx = laneId; idx < imageElements; idx += C10_WARP_SIZE) {
328 // Need to calculate the following: batch position, and offset into the grad_output
329 // in height, and width. We can intuit the corresponding position in the input from
330 // the other parameters we have
331 int go_w_offset = idx % outputWidth;
332 int go_h_offset = (idx / outputWidth);
333
334 int i_w_offset = (go_w_offset * strideWidth) + (kW * dilationWidth) - padWidth;
335 int i_h_offset = (go_h_offset * strideHeight) + (kH * dilationHeight) - padHeight;
336
337 if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < inputWidth && i_h_offset < inputHeight) {
338 int inputOffset = ((batchIdx * inputChannels + inputCh) * inputHeight + i_h_offset) * inputWidth + i_w_offset;
339 int outputOffset = ((batchIdx * kernelChannels + ch) * outputHeight ) * outputWidth + idx;
340 grad += (static_cast<acc_t>(input.data()[inputOffset]) *
341 static_cast<acc_t>(grad_output.data()[outputOffset]));
342 }
343 }
344 }
345
346 // At this point each thread in the block has a local gradient, which we need to
347 // accumulate prior to writing the global value
348 extern __shared__ char smem[];
349 acc_t* buf = reinterpret_cast<acc_t*>(smem);
350 acc_t tval = cuda_utils::BlockReduceSum(grad, buf);
351
352 // After reduction, first thread in the block has the gradient, so its responsible
353 // for writing it to grad_weight
354 if (threadIdx.x == 0) {
355 int weightOffset = kW + (kernelWidth * kH) + (kernelWidth * kernelHeight * ch);
356 grad_weight.data()[weightOffset] = static_cast<scalar_t>(tval);
357 }
358 }
359
conv_depthwise2d_forward_out(const Tensor & input,const Tensor & output,const Tensor & weight,const Tensor & bias,const int kW,const int kH,const int dW,const int dH,const int padW,const int padH,const int dilationW,const int dilationH)360 void conv_depthwise2d_forward_out(
361 const Tensor &input,
362 const Tensor &output,
363 const Tensor &weight,
364 const Tensor &bias,
365 const int kW, const int kH,
366 const int dW, const int dH,
367 const int padW, const int padH,
368 const int dilationW, const int dilationH) {
369 // Only handle 4D Input Tensors for now
370 TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
371 TORCH_CHECK(weight.numel() > 0 && weight.dim() == 4);
372 TORCH_CHECK(output.is_contiguous());
373
374 auto in_sizes = input.sizes();
375 auto w_sizes = weight.sizes();
376
377 // We assume that the input and weight Tensors are shaped properly by
378 // the caller, so we verify that here to some extent
379
380 // Weight Tensor is shape (output_channels, 1, kH, kW)
381 TORCH_CHECK(w_sizes[1] == 1);
382
383 // Input Tensor is shape (N, input_channels, H, W)
384 // We verify that the # of output_channels is a multiple of input_channels
385 TORCH_CHECK(w_sizes[0] % in_sizes[1] == 0);
386
387 // Bias has same # of channels as output
388 const bool has_bias = bias.defined();
389 TORCH_CHECK(!has_bias || (bias.dim() <= 1 && bias.numel() == w_sizes[0]));
390
391 // Following the behavior of other THCUNN functions, we shape the output
392 // Tensor ourselves
393 int64_t height = in_sizes[2];
394 int64_t width = in_sizes[3];
395 int64_t outputChannels = w_sizes[0];
396 auto out_sizes = conv_output_size(in_sizes, weight.sizes(), {padH, padW}, {dH, dW},
397 {dilationH, dilationW});
398 const auto outputWidth = out_sizes[3];
399 const auto outputHeight = out_sizes[2];
400
401 resize_output(output, out_sizes);
402
403 int64_t inputChannels = in_sizes[1];
404 int64_t depthwiseMultiplier = outputChannels / inputChannels;
405
406 // One thread per output value
407 TORCH_CHECK(canUse32BitIndexMath(input) && canUse32BitIndexMath(output));
408 int32_t n = output.numel();
409 int blocks = GET_BLOCKS(n);
410 dim3 grid(blocks);
411 dim3 block(CUDA_NUM_THREADS);
412 const auto stream = c10::cuda::getCurrentCUDAStream();
413
414 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
415 "conv_depthwise2d_forward_cuda", [&] {
416 // Create PackedTensorAccessor
417 // Kernel currently relies upon all the Tensors to be contiguous, but we made
418 // them contiguous above
419 const auto input_a = input.packed_accessor32<const scalar_t, 4>();
420 const auto weight_a = weight.packed_accessor32<const scalar_t, 4>();
421 const auto output_a = output.packed_accessor32<scalar_t, 4>();
422 const auto bias_a = has_bias ?
423 bias.packed_accessor32<const scalar_t, 1>() :
424 dummy_packed_accessor32<const scalar_t, 1>();
425 if (kW == 5 && kH == 5) {
426 conv_depthwise2d_forward_kernel<5> <<<grid, block, 0, stream>>>(
427 input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
428 width, height, outputWidth, outputHeight,
429 kW, kH, dW, dH, padW, padH, dilationW, dilationH);
430 C10_CUDA_KERNEL_LAUNCH_CHECK();
431 } else if (kW == 3 && kH == 3) {
432 conv_depthwise2d_forward_kernel<3> <<<grid, block, 0, stream>>>(
433 input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
434 width, height, outputWidth, outputHeight,
435 kW, kH, dW, dH, padW, padH, dilationW, dilationH);
436 C10_CUDA_KERNEL_LAUNCH_CHECK();
437 } else if (kW == 1 && kH == 1) {
438 conv_depthwise2d_forward_kernel<1> <<<grid, block, 0, stream>>>(
439 input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
440 width, height, outputWidth, outputHeight,
441 kW, kH, dW, dH, padW, padH, dilationW, dilationH);
442 C10_CUDA_KERNEL_LAUNCH_CHECK();
443 } else {
444 conv_depthwise2d_forward_kernel_generic<<<grid, block, 0, stream>>>(
445 input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
446 width, height, outputWidth, outputHeight,
447 kW, kH, dW, dH, padW, padH, dilationW, dilationH);
448 C10_CUDA_KERNEL_LAUNCH_CHECK();
449 }
450 });
451 }
452
conv_depthwise2d_backward_out(const Tensor & input,const Tensor & grad_output,const Tensor & grad_input,const Tensor & weight,const int kW,const int kH,const int dW,const int dH,const int padW,const int padH,const int dilationW,const int dilationH)453 void conv_depthwise2d_backward_out(
454 const Tensor &input,
455 const Tensor &grad_output,
456 const Tensor &grad_input,
457 const Tensor &weight,
458 const int kW, const int kH,
459 const int dW, const int dH,
460 const int padW, const int padH,
461 const int dilationW, const int dilationH) {
462 // Only handle 4D Input Tensors for now
463 TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
464 TORCH_CHECK(weight.numel() > 0 && weight.dim() == 4);
465 TORCH_CHECK(grad_output.numel() > 0 && grad_output.dim() == 4);
466
467 // Minimal shape checking, as above
468 // Same # of elements in batch
469 TORCH_CHECK(input.sizes()[0] == grad_output.sizes()[0]);
470 // Same # of filters as outputChannels
471 TORCH_CHECK(weight.sizes()[0] == grad_output.sizes()[1]);
472
473 // Resize Grainput_a
474 auto in_sizes = input.sizes();
475 resize_output(grad_input, in_sizes);
476
477 int inputChannels = in_sizes[1];
478 int height = in_sizes[2];
479 int width = in_sizes[3];
480
481 auto gO_sizes = grad_output.sizes();
482 int outputChannels = gO_sizes[1];
483 int outputHeight = gO_sizes[2];
484 int outputWidth = gO_sizes[3];
485
486 int depthwiseMultiplier = outputChannels / inputChannels;
487
488 // Kernel currently relies upon all the Tensors to be contiguous
489 TORCH_CHECK(grad_output.is_contiguous());
490 TORCH_CHECK(weight.is_contiguous());
491 TORCH_CHECK(grad_input.is_contiguous());
492
493 // One thread per grainput_a value
494 TORCH_CHECK(canUse32BitIndexMath(grad_input) &&
495 canUse32BitIndexMath(grad_output));
496 int32_t n = grad_input.numel();
497 int blocks = GET_BLOCKS(n);
498 dim3 grid(blocks);
499 dim3 block(CUDA_NUM_THREADS);
500 const auto stream = c10::cuda::getCurrentCUDAStream();
501 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
502 "conv_depthwise2d_backward_cuda", [&] {
503 auto grad_output_a = grad_output.packed_accessor32<const scalar_t, 4>();
504 auto grad_input_a = grad_input.packed_accessor32<scalar_t, 4>();
505 auto weight_a = weight.packed_accessor32<const scalar_t, 4>();
506
507 if (kW == 5 && kH == 5) {
508 if (dW == 1 && dH == 1){
509 conv_depthwise2d_backward_kernel<5, 1><<<grid, block, 0, stream>>>(
510 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
511 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
512 C10_CUDA_KERNEL_LAUNCH_CHECK();
513 } else if (dW == 2 && dH == 2) {
514 conv_depthwise2d_backward_kernel<5, 2><<<grid, block, 0, stream>>>(
515 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
516 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
517 C10_CUDA_KERNEL_LAUNCH_CHECK();
518 } else {
519 conv_depthwise2d_backward_kernel<5, 0><<<grid, block, 0, stream>>>(
520 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
521 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
522 C10_CUDA_KERNEL_LAUNCH_CHECK();
523 }
524 } else if (kW == 3 && kH == 3) {
525 if (dW == 1 && dH == 1){
526 conv_depthwise2d_backward_kernel<3, 1><<<grid, block, 0, stream>>>(
527 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
528 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
529 C10_CUDA_KERNEL_LAUNCH_CHECK();
530 } else if (dW == 2 && dH == 2) {
531 conv_depthwise2d_backward_kernel<3, 2><<<grid, block, 0, stream>>>(
532 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
533 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
534 C10_CUDA_KERNEL_LAUNCH_CHECK();
535 } else {
536 conv_depthwise2d_backward_kernel<3, 0><<<grid, block, 0, stream>>>(
537 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
538 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
539 C10_CUDA_KERNEL_LAUNCH_CHECK();
540 }
541 } else if (kW == 1 && kH == 1) {
542 if (dW == 1 && dH == 1){
543 conv_depthwise2d_backward_kernel<1, 1><<<grid, block, 0, stream>>>(
544 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
545 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
546 C10_CUDA_KERNEL_LAUNCH_CHECK();
547 } else if (dW == 2 && dH == 2) {
548 conv_depthwise2d_backward_kernel<1, 2><<<grid, block, 0, stream>>>(
549 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
550 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
551 C10_CUDA_KERNEL_LAUNCH_CHECK();
552 } else {
553 conv_depthwise2d_backward_kernel<1, 0><<<grid, block, 0, stream>>>(
554 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
555 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
556 C10_CUDA_KERNEL_LAUNCH_CHECK();
557 }
558 } else if (dW == 1 && dH == 1) {
559 conv_depthwise2d_backward_kernel<0, 1><<<grid, block, 0, stream>>>(
560 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
561 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
562 C10_CUDA_KERNEL_LAUNCH_CHECK();
563 } else if (dW == 2 && dH == 2) {
564 conv_depthwise2d_backward_kernel<0, 2><<<grid, block, 0, stream>>>(
565 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
566 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
567 C10_CUDA_KERNEL_LAUNCH_CHECK();
568 } else {
569 conv_depthwise2d_backward_kernel<0, 0><<<grid, block, 0, stream>>>(
570 grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
571 height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
572 C10_CUDA_KERNEL_LAUNCH_CHECK();
573 }
574 });
575 }
576
577 // Crude benchmarks suggest 256 is better than 512 and 1024
578 // TODO: Autotune/use better heuristics, improve speed more.
getGradParamsNumThreads(int batchSize)579 int getGradParamsNumThreads(int batchSize) {
580 //warp per item in a batch, up to a maximum
581 constexpr int MAX_BLOCK_SIZE = 256;
582 return std::min(batchSize * at::cuda::warp_size(), MAX_BLOCK_SIZE);
583 }
584
conv_depthwise2d_grad_weight_out(const Tensor & input,const Tensor & grad_output,const Tensor & grad_weight,const int kW,const int kH,const int dW,const int dH,const int padW,const int padH,const int dilationW,const int dilationH)585 void conv_depthwise2d_grad_weight_out(
586 const Tensor &input,
587 const Tensor &grad_output,
588 const Tensor &grad_weight,
589 const int kW, const int kH,
590 const int dW, const int dH,
591 const int padW, const int padH,
592 const int dilationW, const int dilationH) {
593 // Only handle 4D Input Tensors for now
594 TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
595 TORCH_CHECK(grad_output.numel() > 0 && grad_output.dim() == 4);
596
597 // Minimal shape checking as above
598 // Same # of elements in batch
599 TORCH_CHECK(input.sizes()[0] == grad_output.sizes()[0]);
600
601 auto in_sizes = input.sizes();
602 int batchSize = in_sizes[0];
603 int inputChannels = in_sizes[1];
604 int height = in_sizes[2];
605 int width = in_sizes[3];
606
607 auto gO_sizes = grad_output.sizes();
608 int outputChannels = gO_sizes[1];
609 int outputHeight = gO_sizes[2];
610 int outputWidth = gO_sizes[3];
611
612 int depthwiseMultiplier = outputChannels / inputChannels;
613
614 resize_output(grad_weight, {outputChannels, 1, kH, kW});
615
616 // Kernel currently relies upon all the Tensors to be contiguous
617 TORCH_CHECK(grad_output.is_contiguous());
618 TORCH_CHECK(input.is_contiguous());
619 TORCH_CHECK(grad_weight.is_contiguous());
620
621 // We parallelize so that each block computes a single value in grad_weight
622 TORCH_CHECK(canUse32BitIndexMath(input) &&
623 canUse32BitIndexMath(grad_output));
624 int blocks = outputChannels * kH * kW;
625
626 // Make sure we have enough threads to perform the reduction, and use this number
627 // to create the shared memory size for the reduction
628 dim3 grid(blocks);
629 dim3 block(getGradParamsNumThreads(batchSize));
630 const auto stream = c10::cuda::getCurrentCUDAStream();
631
632 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
633 "conv_depthwise2d_grad_weight_cuda", [&] {
634 const auto grad_output_a = grad_output.packed_accessor32<const scalar_t, 4>();
635 const auto input_a = input.packed_accessor32<const scalar_t, 4>();
636 const auto grad_weight_a = grad_weight.packed_accessor32<scalar_t, 4>();
637 using acc_t = at::acc_type<scalar_t, true>;
638 int warp_size = at::cuda::warp_size();
639 TORCH_INTERNAL_ASSERT(block.x % warp_size == 0);
640 int smem = (block.x / warp_size) * sizeof(acc_t);
641 conv_depthwise2d_grad_weight_kernel<<<grid, block, smem, stream>>>(
642 grad_output_a, input_a, grad_weight_a, batchSize, inputChannels, outputChannels, depthwiseMultiplier,
643 width, height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
644 C10_CUDA_KERNEL_LAUNCH_CHECK();
645 });
646 }
647
648 } // namespace (anonymous)
649
conv_depthwise2d_cuda_out(const Tensor & input_,const Tensor & weight_,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,const Tensor & out)650 const Tensor& conv_depthwise2d_cuda_out(
651 const Tensor &input_,
652 const Tensor &weight_,
653 IntArrayRef kernel_size,
654 const std::optional<Tensor> &bias_opt,
655 IntArrayRef stride,
656 IntArrayRef padding,
657 IntArrayRef dilation,
658 const Tensor &out) {
659 TORCH_CHECK(kernel_size.size() == 2);
660 TORCH_CHECK(stride.size() == 2);
661 TORCH_CHECK(padding.size() == 2);
662 TORCH_CHECK(dilation.size() == 2);
663
664 auto input = input_.expect_contiguous();
665 auto weight = weight_.expect_contiguous();
666 auto bias = [&] {
667 if (bias_opt.has_value() && bias_opt->defined()) {
668 return bias_opt->expect_contiguous();
669 }
670 return c10::MaybeOwned<Tensor>::owned(std::in_place);
671 }();
672
673 conv_depthwise2d_forward_out(
674 *input,
675 out,
676 *weight,
677 *bias,
678 kernel_size[1], kernel_size[0],
679 stride[1], stride[0],
680 padding[1], padding[0],
681 dilation[1], dilation[0]);
682 return out;
683 }
684
conv_depthwise2d_cuda(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation)685 Tensor conv_depthwise2d_cuda(
686 const Tensor &input,
687 const Tensor &weight,
688 IntArrayRef kernel_size,
689 const std::optional<Tensor> &bias,
690 IntArrayRef stride,
691 IntArrayRef padding,
692 IntArrayRef dilation) {
693 auto out = at::empty({0}, input.options());
694 return conv_depthwise2d_cuda_out(input, weight, kernel_size, bias,
695 stride, padding, dilation, out);
696 }
697
conv_depthwise2d_backward_cuda_out(const Tensor & grad_output_,const Tensor & self_,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,Tensor & grad_input,Tensor & grad_weight)698 std::tuple<Tensor&, Tensor&> conv_depthwise2d_backward_cuda_out(
699 const Tensor & grad_output_,
700 const Tensor & self_,
701 const Tensor & weight_,
702 IntArrayRef kernel_size,
703 IntArrayRef stride,
704 IntArrayRef padding,
705 IntArrayRef dilation,
706 Tensor & grad_input,
707 Tensor & grad_weight) {
708 auto grad_output = grad_output_.expect_contiguous();
709
710 if (grad_weight.defined()) {
711 auto self = self_.expect_contiguous();
712 conv_depthwise2d_grad_weight_out(
713 *self, *grad_output, grad_weight,
714 kernel_size[1], kernel_size[0],
715 stride[1], stride[0],
716 padding[1], padding[0],
717 dilation[1], dilation[0]);
718 }
719
720 if (grad_input.defined()) {
721 auto weight = weight_.expect_contiguous();
722 conv_depthwise2d_backward_out(
723 self_, *grad_output, grad_input, *weight,
724 kernel_size[1], kernel_size[0],
725 stride[1], stride[0],
726 padding[1], padding[0],
727 dilation[1], dilation[0]);
728 }
729 return std::forward_as_tuple(grad_input, grad_weight);
730 }
731
conv_depthwise2d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,std::array<bool,2> output_mask)732 std::tuple<Tensor, Tensor> conv_depthwise2d_backward_cuda(
733 const Tensor& grad_output,
734 const Tensor& self,
735 const Tensor& weight,
736 IntArrayRef kernel_size,
737 IntArrayRef stride,
738 IntArrayRef padding,
739 IntArrayRef dilation,
740 std::array<bool, 2> output_mask) {
741 Tensor grad_input;
742 Tensor grad_weight;
743
744 if (output_mask[0]) {
745 grad_input = at::empty({0}, grad_output.options());
746 }
747
748 if (output_mask[1]) {
749 grad_weight = at::empty({0}, grad_output.options());
750 }
751 return conv_depthwise2d_backward_cuda_out(
752 grad_output,
753 self,
754 weight,
755 kernel_size,
756 stride,
757 padding,
758 dilation,
759 grad_input,
760 grad_weight);
761 }
762
763 REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda);
764
765 } // namespace at::native
766