xref: /aosp_15_r20/external/pytorch/aten/src/ATen/CPUApplyUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/CollapseDims.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorUtils.h>
6 #include <c10/util/irange.h>
7 #include <cstring>
8 #include <limits>
9 
10 namespace at {
11 
12 /*
13  * The basic strategy for apply is as follows:
14  *
15  * 1. Starting with the outermost index, loop until we reach a dimension where
16  * the data is no longer contiguous, i.e. the stride at that dimension is not
17  * equal to the size of the tensor defined by the outer dimensions. Let's call
18  * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
19  * A is equal to the entire Tensor. Let's call the inner tensor B.
20  *
21  * 2. We loop through the indices in B, starting at its outermost dimension. For
22  * example, if B is a 2x2 matrix, then we do:
23  *
24  * B[0][0]
25  * B[0][1]
26  * B[1][0]
27  * B[1][1]
28  *
29  * We set the offset into the underlying storage as (storageOffset + stride_B *
30  * index_B), i.e. basically we compute the offset into the storage as we would
31  * normally for a Tensor. But because we are guaranteed the subsequent data is
32  * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
33  * the operation, without having to follow the order described by the strides of
34  * A.
35  *
36  * 3. As an optimization, we merge dimensions of A that are contiguous in
37  * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
38  * then the first two dimensions can be merged for the purposes of APPLY,
39  * reducing the number of nested loops.
40  */
41 
sort_strides(Tensor & tensor_)42 inline Tensor sort_strides(Tensor& tensor_) {
43   IntArrayRef strides = tensor_.strides();
44   std::vector<int64_t> indices;
45   indices.reserve(tensor_.ndimension());
46   for (const auto i : c10::irange(tensor_.ndimension())) {
47     indices.push_back(i);
48   }
49   std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
50     return strides[i1] > strides[i2];
51   });
52   Tensor tensor = tensor_.permute(indices);
53   return tensor;
54 }
55 
56 template <typename T, int N>
57 struct strided_tensor_iter_fixed {
58  public:
59   T* data_ = NULL;
60   int64_t dim_ = 0;
61 
62   int64_t counter_[N] = {0};
63   int64_t sizes_[N] = {0};
64   int64_t strides_[N] = {0};
65 
66   strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
67   void operator=(strided_tensor_iter_fixed const& x) = delete;
68   strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
69   strided_tensor_iter_fixed(
70       Tensor& tensor,
71       C10_UNUSED bool sort_strides = false)
72       : data_(tensor.data_ptr<T>()) {
73     std::memset(counter_, 0, sizeof(int64_t) * N);
74     if (tensor.dim() > 0) {
75       std::memcpy(
76           sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
77       std::memcpy(
78           strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
79     }
80     dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
81   }
82 };
83 
84 template <typename T>
85 struct strided_tensor_iter {
86  private:
87  public:
88   T* data_ = NULL;
89   int64_t dim_;
90 
91   std::vector<int64_t> counter_;
92   std::vector<int64_t> sizes_;
93   std::vector<int64_t> strides_;
94 
95   strided_tensor_iter(strided_tensor_iter const&) = delete;
96   void operator=(strided_tensor_iter const& x) = delete;
97   strided_tensor_iter(strided_tensor_iter&&) = default;
strided_tensor_iterstrided_tensor_iter98   strided_tensor_iter(Tensor& tensor)
99       : data_(tensor.data_ptr<T>()),
100         dim_(tensor.ndimension()),
101         counter_(dim_, 0),
102         sizes_(tensor.sizes().vec()),
103         strides_(tensor.strides().vec()) {
104     dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
105   }
106 };
107 
_all_equal_numel(at::ArrayRef<Tensor> tensors)108 inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
109   if (tensors.empty())
110     return true;
111   int64_t all_numel = tensors[0].numel();
112   for (const auto i : c10::irange(1, tensors.size())) {
113     if (tensors[i].numel() != all_numel)
114       return false;
115   }
116   return true;
117 }
118 
_all_equal_numel_error(at::ArrayRef<Tensor> tensors)119 inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
120   std::ostringstream oss;
121   oss << "inconsistent tensor size, expected ";
122   for (size_t i = 0; i < tensors.size() - 1; i++) {
123     oss << tensors[i].sizes() << ", ";
124   }
125   oss << "and " << tensors[tensors.size() - 1].sizes()
126       << " to have the same number of elements, but got ";
127   for (size_t i = 0; i < tensors.size() - 1; i++) {
128     oss << tensors[i].numel() << ", ";
129   }
130   oss << "and " << tensors[tensors.size() - 1].numel()
131       << " elements respectively";
132   return oss.str();
133 }
134 
_apply_preamble(ArrayRef<Tensor> tensors)135 inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
136   checkDeviceType("CPU_tensor_apply", tensors, kCPU);
137   checkLayout("CPU_tensor_apply", tensors, kStrided);
138   if (!_all_equal_numel(tensors))
139     AT_ERROR(_all_equal_numel_error(tensors));
140   // An empty tensor has no elements
141   for (auto& t : tensors)
142     if (t.numel() == 0)
143       return false;
144   return true;
145 }
146 
_max_dim_tensors(ArrayRef<Tensor> tensors)147 inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
148   int64_t dim = 0;
149   for (auto& t : tensors)
150     dim = std::max(dim, t.ndimension());
151   return dim;
152 }
153 
iterate(int64_t)154 inline void iterate(int64_t /*size*/){};
155 
156 template <typename Arg, typename... Args>
iterate(int64_t size,Arg & iter,Args &...iter_tail)157 inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
158   iter.counter_[iter.dim_ - 1] += size;
159   iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
160   iterate(size, iter_tail...);
161 }
162 
iterate_continue()163 inline bool iterate_continue() {
164   return true;
165 };
166 
167 template <typename Arg, typename... Args>
iterate_continue(Arg & iter,Args &...iter_tail)168 inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
169   return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
170       iterate_continue(iter_tail...);
171 }
172 
max_iterate_size()173 inline int64_t max_iterate_size() {
174   return std::numeric_limits<int64_t>::max();
175 };
176 
177 template <typename Arg, typename... Args>
max_iterate_size(Arg & iter,Args &...iter_tail)178 inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
179   return std::min(
180       (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
181       max_iterate_size(iter_tail...));
182 }
183 
iterate_overflow()184 inline void iterate_overflow(){};
185 
186 template <typename Arg, typename... Args>
iterate_overflow(Arg & iter,Args &...iter_tail)187 inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
188   if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
189     for (int64_t i = iter.dim_ - 1; i > 0; i--) {
190       if (iter.counter_[i] == iter.sizes_[i]) {
191         iter.counter_[i] = 0;
192         iter.counter_[i - 1]++;
193         iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
194             iter.strides_[i - 1];
195       }
196     }
197   }
198   iterate_overflow(iter_tail...);
199 }
200 
forward(int64_t)201 inline void forward(int64_t /*offset*/){};
202 
203 template <typename Arg, typename... Args>
forward(int64_t offset,Arg & iter,Args &...iter_tail)204 inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
205   int64_t multi = offset;
206   for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
207     int64_t inc = multi % iter.sizes_[i];
208     multi = multi / iter.sizes_[i];
209     iter.data_ = iter.data_ + inc * iter.strides_[i];
210     iter.counter_[i] += inc;
211   }
212   forward(offset, iter_tail...);
213 }
214 
max_dim()215 inline int64_t max_dim() {
216   return 0;
217 }
218 
219 template <typename Arg, typename... Args>
max_dim(Arg & iter,Args &...iter_tail)220 inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
221   return std::max(iter.dim_, max_dim(iter_tail...));
222 }
223 
apply_op()224 inline void apply_op(){};
225 
226 template <typename Op, typename... Args>
apply_op(int64_t numel,int64_t offset,const Op & op,Args...iters)227 inline void apply_op(
228     int64_t numel,
229     int64_t offset,
230     const Op& op,
231     Args... iters) {
232   // For 0-dim tensors
233   if (numel == 1 && max_dim(iters...) == 0) {
234     op(*iters.data_...);
235     return;
236   }
237   if (offset > 0)
238     forward(offset, iters...);
239   // Splitting this into chunks helps the compiler create faster assembly
240   for (int64_t i = 0; i < numel;) {
241     for (; iterate_continue(iters...) && i < numel;) {
242       op(*iters.data_...);
243       iterate(1, iters...);
244       i++;
245     }
246     iterate_overflow(iters...);
247   }
248 }
249 
250 /*
251   Apply a pointwise operator to sequence of tensors
252 
253   The calling convention for op is a function/functor that takes the same
254   number of pointers of type scalar as the number of given tensors. For example,
255   to compute a = b * c, op would be of the form:
256   [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
257   b_val[0] * c_val[0]; };
258 */
259 
260 template <typename scalar1, typename scalar2, typename Op>
CPU_tensor_apply2(Tensor tensor1,Tensor tensor2,const Op op)261 inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
262   if (!_apply_preamble({tensor1, tensor2}))
263     return;
264   if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
265     apply_op(
266         tensor1.numel(),
267         0,
268         op,
269         strided_tensor_iter_fixed<scalar1, 8>(tensor1),
270         strided_tensor_iter_fixed<scalar2, 8>(tensor2));
271   } else {
272     apply_op(
273         tensor1.numel(),
274         0,
275         op,
276         strided_tensor_iter<scalar1>(tensor1),
277         strided_tensor_iter<scalar2>(tensor2));
278   }
279 }
280 
281 template <typename scalar1, typename scalar2, typename scalar3, typename Op>
CPU_tensor_apply3(Tensor tensor1,Tensor tensor2,Tensor tensor3,const Op op)282 inline void CPU_tensor_apply3(
283     Tensor tensor1,
284     Tensor tensor2,
285     Tensor tensor3,
286     const Op op) {
287   if (!_apply_preamble({tensor1, tensor2, tensor3}))
288     return;
289   if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
290     apply_op(
291         tensor1.numel(),
292         0,
293         op,
294         strided_tensor_iter_fixed<scalar1, 8>(tensor1),
295         strided_tensor_iter_fixed<scalar2, 8>(tensor2),
296         strided_tensor_iter_fixed<scalar3, 8>(tensor3));
297   } else {
298     apply_op(
299         tensor1.numel(),
300         0,
301         op,
302         strided_tensor_iter<scalar1>(tensor1),
303         strided_tensor_iter<scalar2>(tensor2),
304         strided_tensor_iter<scalar3>(tensor3));
305   }
306 }
307 
308 template <
309     typename scalar1,
310     typename scalar2,
311     typename scalar3,
312     typename scalar4,
313     typename Op>
CPU_tensor_apply4(Tensor tensor1,Tensor tensor2,Tensor tensor3,Tensor tensor4,const Op op)314 inline void CPU_tensor_apply4(
315     Tensor tensor1,
316     Tensor tensor2,
317     Tensor tensor3,
318     Tensor tensor4,
319     const Op op) {
320   if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
321     return;
322   if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
323     apply_op(
324         tensor1.numel(),
325         0,
326         op,
327         strided_tensor_iter_fixed<scalar1, 8>(tensor1),
328         strided_tensor_iter_fixed<scalar2, 8>(tensor2),
329         strided_tensor_iter_fixed<scalar3, 8>(tensor3),
330         strided_tensor_iter_fixed<scalar4, 8>(tensor4));
331   } else {
332     apply_op(
333         tensor1.numel(),
334         0,
335         op,
336         strided_tensor_iter<scalar1>(tensor1),
337         strided_tensor_iter<scalar2>(tensor2),
338         strided_tensor_iter<scalar3>(tensor3),
339         strided_tensor_iter<scalar4>(tensor4));
340   }
341 }
342 
343 } // namespace at
344