xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/UnfoldBackward.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 
9 #include <vector>
10 
11 // Note on naming: it is unconventional.
12 // grad_in does not mean that it is a gradient wrt to input,
13 // grad_in/grad_out is just an input/output of unfold_backward kernel.
14 //
15 // unfold_backward, the algorithm is described in
16 // /native/cpu/UnfoldBackwardKernel.cpp
17 
18 namespace at::native {
19 
20 namespace {
21 
22 template <int n_threads, int n_elems_per_thread, typename func_t>
C10_LAUNCH_BOUNDS_2(n_threads,n_elems_per_thread)23 C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
24 __global__ void _unfold_backward_elementwise_kernel(int total_n_elems, func_t f) {
25   constexpr int total_work_block = n_threads * n_elems_per_thread;
26   int idx = total_work_block * blockIdx.x + threadIdx.x;
27 
28   #pragma unroll
29   for (int i = 0; i < n_elems_per_thread; ++i) {
30     if (idx < total_n_elems) {
31       f(idx);
32       idx += n_threads;
33     }
34   }
35 }
36 
37 template <int n_threads, int n_elems_per_thread, typename func_t>
_launch_unfold_backward_kernel(int total_n_elems,func_t f)38 static void _launch_unfold_backward_kernel(int total_n_elems, func_t f) {
39   TORCH_INTERNAL_ASSERT(
40     total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
41   );
42 
43   dim3 block(n_threads);
44   constexpr int total_work_block = n_threads * n_elems_per_thread;
45   dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
46 
47   auto stream = at::cuda::getCurrentCUDAStream();
48   _unfold_backward_elementwise_kernel<n_threads, n_elems_per_thread, func_t>
49     <<<grid, block, 0, stream>>>(total_n_elems, f);
50   C10_CUDA_KERNEL_LAUNCH_CHECK();
51 }
52 
53 template <typename scalar_t>
_unfold_backward_internal_kernel(TensorIterator & iter,int64_t size,int64_t step,int64_t grad_in_dim_stride,int64_t grad_in_last_dim_stride,int64_t grad_in_dim_size,int64_t grad_out_dim_stride)54 void _unfold_backward_internal_kernel(
55   TensorIterator& iter,
56   int64_t size,
57   int64_t step,
58   int64_t grad_in_dim_stride,
59   int64_t grad_in_last_dim_stride,
60   int64_t grad_in_dim_size,
61   int64_t grad_out_dim_stride
62 ) {
63   if (iter.numel() == 0) {
64     return;
65   }
66 
67   if (!iter.can_use_32bit_indexing()) {
68     for (auto& sub_iter : iter.with_32bit_indexing()) {
69       _unfold_backward_internal_kernel<scalar_t>(
70         sub_iter,
71         size,
72         step,
73         grad_in_dim_stride,
74         grad_in_last_dim_stride,
75         grad_in_dim_size,
76         grad_out_dim_stride
77       );
78     }
79     return;
80   }
81 
82   char* __restrict__ grad_out_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
83   char* __restrict__ grad_in_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
84   char* __restrict__ idx_dim_ptr = reinterpret_cast<char*>(iter.data_ptr(2));
85 
86   auto offset_calc = make_offset_calculator<3>(iter);
87 
88   // The algorithm is: for each index in grad_out find
89   // the elements contributing to it and sum them up.
90   // Note: the algorithm does not require any synchronization.
91   auto loop = [=]C10_DEVICE(int i) {
92     auto offsets = offset_calc.get(i);
93 
94     auto* __restrict__ grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr + offsets[0]);
95     auto* __restrict__ grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr + offsets[1]);
96 
97     auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr + offsets[2]);
98 
99     // left_fold potentially intersecting with idx_dim
100     // is either (idx_dim - size) / step or the next integer.
101     int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
102     if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
103       ++left_fold_idx;
104     }
105 
106     auto right_fold_idx = idx_dim / step;
107     right_fold_idx = (right_fold_idx >= grad_in_dim_size) ?
108       (grad_in_dim_size - 1) : right_fold_idx;
109 
110     for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
111       auto idx_last_dim = idx_dim - fold_idx * step;
112       *grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
113                                   + idx_last_dim * grad_in_last_dim_stride];
114     }
115 
116   };
117 
118   _launch_unfold_backward_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
119 }
120 
unfold_backward_cuda_kernel(Tensor & grad_out,const Tensor & grad_in,int64_t dim,int64_t size,int64_t step)121 void unfold_backward_cuda_kernel(
122   Tensor& grad_out,
123   const Tensor& grad_in,
124   int64_t dim,
125   int64_t size,
126   int64_t step
127 ) {
128   dim = maybe_wrap_dim(dim, grad_out.dim());
129   // last dim stores the folds
130   auto last_dim = maybe_wrap_dim(-1, grad_in.dim());
131 
132   auto grad_in_dim_stride = ensure_nonempty_stride(grad_in, dim);
133   auto grad_in_last_dim_stride = ensure_nonempty_stride(grad_in, last_dim);
134   auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
135 
136   auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim);
137 
138   TensorIterator iter = _make_unfold_backward_iter_over_grad_out(
139       grad_out, grad_in, dim, size, step);
140 
141   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
142     at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
143     iter.dtype(),
144     "unfold_backward_cuda", [&] {
145       _unfold_backward_internal_kernel<scalar_t>(
146         iter,
147         size,
148         step,
149         grad_in_dim_stride,
150         grad_in_last_dim_stride,
151         grad_in_dim_size,
152         grad_out_dim_stride
153       );
154     }
155   );
156 }
157 
158 }
159 
160 REGISTER_DISPATCH(unfold_backward_stub, &unfold_backward_cuda_kernel);
161 
162 } // namespace at::native
163