xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/AffineGridGenerator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorOperators.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/affine_grid_generator_backward_native.h>
10 #include <ATen/ops/affine_grid_generator_native.h>
11 #include <ATen/ops/empty.h>
12 #include <ATen/ops/linspace.h>
13 #include <ATen/ops/tensor.h>
14 #endif
15 
16 namespace at::native {
17 
linspace_from_neg_one(const Tensor & grid,int64_t num_steps,bool align_corners)18 static at::Tensor linspace_from_neg_one(const Tensor& grid, int64_t num_steps,
19                                  bool align_corners) {
20   if (num_steps <= 1) {
21     return at::tensor(0, grid.options());
22   }
23   auto range = at::linspace(-1, 1, num_steps, grid.options());
24   if (!align_corners) {
25     range = range * (num_steps - 1) / num_steps;
26   }
27   return range;
28 }
29 
make_base_grid_4D(const Tensor & theta,int64_t N,int64_t C,int64_t H,int64_t W,bool align_corners)30 static Tensor make_base_grid_4D(
31     const Tensor& theta,
32     int64_t N,
33     int64_t C,
34     int64_t H,
35     int64_t W,
36     bool align_corners) {
37   auto base_grid = at::empty({N, H, W, 3}, theta.options());
38 
39   base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners));
40   base_grid.select(-1, 1).copy_(linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1));
41   base_grid.select(-1, 2).fill_(1);
42 
43   return base_grid;
44 }
45 
make_base_grid_5D(const Tensor & theta,int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,bool align_corners)46 static Tensor make_base_grid_5D(
47     const Tensor& theta,
48     int64_t N,
49     int64_t C,
50     int64_t D,
51     int64_t H,
52     int64_t W,
53     bool align_corners) {
54   auto base_grid = at::empty({N, D, H, W, 4}, theta.options());
55 
56   base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners));
57   base_grid.select(-1, 1).copy_(linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1));
58   base_grid.select(-1, 2).copy_(linspace_from_neg_one(theta, D, align_corners).unsqueeze_(-1).unsqueeze_(-1));
59   base_grid.select(-1, 3).fill_(1);
60 
61   return base_grid;
62 }
63 
affine_grid_generator_4D(const Tensor & theta,int64_t N,int64_t C,int64_t H,int64_t W,bool align_corners)64 static Tensor affine_grid_generator_4D(
65     const Tensor& theta,
66     int64_t N,
67     int64_t C,
68     int64_t H,
69     int64_t W,
70     bool align_corners) {
71   Tensor base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners);
72   auto grid = base_grid.view({N, H * W, 3}).bmm(theta.transpose(1, 2));
73   return grid.view({N, H, W, 2});
74 }
75 
affine_grid_generator_5D(const Tensor & theta,int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,bool align_corners)76 static Tensor affine_grid_generator_5D(
77     const Tensor& theta,
78     int64_t N,
79     int64_t C,
80     int64_t D,
81     int64_t H,
82     int64_t W,
83     bool align_corners) {
84   Tensor base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners);
85   auto grid = base_grid.view({N, D * H * W, 4}).bmm(theta.transpose(1, 2));
86   return grid.view({N, D, H, W, 3});
87 }
88 
affine_grid_generator(const Tensor & theta,IntArrayRef size,bool align_corners)89 Tensor affine_grid_generator(const Tensor& theta, IntArrayRef size, bool align_corners) {
90   TORCH_CHECK(
91       size.size() == 4 || size.size() == 5,
92       "AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.");
93   if (size.size() == 4) {
94     return affine_grid_generator_4D(
95         theta, size[0], size[1], size[2], size[3], align_corners);
96   } else {
97     return affine_grid_generator_5D(
98         theta, size[0], size[1], size[2], size[3], size[4], align_corners);
99   }
100 }
101 
affine_grid_generator_4D_backward(const Tensor & grad_grid,int64_t N,int64_t C,int64_t H,int64_t W,bool align_corners)102 static Tensor affine_grid_generator_4D_backward(
103     const Tensor& grad_grid,
104     int64_t N,
105     int64_t C,
106     int64_t H,
107     int64_t W,
108     bool align_corners) {
109   auto base_grid = make_base_grid_4D(grad_grid, N, C, H, W, align_corners);
110   AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, H, W, 2}));
111   auto grad_theta = base_grid.view({N, H * W, 3})
112                         .transpose(1, 2)
113                         .bmm(grad_grid.reshape({N, H * W, 2}));
114   return grad_theta.transpose(1, 2);
115 }
116 
affine_grid_generator_5D_backward(const Tensor & grad_grid,int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,bool align_corners)117 static Tensor affine_grid_generator_5D_backward(
118     const Tensor& grad_grid,
119     int64_t N,
120     int64_t C,
121     int64_t D,
122     int64_t H,
123     int64_t W,
124     bool align_corners) {
125   auto base_grid = make_base_grid_5D(grad_grid, N, C, D, H, W, align_corners);
126   AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, D, H, W, 3}));
127   auto grad_theta = base_grid.view({N, D * H * W, 4})
128                         .transpose(1, 2)
129                         .bmm(grad_grid.reshape({N, D * H * W, 3}));
130   return grad_theta.transpose(1, 2);
131 }
132 
affine_grid_generator_backward(const Tensor & grad,IntArrayRef size,bool align_corners)133 Tensor affine_grid_generator_backward(const Tensor& grad, IntArrayRef size, bool align_corners) {
134   TORCH_CHECK(
135       size.size() == 4 || size.size() == 5,
136       "AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.");
137   if (size.size() == 4) {
138     return affine_grid_generator_4D_backward(
139         grad, size[0], size[1], size[2], size[3], align_corners);
140   } else {
141     return affine_grid_generator_5D_backward(
142         grad, size[0], size[1], size[2], size[3], size[4], align_corners);
143   }
144 }
145 
146 }  // namespace at::native
147