1 #pragma once
2
3 #include <ATen/CollapseDims.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorUtils.h>
6 #include <c10/util/irange.h>
7 #include <cstring>
8 #include <limits>
9
10 namespace at {
11
12 /*
13 * The basic strategy for apply is as follows:
14 *
15 * 1. Starting with the outermost index, loop until we reach a dimension where
16 * the data is no longer contiguous, i.e. the stride at that dimension is not
17 * equal to the size of the tensor defined by the outer dimensions. Let's call
18 * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
19 * A is equal to the entire Tensor. Let's call the inner tensor B.
20 *
21 * 2. We loop through the indices in B, starting at its outermost dimension. For
22 * example, if B is a 2x2 matrix, then we do:
23 *
24 * B[0][0]
25 * B[0][1]
26 * B[1][0]
27 * B[1][1]
28 *
29 * We set the offset into the underlying storage as (storageOffset + stride_B *
30 * index_B), i.e. basically we compute the offset into the storage as we would
31 * normally for a Tensor. But because we are guaranteed the subsequent data is
32 * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
33 * the operation, without having to follow the order described by the strides of
34 * A.
35 *
36 * 3. As an optimization, we merge dimensions of A that are contiguous in
37 * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
38 * then the first two dimensions can be merged for the purposes of APPLY,
39 * reducing the number of nested loops.
40 */
41
sort_strides(Tensor & tensor_)42 inline Tensor sort_strides(Tensor& tensor_) {
43 IntArrayRef strides = tensor_.strides();
44 std::vector<int64_t> indices;
45 indices.reserve(tensor_.ndimension());
46 for (const auto i : c10::irange(tensor_.ndimension())) {
47 indices.push_back(i);
48 }
49 std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
50 return strides[i1] > strides[i2];
51 });
52 Tensor tensor = tensor_.permute(indices);
53 return tensor;
54 }
55
56 template <typename T, int N>
57 struct strided_tensor_iter_fixed {
58 public:
59 T* data_ = NULL;
60 int64_t dim_ = 0;
61
62 int64_t counter_[N] = {0};
63 int64_t sizes_[N] = {0};
64 int64_t strides_[N] = {0};
65
66 strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
67 void operator=(strided_tensor_iter_fixed const& x) = delete;
68 strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
69 strided_tensor_iter_fixed(
70 Tensor& tensor,
71 C10_UNUSED bool sort_strides = false)
72 : data_(tensor.data_ptr<T>()) {
73 std::memset(counter_, 0, sizeof(int64_t) * N);
74 if (tensor.dim() > 0) {
75 std::memcpy(
76 sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
77 std::memcpy(
78 strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
79 }
80 dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
81 }
82 };
83
84 template <typename T>
85 struct strided_tensor_iter {
86 private:
87 public:
88 T* data_ = NULL;
89 int64_t dim_;
90
91 std::vector<int64_t> counter_;
92 std::vector<int64_t> sizes_;
93 std::vector<int64_t> strides_;
94
95 strided_tensor_iter(strided_tensor_iter const&) = delete;
96 void operator=(strided_tensor_iter const& x) = delete;
97 strided_tensor_iter(strided_tensor_iter&&) = default;
strided_tensor_iterstrided_tensor_iter98 strided_tensor_iter(Tensor& tensor)
99 : data_(tensor.data_ptr<T>()),
100 dim_(tensor.ndimension()),
101 counter_(dim_, 0),
102 sizes_(tensor.sizes().vec()),
103 strides_(tensor.strides().vec()) {
104 dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
105 }
106 };
107
_all_equal_numel(at::ArrayRef<Tensor> tensors)108 inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
109 if (tensors.empty())
110 return true;
111 int64_t all_numel = tensors[0].numel();
112 for (const auto i : c10::irange(1, tensors.size())) {
113 if (tensors[i].numel() != all_numel)
114 return false;
115 }
116 return true;
117 }
118
_all_equal_numel_error(at::ArrayRef<Tensor> tensors)119 inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
120 std::ostringstream oss;
121 oss << "inconsistent tensor size, expected ";
122 for (size_t i = 0; i < tensors.size() - 1; i++) {
123 oss << tensors[i].sizes() << ", ";
124 }
125 oss << "and " << tensors[tensors.size() - 1].sizes()
126 << " to have the same number of elements, but got ";
127 for (size_t i = 0; i < tensors.size() - 1; i++) {
128 oss << tensors[i].numel() << ", ";
129 }
130 oss << "and " << tensors[tensors.size() - 1].numel()
131 << " elements respectively";
132 return oss.str();
133 }
134
_apply_preamble(ArrayRef<Tensor> tensors)135 inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
136 checkDeviceType("CPU_tensor_apply", tensors, kCPU);
137 checkLayout("CPU_tensor_apply", tensors, kStrided);
138 if (!_all_equal_numel(tensors))
139 AT_ERROR(_all_equal_numel_error(tensors));
140 // An empty tensor has no elements
141 for (auto& t : tensors)
142 if (t.numel() == 0)
143 return false;
144 return true;
145 }
146
_max_dim_tensors(ArrayRef<Tensor> tensors)147 inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
148 int64_t dim = 0;
149 for (auto& t : tensors)
150 dim = std::max(dim, t.ndimension());
151 return dim;
152 }
153
iterate(int64_t)154 inline void iterate(int64_t /*size*/){};
155
156 template <typename Arg, typename... Args>
iterate(int64_t size,Arg & iter,Args &...iter_tail)157 inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
158 iter.counter_[iter.dim_ - 1] += size;
159 iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
160 iterate(size, iter_tail...);
161 }
162
iterate_continue()163 inline bool iterate_continue() {
164 return true;
165 };
166
167 template <typename Arg, typename... Args>
iterate_continue(Arg & iter,Args &...iter_tail)168 inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
169 return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
170 iterate_continue(iter_tail...);
171 }
172
max_iterate_size()173 inline int64_t max_iterate_size() {
174 return std::numeric_limits<int64_t>::max();
175 };
176
177 template <typename Arg, typename... Args>
max_iterate_size(Arg & iter,Args &...iter_tail)178 inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
179 return std::min(
180 (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
181 max_iterate_size(iter_tail...));
182 }
183
iterate_overflow()184 inline void iterate_overflow(){};
185
186 template <typename Arg, typename... Args>
iterate_overflow(Arg & iter,Args &...iter_tail)187 inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
188 if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
189 for (int64_t i = iter.dim_ - 1; i > 0; i--) {
190 if (iter.counter_[i] == iter.sizes_[i]) {
191 iter.counter_[i] = 0;
192 iter.counter_[i - 1]++;
193 iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
194 iter.strides_[i - 1];
195 }
196 }
197 }
198 iterate_overflow(iter_tail...);
199 }
200
forward(int64_t)201 inline void forward(int64_t /*offset*/){};
202
203 template <typename Arg, typename... Args>
forward(int64_t offset,Arg & iter,Args &...iter_tail)204 inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
205 int64_t multi = offset;
206 for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
207 int64_t inc = multi % iter.sizes_[i];
208 multi = multi / iter.sizes_[i];
209 iter.data_ = iter.data_ + inc * iter.strides_[i];
210 iter.counter_[i] += inc;
211 }
212 forward(offset, iter_tail...);
213 }
214
max_dim()215 inline int64_t max_dim() {
216 return 0;
217 }
218
219 template <typename Arg, typename... Args>
max_dim(Arg & iter,Args &...iter_tail)220 inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
221 return std::max(iter.dim_, max_dim(iter_tail...));
222 }
223
apply_op()224 inline void apply_op(){};
225
226 template <typename Op, typename... Args>
apply_op(int64_t numel,int64_t offset,const Op & op,Args...iters)227 inline void apply_op(
228 int64_t numel,
229 int64_t offset,
230 const Op& op,
231 Args... iters) {
232 // For 0-dim tensors
233 if (numel == 1 && max_dim(iters...) == 0) {
234 op(*iters.data_...);
235 return;
236 }
237 if (offset > 0)
238 forward(offset, iters...);
239 // Splitting this into chunks helps the compiler create faster assembly
240 for (int64_t i = 0; i < numel;) {
241 for (; iterate_continue(iters...) && i < numel;) {
242 op(*iters.data_...);
243 iterate(1, iters...);
244 i++;
245 }
246 iterate_overflow(iters...);
247 }
248 }
249
250 /*
251 Apply a pointwise operator to sequence of tensors
252
253 The calling convention for op is a function/functor that takes the same
254 number of pointers of type scalar as the number of given tensors. For example,
255 to compute a = b * c, op would be of the form:
256 [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
257 b_val[0] * c_val[0]; };
258 */
259
260 template <typename scalar1, typename scalar2, typename Op>
CPU_tensor_apply2(Tensor tensor1,Tensor tensor2,const Op op)261 inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
262 if (!_apply_preamble({tensor1, tensor2}))
263 return;
264 if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
265 apply_op(
266 tensor1.numel(),
267 0,
268 op,
269 strided_tensor_iter_fixed<scalar1, 8>(tensor1),
270 strided_tensor_iter_fixed<scalar2, 8>(tensor2));
271 } else {
272 apply_op(
273 tensor1.numel(),
274 0,
275 op,
276 strided_tensor_iter<scalar1>(tensor1),
277 strided_tensor_iter<scalar2>(tensor2));
278 }
279 }
280
281 template <typename scalar1, typename scalar2, typename scalar3, typename Op>
CPU_tensor_apply3(Tensor tensor1,Tensor tensor2,Tensor tensor3,const Op op)282 inline void CPU_tensor_apply3(
283 Tensor tensor1,
284 Tensor tensor2,
285 Tensor tensor3,
286 const Op op) {
287 if (!_apply_preamble({tensor1, tensor2, tensor3}))
288 return;
289 if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
290 apply_op(
291 tensor1.numel(),
292 0,
293 op,
294 strided_tensor_iter_fixed<scalar1, 8>(tensor1),
295 strided_tensor_iter_fixed<scalar2, 8>(tensor2),
296 strided_tensor_iter_fixed<scalar3, 8>(tensor3));
297 } else {
298 apply_op(
299 tensor1.numel(),
300 0,
301 op,
302 strided_tensor_iter<scalar1>(tensor1),
303 strided_tensor_iter<scalar2>(tensor2),
304 strided_tensor_iter<scalar3>(tensor3));
305 }
306 }
307
308 template <
309 typename scalar1,
310 typename scalar2,
311 typename scalar3,
312 typename scalar4,
313 typename Op>
CPU_tensor_apply4(Tensor tensor1,Tensor tensor2,Tensor tensor3,Tensor tensor4,const Op op)314 inline void CPU_tensor_apply4(
315 Tensor tensor1,
316 Tensor tensor2,
317 Tensor tensor3,
318 Tensor tensor4,
319 const Op op) {
320 if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
321 return;
322 if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
323 apply_op(
324 tensor1.numel(),
325 0,
326 op,
327 strided_tensor_iter_fixed<scalar1, 8>(tensor1),
328 strided_tensor_iter_fixed<scalar2, 8>(tensor2),
329 strided_tensor_iter_fixed<scalar3, 8>(tensor3),
330 strided_tensor_iter_fixed<scalar4, 8>(tensor4));
331 } else {
332 apply_op(
333 tensor1.numel(),
334 0,
335 op,
336 strided_tensor_iter<scalar1>(tensor1),
337 strided_tensor_iter<scalar2>(tensor2),
338 strided_tensor_iter<scalar3>(tensor3),
339 strided_tensor_iter<scalar4>(tensor4));
340 }
341 }
342
343 } // namespace at
344