1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/UnfoldBackward.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8
9 #include <vector>
10
11 // Note on naming: it is unconventional.
12 // grad_in does not mean that it is a gradient wrt to input,
13 // grad_in/grad_out is just an input/output of unfold_backward kernel.
14 //
15 // unfold_backward, the algorithm is described in
16 // /native/cpu/UnfoldBackwardKernel.cpp
17
18 namespace at::native {
19
20 namespace {
21
22 template <int n_threads, int n_elems_per_thread, typename func_t>
C10_LAUNCH_BOUNDS_2(n_threads,n_elems_per_thread)23 C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
24 __global__ void _unfold_backward_elementwise_kernel(int total_n_elems, func_t f) {
25 constexpr int total_work_block = n_threads * n_elems_per_thread;
26 int idx = total_work_block * blockIdx.x + threadIdx.x;
27
28 #pragma unroll
29 for (int i = 0; i < n_elems_per_thread; ++i) {
30 if (idx < total_n_elems) {
31 f(idx);
32 idx += n_threads;
33 }
34 }
35 }
36
37 template <int n_threads, int n_elems_per_thread, typename func_t>
_launch_unfold_backward_kernel(int total_n_elems,func_t f)38 static void _launch_unfold_backward_kernel(int total_n_elems, func_t f) {
39 TORCH_INTERNAL_ASSERT(
40 total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
41 );
42
43 dim3 block(n_threads);
44 constexpr int total_work_block = n_threads * n_elems_per_thread;
45 dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
46
47 auto stream = at::cuda::getCurrentCUDAStream();
48 _unfold_backward_elementwise_kernel<n_threads, n_elems_per_thread, func_t>
49 <<<grid, block, 0, stream>>>(total_n_elems, f);
50 C10_CUDA_KERNEL_LAUNCH_CHECK();
51 }
52
53 template <typename scalar_t>
_unfold_backward_internal_kernel(TensorIterator & iter,int64_t size,int64_t step,int64_t grad_in_dim_stride,int64_t grad_in_last_dim_stride,int64_t grad_in_dim_size,int64_t grad_out_dim_stride)54 void _unfold_backward_internal_kernel(
55 TensorIterator& iter,
56 int64_t size,
57 int64_t step,
58 int64_t grad_in_dim_stride,
59 int64_t grad_in_last_dim_stride,
60 int64_t grad_in_dim_size,
61 int64_t grad_out_dim_stride
62 ) {
63 if (iter.numel() == 0) {
64 return;
65 }
66
67 if (!iter.can_use_32bit_indexing()) {
68 for (auto& sub_iter : iter.with_32bit_indexing()) {
69 _unfold_backward_internal_kernel<scalar_t>(
70 sub_iter,
71 size,
72 step,
73 grad_in_dim_stride,
74 grad_in_last_dim_stride,
75 grad_in_dim_size,
76 grad_out_dim_stride
77 );
78 }
79 return;
80 }
81
82 char* __restrict__ grad_out_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
83 char* __restrict__ grad_in_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
84 char* __restrict__ idx_dim_ptr = reinterpret_cast<char*>(iter.data_ptr(2));
85
86 auto offset_calc = make_offset_calculator<3>(iter);
87
88 // The algorithm is: for each index in grad_out find
89 // the elements contributing to it and sum them up.
90 // Note: the algorithm does not require any synchronization.
91 auto loop = [=]C10_DEVICE(int i) {
92 auto offsets = offset_calc.get(i);
93
94 auto* __restrict__ grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr + offsets[0]);
95 auto* __restrict__ grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr + offsets[1]);
96
97 auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr + offsets[2]);
98
99 // left_fold potentially intersecting with idx_dim
100 // is either (idx_dim - size) / step or the next integer.
101 int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
102 if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
103 ++left_fold_idx;
104 }
105
106 auto right_fold_idx = idx_dim / step;
107 right_fold_idx = (right_fold_idx >= grad_in_dim_size) ?
108 (grad_in_dim_size - 1) : right_fold_idx;
109
110 for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
111 auto idx_last_dim = idx_dim - fold_idx * step;
112 *grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
113 + idx_last_dim * grad_in_last_dim_stride];
114 }
115
116 };
117
118 _launch_unfold_backward_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
119 }
120
unfold_backward_cuda_kernel(Tensor & grad_out,const Tensor & grad_in,int64_t dim,int64_t size,int64_t step)121 void unfold_backward_cuda_kernel(
122 Tensor& grad_out,
123 const Tensor& grad_in,
124 int64_t dim,
125 int64_t size,
126 int64_t step
127 ) {
128 dim = maybe_wrap_dim(dim, grad_out.dim());
129 // last dim stores the folds
130 auto last_dim = maybe_wrap_dim(-1, grad_in.dim());
131
132 auto grad_in_dim_stride = ensure_nonempty_stride(grad_in, dim);
133 auto grad_in_last_dim_stride = ensure_nonempty_stride(grad_in, last_dim);
134 auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
135
136 auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim);
137
138 TensorIterator iter = _make_unfold_backward_iter_over_grad_out(
139 grad_out, grad_in, dim, size, step);
140
141 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
142 at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
143 iter.dtype(),
144 "unfold_backward_cuda", [&] {
145 _unfold_backward_internal_kernel<scalar_t>(
146 iter,
147 size,
148 step,
149 grad_in_dim_stride,
150 grad_in_last_dim_stride,
151 grad_in_dim_size,
152 grad_out_dim_stride
153 );
154 }
155 );
156 }
157
158 }
159
160 REGISTER_DISPATCH(unfold_backward_stub, &unfold_backward_cuda_kernel);
161
162 } // namespace at::native
163