xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/cpu/vec/vec.h>
6 #include <ATen/native/UnfoldBackward.h>
7 #include <ATen/native/cpu/Loops.h>
8 #include <c10/util/irange.h>
9 
10 #if (defined(_WIN32) || defined(_WIN64))
11 #define RESTRICT __restrict
12 #else
13 #define RESTRICT __restrict__
14 #endif
15 
16 // Note on naming: it is unconventional.
17 // grad_in does not mean that it is a gradient wrt to input,
18 // grad_in/grad_out is just an input/output of unfold_backward kernel.
19 //
20 // unfold_backward, the algorithm.
21 //
22 // Consider out = in.unfold(dim, size, step), then
23 // out.shape[dim] == (in.shape[dim] - size) / step + 1,
24 // out.shape[-1] == size.
25 // out.dims() == in.dims() + 1
26 //
27 // unfold_backward receives grad_in and returns grad_out such that
28 // grad_in.shape == out.shape,
29 // grad_out.shape = in.shape.
30 //
31 // unfold_backward considers the following two cases:
32 // case1. step >= size.
33 // case2. step < size.
34 //
35 // case1. step >= size.
36 // In this case the iteration takes over grad_in and performs the following copy:
37 // grad_out[..., i_out_dim,...] = grad_in[..., i_in_dim,..., i_in_last_dim],
38 // where i_out_dim = i_in_dim * step + i_in_last_dim.
39 //
40 // case2. step < size.
41 // In this case the iteration takes over grad_out,
42 // where grad_out[...,i_out_dim,...] accumulates all values
43 // grad_in[...,i_in_dim,...,i_in_last_dim], where
44 // i_in_dim is in [left_idx_fold, right_idx_fold],
45 // i_in_last_dim = i_out_dim - i_in_dim * step,
46 // left_idx_fold = (i_out_dim - size) / step
47 //  if i_out_dim in [left_idx_fold * step, left_idx_fold * step + size)
48 //  else (i_out_dim - size) / step + 1,
49 // right_idx_fold = i_out_dim / step.
50 //
51 // Simply put, given i_out_dim, we find which folds of grad_in
52 // intersect with i_out_dim, these are precisely [left_idx_fold, right_idx_fold],
53 // and then the corresponding value of grad_in[...,i_in_dim,...,i_in_last_dim]
54 // gets added up to grad_out[...,i_out_dim,...].
55 
56 namespace at::native {
57 
58 namespace {
59 
60 template <typename scalar_t>
_unfold_backward_internal_kernel(TensorIterator & iter,int64_t size,int64_t step,int64_t grad_in_dim_stride,int64_t grad_in_last_dim_stride,int64_t grad_in_dim_size,int64_t grad_out_dim_stride)61 void _unfold_backward_internal_kernel(
62   TensorIterator& iter,
63   int64_t size,
64   int64_t step,
65   int64_t grad_in_dim_stride,
66   int64_t grad_in_last_dim_stride,
67   int64_t grad_in_dim_size,
68   int64_t grad_out_dim_stride
69 ) {
70   if (iter.numel() == 0) {
71     return;
72   }
73 
74   auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
75     auto* RESTRICT grad_out_ptr = data[0];
76     auto* RESTRICT grad_in_ptr = data[1];
77     auto* RESTRICT idx_dim_ptr = data[2];
78 
79     for (const auto elem C10_UNUSED : c10::irange(nelems)) {
80       auto* RESTRICT grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr);
81       auto* RESTRICT grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr);
82 
83       auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr);
84 
85       // left_fold potentially intersecting with idx_dim
86       // is either (idx_dim - size) / step or the next integer.
87       int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
88       if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
89         ++left_fold_idx;
90       }
91 
92       auto right_fold_idx = idx_dim / step;
93       right_fold_idx = (right_fold_idx >= grad_in_dim_size)
94         ? (grad_in_dim_size - 1) : right_fold_idx;
95 
96       for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
97         auto idx_last_dim = idx_dim - fold_idx * step;
98         *grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
99                                     + idx_last_dim * grad_in_last_dim_stride];
100       }
101 
102       grad_out_ptr += strides[0];
103       grad_in_ptr += strides[1];
104       idx_dim_ptr += strides[2];
105     }
106   };
107 
108   iter.for_each(loop);
109 }
110 
unfold_backward_cpu_kernel(Tensor & grad_out,const Tensor & grad_in,int64_t dim,int64_t size,int64_t step)111 void unfold_backward_cpu_kernel(
112   Tensor& grad_out,
113   const Tensor& grad_in,
114   int64_t dim,
115   int64_t size,
116   int64_t step
117 ) {
118   dim = maybe_wrap_dim(dim, grad_out.dim());
119   // last dim stores the folds
120   auto last_dim = maybe_wrap_dim(-1, grad_in.dim());
121 
122   auto grad_in_dim_stride = ensure_nonempty_stride(grad_in, dim);
123   auto grad_in_last_dim_stride = ensure_nonempty_stride(grad_in, last_dim);
124   auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
125 
126   auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim);
127 
128   TensorIterator iter = _make_unfold_backward_iter_over_grad_out(
129       grad_out, grad_in, dim, size, step);
130 
131   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
132     at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
133     iter.dtype(),
134     "unfold_backward_cpu", [&] {
135       _unfold_backward_internal_kernel<scalar_t>(
136         iter,
137         size,
138         step,
139         grad_in_dim_stride,
140         grad_in_last_dim_stride,
141         grad_in_dim_size,
142         grad_out_dim_stride
143       );
144     }
145   );
146 }
147 
148 }
149 
150 REGISTER_DISPATCH(unfold_backward_stub, &unfold_backward_cpu_kernel);
151 
152 } // namespace at::native
153