xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorTransformations.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/TensorTransformations.h>
3 #include <ATen/native/IndexKernel.h>  // for flip_stub
4 
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/WrapDimUtilsMulti.h>
8 #include <ATen/core/DimVector.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/irange.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/atleast_1d_native.h>
17 #include <ATen/ops/atleast_2d_native.h>
18 #include <ATen/ops/atleast_3d_native.h>
19 #include <ATen/ops/cat.h>
20 #include <ATen/ops/chalf_native.h>
21 #include <ATen/ops/empty_like.h>
22 #include <ATen/ops/flip_native.h>
23 #include <ATen/ops/fliplr_native.h>
24 #include <ATen/ops/flipud_native.h>
25 #include <ATen/ops/roll_native.h>
26 #include <ATen/ops/rot90_native.h>
27 #include <ATen/ops/zeros_like_ops.h>
28 #endif
29 
30 #include <algorithm>
31 #include <utility>
32 #include <vector>
33 
34 namespace at::native {
35 
flip(const Tensor & self,IntArrayRef dims)36 Tensor flip(const Tensor& self, IntArrayRef dims) {
37   const int64_t total_dims = self.dim();
38   // It wraps the dims and checks that there are no repeated dims
39   auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims);
40 
41   Tensor out_tensor = at::empty_like(self, MemoryFormat::Preserve);
42 
43   // Count dimensions in which we need to do work
44   int n = 0;
45   auto strides = DimVector(self.strides());
46   for (const auto i : c10::irange(total_dims)) {
47     if(flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) {
48       n++;
49       strides[i] = 0;
50     }
51   }
52 
53   // Nothing to do, we return fast
54   if (n == 0 || self.numel() <=1) {
55     out_tensor.copy_(self);
56     return out_tensor;
57   }
58 
59   //create dummy output with 0 strides at flipped dimension, to prevent tensorIterator from coalescing flipped dims
60   const auto restrided_self = self.as_strided(self.sizes(), strides);
61   auto iter = TensorIteratorConfig()
62     .set_check_mem_overlap(false)
63     .check_all_same_dtype(false)
64     .declare_static_dtype_and_device(self.scalar_type(), self.device())
65     .add_output(out_tensor)
66     .add_const_input(self)
67     .add_const_input(restrided_self)
68     .build();
69 
70   auto* data = reinterpret_cast<char*>(iter.data_ptr(0));
71   const auto sizes = iter.shape();
72   // This is a SmallVector of _signed_ ints
73   auto strides_bytes = DimVector(iter.strides(0));
74   const auto strides_self = iter.strides(1);
75   const auto strides_dummy = iter.strides(2);
76 
77   // To understand this transformation, think of a 3D cube.
78   //   - The data ptr points to the lower-left most vertex of the cube
79   //   - The strides tell us how to move in each dimension,
80   //     that is, data + stride[i] advances one element in the dimension i
81   // To flip a dimension:
82   //   - We move the pointer to the opposite vertex of the cube
83   //   - We iterate in the opposite direction (invert the strides)
84 
85   for (const auto i : c10::irange(iter.ndim())) {
86     // We know that an dimension has a zero stride and self[i] does not, as we defined above
87     // Note that it may be the case that strides_dummy[i] = 0 not because we set it, but because
88     // strides_self[i] == 0. We do not want to do anything there
89     if (strides_dummy[i] == 0 && strides_self[i] != 0) {
90       data += strides_bytes[i] * (sizes[i]-1);
91       strides_bytes[i] *= -1;
92     }
93   }
94   iter._unsafe_set_arg_strides(0, strides_bytes);
95   iter._unsafe_set_arg_data(0, reinterpret_cast<void*>(data));
96 
97   flip_stub(iter.device_type(), iter, self.is_quantized());
98 
99   return out_tensor;
100 }
101 
roll(const Tensor & self,IntArrayRef shifts,IntArrayRef dims)102 Tensor roll(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { // Used by CPU and MPS dispatch.
103   if (dims.size() != 1 || shifts.size() != 1) {
104     return roll_common(self, shifts, dims);
105   }
106   // avoid a div zero error below.
107   if (self.numel() == 0) {
108     return self.clone(at::MemoryFormat::Preserve);
109   }
110   int64_t dim = dims[0];
111   int64_t size = self.size(dim);
112   int64_t start = (size - shifts[0]) % size;
113   // Behavior of % is different in C++ vs Python for negative numbers. This
114   // corrects the difference.
115   if (start < 0) {
116     start = start + size;
117   }
118   auto t0 = self.narrow(dim, start, size-start);
119   auto t1 = self.narrow(dim, 0, start);
120   return at::cat({std::move(t0), std::move(t1)}, dim);
121 }
122 
rot90(const Tensor & self,int64_t k,IntArrayRef dims)123 Tensor rot90(const Tensor& self, int64_t k, IntArrayRef dims) {
124   const int64_t total_dims = self.dim(), total_rot_dims = dims.size();
125 
126   TORCH_CHECK(total_rot_dims == 2,
127     "expected total rotation dims == 2, but got dims = ", total_rot_dims);
128 
129   TORCH_CHECK(total_dims >= 2,
130     "expected total dims >= 2, but got total dims = ", total_dims);
131 
132   TORCH_CHECK(dims[0] != dims[1] && std::abs(dims[0] - dims[1]) != total_dims,
133     "expected rotation dims to be different, but got dim0 = ", dims[0],
134     " and dim1 = ", dims[1]);
135 
136   // check range of dims
137   TORCH_CHECK(dims[0] < total_dims && dims[0] >= -total_dims,
138     "Rotation dim0 out of range, dim0 = ", dims[0]);
139 
140   TORCH_CHECK(dims[1] < total_dims && dims[1] >= -total_dims,
141     "Rotation dim1 out of range, dim1 = ", dims[1]);
142 
143   // handle modulo with negative k
144   k = (4 + (k % 4)) % 4;
145 
146   switch(k) {
147     case 1:
148       return self.flip({dims[1]}).transpose_(dims[0], dims[1]);
149     case 2:
150       return self.flip(dims);
151     case 3:
152       return self.flip({dims[0]}).transpose_(dims[0], dims[1]);
153     default:
154       return self.clone(at::MemoryFormat::Contiguous);
155   }
156 }
157 
fliplr(const Tensor & self)158 Tensor fliplr(const Tensor& self) {
159   TORCH_CHECK(self.dim() >= 2, "Input must be >= 2-d.");
160 
161   return self.flip({1});
162 }
163 
flipud(const Tensor & self)164 Tensor flipud(const Tensor& self) {
165   TORCH_CHECK(self.dim() >= 1, "Input must be >= 1-d.");
166 
167   return self.flip({0});
168 }
169 
atleast_1d(const Tensor & self)170 Tensor atleast_1d(const Tensor& self) {
171   switch (self.dim()) {
172     case 0:
173       return self.reshape({1});
174     default:
175       return self;
176   }
177 }
178 
atleast_1d(TensorList tensors)179 std::vector<Tensor> atleast_1d(TensorList tensors) {
180   std::vector<Tensor> result(tensors.size());
181   auto transform_lambda = [](const Tensor& input) -> Tensor {
182     return at::native::atleast_1d(input);
183   };
184   std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
185   return result;
186 }
187 
atleast_2d(const Tensor & self)188 Tensor atleast_2d(const Tensor& self) {
189   switch (self.dim()) {
190     case 0:
191       return self.reshape({1, 1});
192     case 1: {
193       return self.unsqueeze(0);
194     }
195     default:
196       return self;
197   }
198 }
199 
atleast_2d(TensorList tensors)200 std::vector<Tensor> atleast_2d(TensorList tensors) {
201   std::vector<Tensor> result(tensors.size());
202   auto transform_lambda = [](const Tensor& input) -> Tensor {
203     return at::native::atleast_2d(input);
204   };
205   std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
206   return result;
207 }
208 
atleast_3d(const Tensor & self)209 Tensor atleast_3d(const Tensor& self) {
210   switch (self.dim()) {
211     case 0:
212       return self.reshape({1, 1, 1});
213     case 1: {
214       return self.unsqueeze(0).unsqueeze(-1);
215     }
216     case 2: {
217       return self.unsqueeze(-1);
218     }
219     default:
220       return self;
221   }
222 }
223 
atleast_3d(TensorList tensors)224 std::vector<Tensor> atleast_3d(TensorList tensors) {
225   std::vector<Tensor> result(tensors.size());
226   auto transform_lambda = [](const Tensor& input) -> Tensor {
227     return at::native::atleast_3d(input);
228   };
229   std::transform(tensors.cbegin(), tensors.cend(), result.begin(), transform_lambda);
230   return result;
231 }
232 
chalf(const Tensor & self,std::optional<MemoryFormat> memory_format)233 Tensor chalf(const Tensor& self, std::optional<MemoryFormat> memory_format) {
234   return self.to(kComplexHalf, false, false, memory_format);
235 }
236 
237 DEFINE_DISPATCH(flip_stub);
238 
239 } // namespace at::native
240