1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorOperators.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/affine_grid_generator_backward_native.h>
10 #include <ATen/ops/affine_grid_generator_native.h>
11 #include <ATen/ops/empty.h>
12 #include <ATen/ops/linspace.h>
13 #include <ATen/ops/tensor.h>
14 #endif
15
16 namespace at::native {
17
linspace_from_neg_one(const Tensor & grid,int64_t num_steps,bool align_corners)18 static at::Tensor linspace_from_neg_one(const Tensor& grid, int64_t num_steps,
19 bool align_corners) {
20 if (num_steps <= 1) {
21 return at::tensor(0, grid.options());
22 }
23 auto range = at::linspace(-1, 1, num_steps, grid.options());
24 if (!align_corners) {
25 range = range * (num_steps - 1) / num_steps;
26 }
27 return range;
28 }
29
make_base_grid_4D(const Tensor & theta,int64_t N,int64_t C,int64_t H,int64_t W,bool align_corners)30 static Tensor make_base_grid_4D(
31 const Tensor& theta,
32 int64_t N,
33 int64_t C,
34 int64_t H,
35 int64_t W,
36 bool align_corners) {
37 auto base_grid = at::empty({N, H, W, 3}, theta.options());
38
39 base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners));
40 base_grid.select(-1, 1).copy_(linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1));
41 base_grid.select(-1, 2).fill_(1);
42
43 return base_grid;
44 }
45
make_base_grid_5D(const Tensor & theta,int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,bool align_corners)46 static Tensor make_base_grid_5D(
47 const Tensor& theta,
48 int64_t N,
49 int64_t C,
50 int64_t D,
51 int64_t H,
52 int64_t W,
53 bool align_corners) {
54 auto base_grid = at::empty({N, D, H, W, 4}, theta.options());
55
56 base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W, align_corners));
57 base_grid.select(-1, 1).copy_(linspace_from_neg_one(theta, H, align_corners).unsqueeze_(-1));
58 base_grid.select(-1, 2).copy_(linspace_from_neg_one(theta, D, align_corners).unsqueeze_(-1).unsqueeze_(-1));
59 base_grid.select(-1, 3).fill_(1);
60
61 return base_grid;
62 }
63
affine_grid_generator_4D(const Tensor & theta,int64_t N,int64_t C,int64_t H,int64_t W,bool align_corners)64 static Tensor affine_grid_generator_4D(
65 const Tensor& theta,
66 int64_t N,
67 int64_t C,
68 int64_t H,
69 int64_t W,
70 bool align_corners) {
71 Tensor base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners);
72 auto grid = base_grid.view({N, H * W, 3}).bmm(theta.transpose(1, 2));
73 return grid.view({N, H, W, 2});
74 }
75
affine_grid_generator_5D(const Tensor & theta,int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,bool align_corners)76 static Tensor affine_grid_generator_5D(
77 const Tensor& theta,
78 int64_t N,
79 int64_t C,
80 int64_t D,
81 int64_t H,
82 int64_t W,
83 bool align_corners) {
84 Tensor base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners);
85 auto grid = base_grid.view({N, D * H * W, 4}).bmm(theta.transpose(1, 2));
86 return grid.view({N, D, H, W, 3});
87 }
88
affine_grid_generator(const Tensor & theta,IntArrayRef size,bool align_corners)89 Tensor affine_grid_generator(const Tensor& theta, IntArrayRef size, bool align_corners) {
90 TORCH_CHECK(
91 size.size() == 4 || size.size() == 5,
92 "AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.");
93 if (size.size() == 4) {
94 return affine_grid_generator_4D(
95 theta, size[0], size[1], size[2], size[3], align_corners);
96 } else {
97 return affine_grid_generator_5D(
98 theta, size[0], size[1], size[2], size[3], size[4], align_corners);
99 }
100 }
101
affine_grid_generator_4D_backward(const Tensor & grad_grid,int64_t N,int64_t C,int64_t H,int64_t W,bool align_corners)102 static Tensor affine_grid_generator_4D_backward(
103 const Tensor& grad_grid,
104 int64_t N,
105 int64_t C,
106 int64_t H,
107 int64_t W,
108 bool align_corners) {
109 auto base_grid = make_base_grid_4D(grad_grid, N, C, H, W, align_corners);
110 AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, H, W, 2}));
111 auto grad_theta = base_grid.view({N, H * W, 3})
112 .transpose(1, 2)
113 .bmm(grad_grid.reshape({N, H * W, 2}));
114 return grad_theta.transpose(1, 2);
115 }
116
affine_grid_generator_5D_backward(const Tensor & grad_grid,int64_t N,int64_t C,int64_t D,int64_t H,int64_t W,bool align_corners)117 static Tensor affine_grid_generator_5D_backward(
118 const Tensor& grad_grid,
119 int64_t N,
120 int64_t C,
121 int64_t D,
122 int64_t H,
123 int64_t W,
124 bool align_corners) {
125 auto base_grid = make_base_grid_5D(grad_grid, N, C, D, H, W, align_corners);
126 AT_ASSERT(grad_grid.sizes() == IntArrayRef({N, D, H, W, 3}));
127 auto grad_theta = base_grid.view({N, D * H * W, 4})
128 .transpose(1, 2)
129 .bmm(grad_grid.reshape({N, D * H * W, 3}));
130 return grad_theta.transpose(1, 2);
131 }
132
affine_grid_generator_backward(const Tensor & grad,IntArrayRef size,bool align_corners)133 Tensor affine_grid_generator_backward(const Tensor& grad, IntArrayRef size, bool align_corners) {
134 TORCH_CHECK(
135 size.size() == 4 || size.size() == 5,
136 "AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.");
137 if (size.size() == 4) {
138 return affine_grid_generator_4D_backward(
139 grad, size[0], size[1], size[2], size[3], align_corners);
140 } else {
141 return affine_grid_generator_5D_backward(
142 grad, size[0], size[1], size[2], size[3], size[4], align_corners);
143 }
144 }
145
146 } // namespace at::native
147