1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/TensorTransformations.h>
3 #include <ATen/native/IndexKernel.h> // for flip_stub
4
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/WrapDimUtilsMulti.h>
8 #include <ATen/core/DimVector.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/irange.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/atleast_1d_native.h>
17 #include <ATen/ops/atleast_2d_native.h>
18 #include <ATen/ops/atleast_3d_native.h>
19 #include <ATen/ops/cat.h>
20 #include <ATen/ops/chalf_native.h>
21 #include <ATen/ops/empty_like.h>
22 #include <ATen/ops/flip_native.h>
23 #include <ATen/ops/fliplr_native.h>
24 #include <ATen/ops/flipud_native.h>
25 #include <ATen/ops/roll_native.h>
26 #include <ATen/ops/rot90_native.h>
27 #include <ATen/ops/zeros_like_ops.h>
28 #endif
29
30 #include <algorithm>
31 #include <utility>
32 #include <vector>
33
34 namespace at::native {
35
flip(const Tensor & self,IntArrayRef dims)36 Tensor flip(const Tensor& self, IntArrayRef dims) {
37 const int64_t total_dims = self.dim();
38 // It wraps the dims and checks that there are no repeated dims
39 auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);
40
41 Tensor out_tensor = at::empty_like(self, MemoryFormat::Preserve);
42
43 // Count dimensions in which we need to do work
44 int n = 0;
45 auto strides = DimVector(self.strides());
46 for (const auto i : c10::irange(total_dims)) {
47 if(flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) {
48 n++;
49 strides[i] = 0;
50 }
51 }
52
53 // Nothing to do, we return fast
54 if (n == 0 || self.numel() <=1) {
55 out_tensor.copy_(self);
56 return out_tensor;
57 }
58
59 //create dummy output with 0 strides at flipped dimension, to prevent tensorIterator from coalescing flipped dims
60 const auto restrided_self = self.as_strided(self.sizes(), strides);
61 auto iter = TensorIteratorConfig()
62 .set_check_mem_overlap(false)
63 .check_all_same_dtype(false)
64 .declare_static_dtype_and_device(self.scalar_type(), self.device())
65 .add_output(out_tensor)
66 .add_const_input(self)
67 .add_const_input(restrided_self)
68 .build();
69
70 auto* data = reinterpret_cast<char*>(iter.data_ptr(0));
71 const auto sizes = iter.shape();
72 // This is a SmallVector of _signed_ ints
73 auto strides_bytes = DimVector(iter.strides(0));
74 const auto strides_self = iter.strides(1);
75 const auto strides_dummy = iter.strides(2);
76
77 // To understand this transformation, think of a 3D cube.
78 // - The data ptr points to the lower-left most vertex of the cube
79 // - The strides tell us how to move in each dimension,
80 // that is, data + stride[i] advances one element in the dimension i
81 // To flip a dimension:
82 // - We move the pointer to the opposite vertex of the cube
83 // - We iterate in the opposite direction (invert the strides)
84
85 for (const auto i : c10::irange(iter.ndim())) {
86 // We know that an dimension has a zero stride and self[i] does not, as we defined above
87 // Note that it may be the case that strides_dummy[i] = 0 not because we set it, but because
88 // strides_self[i] == 0. We do not want to do anything there
89 if (strides_dummy[i] == 0 && strides_self[i] != 0) {
90 data += strides_bytes[i] * (sizes[i]-1);
91 strides_bytes[i] *= -1;
92 }
93 }
94 iter._unsafe_set_arg_strides(0, strides_bytes);
95 iter._unsafe_set_arg_data(0, reinterpret_cast<void*>(data));
96
97 flip_stub(iter.device_type(), iter, self.is_quantized());
98
99 return out_tensor;
100 }
101
roll(const Tensor & self,IntArrayRef shifts,IntArrayRef dims)102 Tensor roll(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { // Used by CPU and MPS dispatch.
103 if (dims.size() != 1 || shifts.size() != 1) {
104 return roll_common(self, shifts, dims);
105 }
106 // avoid a div zero error below.
107 if (self.numel() == 0) {
108 return self.clone(at::MemoryFormat::Preserve);
109 }
110 int64_t dim = dims[0];
111 int64_t size = self.size(dim);
112 int64_t start = (size - shifts[0]) % size;
113 // Behavior of % is different in C++ vs Python for negative numbers. This
114 // corrects the difference.
115 if (start < 0) {
116 start = start + size;
117 }
118 auto t0 = self.narrow(dim, start, size-start);
119 auto t1 = self.narrow(dim, 0, start);
120 return at::cat({std::move(t0), std::move(t1)}, dim);
121 }
122
rot90(const Tensor & self,int64_t k,IntArrayRef dims)123 Tensor rot90(const Tensor& self, int64_t k, IntArrayRef dims) {
124 const int64_t total_dims = self.dim(), total_rot_dims = dims.size();
125
126 TORCH_CHECK(total_rot_dims == 2,
127 "expected total rotation dims == 2, but got dims = ", total_rot_dims);
128
129 TORCH_CHECK(total_dims >= 2,
130 "expected total dims >= 2, but got total dims = ", total_dims);
131
132 TORCH_CHECK(dims[0] != dims[1] && std::abs(dims[0] - dims[1]) != total_dims,
133 "expected rotation dims to be different, but got dim0 = ", dims[0],
134 " and dim1 = ", dims[1]);
135
136 // check range of dims
137 TORCH_CHECK(dims[0] < total_dims && dims[0] >= -total_dims,
138 "Rotation dim0 out of range, dim0 = ", dims[0]);
139
140 TORCH_CHECK(dims[1] < total_dims && dims[1] >= -total_dims,
141 "Rotation dim1 out of range, dim1 = ", dims[1]);
142
143 // handle modulo with negative k
144 k = (4 + (k % 4)) % 4;
145
146 switch(k) {
147 case 1:
148 return self.flip({dims[1]}).transpose_(dims[0], dims[1]);
149 case 2:
150 return self.flip(dims);
151 case 3:
152 return self.flip({dims[0]}).transpose_(dims[0], dims[1]);
153 default:
154 return self.clone(at::MemoryFormat::Contiguous);
155 }
156 }
157
fliplr(const Tensor & self)158 Tensor fliplr(const Tensor& self) {
159 TORCH_CHECK(self.dim() >= 2, "Input must be >= 2-d.");
160
161 return self.flip({1});
162 }
163
flipud(const Tensor & self)164 Tensor flipud(const Tensor& self) {
165 TORCH_CHECK(self.dim() >= 1, "Input must be >= 1-d.");
166
167 return self.flip({0});
168 }
169
atleast_1d(const Tensor & self)170 Tensor atleast_1d(const Tensor& self) {
171 switch (self.dim()) {
172 case 0:
173 return self.reshape({1});
174 default:
175 return self;
176 }
177 }
178
atleast_1d(TensorList tensors)179 std::vector<Tensor> atleast_1d(TensorList tensors) {
180 std::vector<Tensor> result(tensors.size());
181 auto transform_lambda = [](const Tensor& input) -> Tensor {
182 return at::native::atleast_1d(input);
183 };
184 std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
185 return result;
186 }
187
atleast_2d(const Tensor & self)188 Tensor atleast_2d(const Tensor& self) {
189 switch (self.dim()) {
190 case 0:
191 return self.reshape({1, 1});
192 case 1: {
193 return self.unsqueeze(0);
194 }
195 default:
196 return self;
197 }
198 }
199
atleast_2d(TensorList tensors)200 std::vector<Tensor> atleast_2d(TensorList tensors) {
201 std::vector<Tensor> result(tensors.size());
202 auto transform_lambda = [](const Tensor& input) -> Tensor {
203 return at::native::atleast_2d(input);
204 };
205 std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
206 return result;
207 }
208
atleast_3d(const Tensor & self)209 Tensor atleast_3d(const Tensor& self) {
210 switch (self.dim()) {
211 case 0:
212 return self.reshape({1, 1, 1});
213 case 1: {
214 return self.unsqueeze(0).unsqueeze(-1);
215 }
216 case 2: {
217 return self.unsqueeze(-1);
218 }
219 default:
220 return self;
221 }
222 }
223
atleast_3d(TensorList tensors)224 std::vector<Tensor> atleast_3d(TensorList tensors) {
225 std::vector<Tensor> result(tensors.size());
226 auto transform_lambda = [](const Tensor& input) -> Tensor {
227 return at::native::atleast_3d(input);
228 };
229 std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
230 return result;
231 }
232
chalf(const Tensor & self,std::optional<MemoryFormat> memory_format)233 Tensor chalf(const Tensor& self, std::optional<MemoryFormat> memory_format) {
234 return self.to(kComplexHalf, false, false, memory_format);
235 }
236
237 DEFINE_DISPATCH(flip_stub);
238
239 } // namespace at::native
240