xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/UnfoldBackward.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/NonEmptyUtils.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #else
11 #include <ATen/ops/arange.h>
12 #endif
13 
14 namespace at::native {
15 
16 using unfold_backward_fn = void (*)(
17   Tensor& grad_in,
18   const Tensor& grad,
19   int64_t dim,
20   int64_t size,
21   int64_t step
22 );
23 
24 DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub);
25 
26 namespace {
27 
28 // Note on naming: it is unconventional.
29 // grad_in does not mean that it is a gradient wrt to input,
30 // grad_in/grad_out is just an input/output of unfold_backward kernel.
31 
_make_unfold_backward_iter_over_grad_out(Tensor & grad_out,const Tensor & grad_in,int64_t dim,int64_t size,int64_t step)32 static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out(
33   Tensor& grad_out,
34   const Tensor& grad_in,
35   int64_t dim,
36   int64_t size,
37   int64_t step
38 ) {
39   dim = maybe_wrap_dim(dim, grad_out.dim());
40   // last dim stores the folds
41 
42   auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim);
43   auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
44   // dictates the number of elements to iterate over
45   // in dimension `dim`
46   auto iter_dim_size = std::min(
47     grad_out_dim_size,
48     (grad_in_dim_size - 1) * step + size
49   );
50 
51   /* prepare grad_out for TensorIterator { */
52   auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec());
53   auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec());
54   grad_out_sizes[dim] = iter_dim_size;
55   auto grad_out_restrided = grad_out.as_strided(
56     grad_out_sizes, grad_out_strides
57   );
58   /* } */
59 
60   /* prepare grad_in for TensorIterator { */
61   auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec());
62   auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec());
63 
64   // set strides for dim to 0
65   // and size to 1 because
66   // this dimension is indexed inside the kernel
67   grad_in_strides[dim] = 0;
68   grad_in_sizes[dim] = 1;
69 
70   grad_in_strides.pop_back();
71   grad_in_sizes.pop_back();
72 
73   auto grad_in_restrided = grad_in.squeeze(-1).as_strided(
74     grad_in_sizes, grad_in_strides
75   );
76   /* } */
77 
78   // During the TensorIterator iteration we have to know
79   // i_dim in grad_out[i_1,...,i_dim,...i_n],
80   // idx_dim stores this information
81   /* prepare idx_dim for TensorIterator { */
82   auto idx_dim = at::arange(
83     0, iter_dim_size, grad_in.options().dtype(at::kLong)
84   );
85 
86   auto grad_out_dim = ensure_nonempty_dim(grad_out.dim());
87 
88   auto idx_dim_strides = std::vector<int64_t>(grad_out_dim, 0);
89   auto idx_dim_sizes = std::vector<int64_t>(grad_out_dim, 1);
90 
91   idx_dim_strides[dim] = 1;
92   idx_dim_sizes[dim] = iter_dim_size;
93 
94   // idx_dim size will broadcast over determined by grad_out sizes in TensorIterator
95   auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
96   /* } */
97 
98   auto iter = TensorIteratorConfig()
99     .set_check_mem_overlap(false)
100     .check_all_same_dtype(false)
101     .resize_outputs(false)
102     .add_owned_output(grad_out_restrided)
103     .add_owned_const_input(grad_in_restrided)
104     .add_owned_const_input(idx_dim_restrided)
105     .build();
106 
107   return iter;
108 }
109 
110 }
111 
112 } // namespace at::native
113