xref: /aosp_15_r20/external/pytorch/aten/src/ATen/CPUApplyUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/CollapseDims.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/TensorUtils.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
7*da0073e9SAndroid Build Coastguard Worker #include <cstring>
8*da0073e9SAndroid Build Coastguard Worker #include <limits>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker namespace at {
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker /*
13*da0073e9SAndroid Build Coastguard Worker  * The basic strategy for apply is as follows:
14*da0073e9SAndroid Build Coastguard Worker  *
15*da0073e9SAndroid Build Coastguard Worker  * 1. Starting with the outermost index, loop until we reach a dimension where
16*da0073e9SAndroid Build Coastguard Worker  * the data is no longer contiguous, i.e. the stride at that dimension is not
17*da0073e9SAndroid Build Coastguard Worker  * equal to the size of the tensor defined by the outer dimensions. Let's call
18*da0073e9SAndroid Build Coastguard Worker  * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
19*da0073e9SAndroid Build Coastguard Worker  * A is equal to the entire Tensor. Let's call the inner tensor B.
20*da0073e9SAndroid Build Coastguard Worker  *
21*da0073e9SAndroid Build Coastguard Worker  * 2. We loop through the indices in B, starting at its outermost dimension. For
22*da0073e9SAndroid Build Coastguard Worker  * example, if B is a 2x2 matrix, then we do:
23*da0073e9SAndroid Build Coastguard Worker  *
24*da0073e9SAndroid Build Coastguard Worker  * B[0][0]
25*da0073e9SAndroid Build Coastguard Worker  * B[0][1]
26*da0073e9SAndroid Build Coastguard Worker  * B[1][0]
27*da0073e9SAndroid Build Coastguard Worker  * B[1][1]
28*da0073e9SAndroid Build Coastguard Worker  *
29*da0073e9SAndroid Build Coastguard Worker  * We set the offset into the underlying storage as (storageOffset + stride_B *
30*da0073e9SAndroid Build Coastguard Worker  * index_B), i.e. basically we compute the offset into the storage as we would
31*da0073e9SAndroid Build Coastguard Worker  * normally for a Tensor. But because we are guaranteed the subsequent data is
32*da0073e9SAndroid Build Coastguard Worker  * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
33*da0073e9SAndroid Build Coastguard Worker  * the operation, without having to follow the order described by the strides of
34*da0073e9SAndroid Build Coastguard Worker  * A.
35*da0073e9SAndroid Build Coastguard Worker  *
36*da0073e9SAndroid Build Coastguard Worker  * 3. As an optimization, we merge dimensions of A that are contiguous in
37*da0073e9SAndroid Build Coastguard Worker  * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
38*da0073e9SAndroid Build Coastguard Worker  * then the first two dimensions can be merged for the purposes of APPLY,
39*da0073e9SAndroid Build Coastguard Worker  * reducing the number of nested loops.
40*da0073e9SAndroid Build Coastguard Worker  */
41*da0073e9SAndroid Build Coastguard Worker 
sort_strides(Tensor & tensor_)42*da0073e9SAndroid Build Coastguard Worker inline Tensor sort_strides(Tensor& tensor_) {
43*da0073e9SAndroid Build Coastguard Worker   IntArrayRef strides = tensor_.strides();
44*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> indices;
45*da0073e9SAndroid Build Coastguard Worker   indices.reserve(tensor_.ndimension());
46*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(tensor_.ndimension())) {
47*da0073e9SAndroid Build Coastguard Worker     indices.push_back(i);
48*da0073e9SAndroid Build Coastguard Worker   }
49*da0073e9SAndroid Build Coastguard Worker   std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
50*da0073e9SAndroid Build Coastguard Worker     return strides[i1] > strides[i2];
51*da0073e9SAndroid Build Coastguard Worker   });
52*da0073e9SAndroid Build Coastguard Worker   Tensor tensor = tensor_.permute(indices);
53*da0073e9SAndroid Build Coastguard Worker   return tensor;
54*da0073e9SAndroid Build Coastguard Worker }
55*da0073e9SAndroid Build Coastguard Worker 
56*da0073e9SAndroid Build Coastguard Worker template <typename T, int N>
57*da0073e9SAndroid Build Coastguard Worker struct strided_tensor_iter_fixed {
58*da0073e9SAndroid Build Coastguard Worker  public:
59*da0073e9SAndroid Build Coastguard Worker   T* data_ = NULL;
60*da0073e9SAndroid Build Coastguard Worker   int64_t dim_ = 0;
61*da0073e9SAndroid Build Coastguard Worker 
62*da0073e9SAndroid Build Coastguard Worker   int64_t counter_[N] = {0};
63*da0073e9SAndroid Build Coastguard Worker   int64_t sizes_[N] = {0};
64*da0073e9SAndroid Build Coastguard Worker   int64_t strides_[N] = {0};
65*da0073e9SAndroid Build Coastguard Worker 
66*da0073e9SAndroid Build Coastguard Worker   strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
67*da0073e9SAndroid Build Coastguard Worker   void operator=(strided_tensor_iter_fixed const& x) = delete;
68*da0073e9SAndroid Build Coastguard Worker   strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
69*da0073e9SAndroid Build Coastguard Worker   strided_tensor_iter_fixed(
70*da0073e9SAndroid Build Coastguard Worker       Tensor& tensor,
71*da0073e9SAndroid Build Coastguard Worker       C10_UNUSED bool sort_strides = false)
72*da0073e9SAndroid Build Coastguard Worker       : data_(tensor.data_ptr<T>()) {
73*da0073e9SAndroid Build Coastguard Worker     std::memset(counter_, 0, sizeof(int64_t) * N);
74*da0073e9SAndroid Build Coastguard Worker     if (tensor.dim() > 0) {
75*da0073e9SAndroid Build Coastguard Worker       std::memcpy(
76*da0073e9SAndroid Build Coastguard Worker           sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
77*da0073e9SAndroid Build Coastguard Worker       std::memcpy(
78*da0073e9SAndroid Build Coastguard Worker           strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
79*da0073e9SAndroid Build Coastguard Worker     }
80*da0073e9SAndroid Build Coastguard Worker     dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
81*da0073e9SAndroid Build Coastguard Worker   }
82*da0073e9SAndroid Build Coastguard Worker };
83*da0073e9SAndroid Build Coastguard Worker 
84*da0073e9SAndroid Build Coastguard Worker template <typename T>
85*da0073e9SAndroid Build Coastguard Worker struct strided_tensor_iter {
86*da0073e9SAndroid Build Coastguard Worker  private:
87*da0073e9SAndroid Build Coastguard Worker  public:
88*da0073e9SAndroid Build Coastguard Worker   T* data_ = NULL;
89*da0073e9SAndroid Build Coastguard Worker   int64_t dim_;
90*da0073e9SAndroid Build Coastguard Worker 
91*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> counter_;
92*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> sizes_;
93*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> strides_;
94*da0073e9SAndroid Build Coastguard Worker 
95*da0073e9SAndroid Build Coastguard Worker   strided_tensor_iter(strided_tensor_iter const&) = delete;
96*da0073e9SAndroid Build Coastguard Worker   void operator=(strided_tensor_iter const& x) = delete;
97*da0073e9SAndroid Build Coastguard Worker   strided_tensor_iter(strided_tensor_iter&&) = default;
strided_tensor_iterstrided_tensor_iter98*da0073e9SAndroid Build Coastguard Worker   strided_tensor_iter(Tensor& tensor)
99*da0073e9SAndroid Build Coastguard Worker       : data_(tensor.data_ptr<T>()),
100*da0073e9SAndroid Build Coastguard Worker         dim_(tensor.ndimension()),
101*da0073e9SAndroid Build Coastguard Worker         counter_(dim_, 0),
102*da0073e9SAndroid Build Coastguard Worker         sizes_(tensor.sizes().vec()),
103*da0073e9SAndroid Build Coastguard Worker         strides_(tensor.strides().vec()) {
104*da0073e9SAndroid Build Coastguard Worker     dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
105*da0073e9SAndroid Build Coastguard Worker   }
106*da0073e9SAndroid Build Coastguard Worker };
107*da0073e9SAndroid Build Coastguard Worker 
_all_equal_numel(at::ArrayRef<Tensor> tensors)108*da0073e9SAndroid Build Coastguard Worker inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
109*da0073e9SAndroid Build Coastguard Worker   if (tensors.empty())
110*da0073e9SAndroid Build Coastguard Worker     return true;
111*da0073e9SAndroid Build Coastguard Worker   int64_t all_numel = tensors[0].numel();
112*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(1, tensors.size())) {
113*da0073e9SAndroid Build Coastguard Worker     if (tensors[i].numel() != all_numel)
114*da0073e9SAndroid Build Coastguard Worker       return false;
115*da0073e9SAndroid Build Coastguard Worker   }
116*da0073e9SAndroid Build Coastguard Worker   return true;
117*da0073e9SAndroid Build Coastguard Worker }
118*da0073e9SAndroid Build Coastguard Worker 
_all_equal_numel_error(at::ArrayRef<Tensor> tensors)119*da0073e9SAndroid Build Coastguard Worker inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
120*da0073e9SAndroid Build Coastguard Worker   std::ostringstream oss;
121*da0073e9SAndroid Build Coastguard Worker   oss << "inconsistent tensor size, expected ";
122*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 0; i < tensors.size() - 1; i++) {
123*da0073e9SAndroid Build Coastguard Worker     oss << tensors[i].sizes() << ", ";
124*da0073e9SAndroid Build Coastguard Worker   }
125*da0073e9SAndroid Build Coastguard Worker   oss << "and " << tensors[tensors.size() - 1].sizes()
126*da0073e9SAndroid Build Coastguard Worker       << " to have the same number of elements, but got ";
127*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 0; i < tensors.size() - 1; i++) {
128*da0073e9SAndroid Build Coastguard Worker     oss << tensors[i].numel() << ", ";
129*da0073e9SAndroid Build Coastguard Worker   }
130*da0073e9SAndroid Build Coastguard Worker   oss << "and " << tensors[tensors.size() - 1].numel()
131*da0073e9SAndroid Build Coastguard Worker       << " elements respectively";
132*da0073e9SAndroid Build Coastguard Worker   return oss.str();
133*da0073e9SAndroid Build Coastguard Worker }
134*da0073e9SAndroid Build Coastguard Worker 
_apply_preamble(ArrayRef<Tensor> tensors)135*da0073e9SAndroid Build Coastguard Worker inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
136*da0073e9SAndroid Build Coastguard Worker   checkDeviceType("CPU_tensor_apply", tensors, kCPU);
137*da0073e9SAndroid Build Coastguard Worker   checkLayout("CPU_tensor_apply", tensors, kStrided);
138*da0073e9SAndroid Build Coastguard Worker   if (!_all_equal_numel(tensors))
139*da0073e9SAndroid Build Coastguard Worker     AT_ERROR(_all_equal_numel_error(tensors));
140*da0073e9SAndroid Build Coastguard Worker   // An empty tensor has no elements
141*da0073e9SAndroid Build Coastguard Worker   for (auto& t : tensors)
142*da0073e9SAndroid Build Coastguard Worker     if (t.numel() == 0)
143*da0073e9SAndroid Build Coastguard Worker       return false;
144*da0073e9SAndroid Build Coastguard Worker   return true;
145*da0073e9SAndroid Build Coastguard Worker }
146*da0073e9SAndroid Build Coastguard Worker 
_max_dim_tensors(ArrayRef<Tensor> tensors)147*da0073e9SAndroid Build Coastguard Worker inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
148*da0073e9SAndroid Build Coastguard Worker   int64_t dim = 0;
149*da0073e9SAndroid Build Coastguard Worker   for (auto& t : tensors)
150*da0073e9SAndroid Build Coastguard Worker     dim = std::max(dim, t.ndimension());
151*da0073e9SAndroid Build Coastguard Worker   return dim;
152*da0073e9SAndroid Build Coastguard Worker }
153*da0073e9SAndroid Build Coastguard Worker 
iterate(int64_t)154*da0073e9SAndroid Build Coastguard Worker inline void iterate(int64_t /*size*/){};
155*da0073e9SAndroid Build Coastguard Worker 
156*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
iterate(int64_t size,Arg & iter,Args &...iter_tail)157*da0073e9SAndroid Build Coastguard Worker inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
158*da0073e9SAndroid Build Coastguard Worker   iter.counter_[iter.dim_ - 1] += size;
159*da0073e9SAndroid Build Coastguard Worker   iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
160*da0073e9SAndroid Build Coastguard Worker   iterate(size, iter_tail...);
161*da0073e9SAndroid Build Coastguard Worker }
162*da0073e9SAndroid Build Coastguard Worker 
iterate_continue()163*da0073e9SAndroid Build Coastguard Worker inline bool iterate_continue() {
164*da0073e9SAndroid Build Coastguard Worker   return true;
165*da0073e9SAndroid Build Coastguard Worker };
166*da0073e9SAndroid Build Coastguard Worker 
167*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
iterate_continue(Arg & iter,Args &...iter_tail)168*da0073e9SAndroid Build Coastguard Worker inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
169*da0073e9SAndroid Build Coastguard Worker   return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
170*da0073e9SAndroid Build Coastguard Worker       iterate_continue(iter_tail...);
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker 
max_iterate_size()173*da0073e9SAndroid Build Coastguard Worker inline int64_t max_iterate_size() {
174*da0073e9SAndroid Build Coastguard Worker   return std::numeric_limits<int64_t>::max();
175*da0073e9SAndroid Build Coastguard Worker };
176*da0073e9SAndroid Build Coastguard Worker 
177*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
max_iterate_size(Arg & iter,Args &...iter_tail)178*da0073e9SAndroid Build Coastguard Worker inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
179*da0073e9SAndroid Build Coastguard Worker   return std::min(
180*da0073e9SAndroid Build Coastguard Worker       (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
181*da0073e9SAndroid Build Coastguard Worker       max_iterate_size(iter_tail...));
182*da0073e9SAndroid Build Coastguard Worker }
183*da0073e9SAndroid Build Coastguard Worker 
iterate_overflow()184*da0073e9SAndroid Build Coastguard Worker inline void iterate_overflow(){};
185*da0073e9SAndroid Build Coastguard Worker 
186*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
iterate_overflow(Arg & iter,Args &...iter_tail)187*da0073e9SAndroid Build Coastguard Worker inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
188*da0073e9SAndroid Build Coastguard Worker   if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
189*da0073e9SAndroid Build Coastguard Worker     for (int64_t i = iter.dim_ - 1; i > 0; i--) {
190*da0073e9SAndroid Build Coastguard Worker       if (iter.counter_[i] == iter.sizes_[i]) {
191*da0073e9SAndroid Build Coastguard Worker         iter.counter_[i] = 0;
192*da0073e9SAndroid Build Coastguard Worker         iter.counter_[i - 1]++;
193*da0073e9SAndroid Build Coastguard Worker         iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
194*da0073e9SAndroid Build Coastguard Worker             iter.strides_[i - 1];
195*da0073e9SAndroid Build Coastguard Worker       }
196*da0073e9SAndroid Build Coastguard Worker     }
197*da0073e9SAndroid Build Coastguard Worker   }
198*da0073e9SAndroid Build Coastguard Worker   iterate_overflow(iter_tail...);
199*da0073e9SAndroid Build Coastguard Worker }
200*da0073e9SAndroid Build Coastguard Worker 
forward(int64_t)201*da0073e9SAndroid Build Coastguard Worker inline void forward(int64_t /*offset*/){};
202*da0073e9SAndroid Build Coastguard Worker 
203*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
forward(int64_t offset,Arg & iter,Args &...iter_tail)204*da0073e9SAndroid Build Coastguard Worker inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
205*da0073e9SAndroid Build Coastguard Worker   int64_t multi = offset;
206*da0073e9SAndroid Build Coastguard Worker   for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
207*da0073e9SAndroid Build Coastguard Worker     int64_t inc = multi % iter.sizes_[i];
208*da0073e9SAndroid Build Coastguard Worker     multi = multi / iter.sizes_[i];
209*da0073e9SAndroid Build Coastguard Worker     iter.data_ = iter.data_ + inc * iter.strides_[i];
210*da0073e9SAndroid Build Coastguard Worker     iter.counter_[i] += inc;
211*da0073e9SAndroid Build Coastguard Worker   }
212*da0073e9SAndroid Build Coastguard Worker   forward(offset, iter_tail...);
213*da0073e9SAndroid Build Coastguard Worker }
214*da0073e9SAndroid Build Coastguard Worker 
max_dim()215*da0073e9SAndroid Build Coastguard Worker inline int64_t max_dim() {
216*da0073e9SAndroid Build Coastguard Worker   return 0;
217*da0073e9SAndroid Build Coastguard Worker }
218*da0073e9SAndroid Build Coastguard Worker 
219*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
max_dim(Arg & iter,Args &...iter_tail)220*da0073e9SAndroid Build Coastguard Worker inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
221*da0073e9SAndroid Build Coastguard Worker   return std::max(iter.dim_, max_dim(iter_tail...));
222*da0073e9SAndroid Build Coastguard Worker }
223*da0073e9SAndroid Build Coastguard Worker 
apply_op()224*da0073e9SAndroid Build Coastguard Worker inline void apply_op(){};
225*da0073e9SAndroid Build Coastguard Worker 
226*da0073e9SAndroid Build Coastguard Worker template <typename Op, typename... Args>
apply_op(int64_t numel,int64_t offset,const Op & op,Args...iters)227*da0073e9SAndroid Build Coastguard Worker inline void apply_op(
228*da0073e9SAndroid Build Coastguard Worker     int64_t numel,
229*da0073e9SAndroid Build Coastguard Worker     int64_t offset,
230*da0073e9SAndroid Build Coastguard Worker     const Op& op,
231*da0073e9SAndroid Build Coastguard Worker     Args... iters) {
232*da0073e9SAndroid Build Coastguard Worker   // For 0-dim tensors
233*da0073e9SAndroid Build Coastguard Worker   if (numel == 1 && max_dim(iters...) == 0) {
234*da0073e9SAndroid Build Coastguard Worker     op(*iters.data_...);
235*da0073e9SAndroid Build Coastguard Worker     return;
236*da0073e9SAndroid Build Coastguard Worker   }
237*da0073e9SAndroid Build Coastguard Worker   if (offset > 0)
238*da0073e9SAndroid Build Coastguard Worker     forward(offset, iters...);
239*da0073e9SAndroid Build Coastguard Worker   // Splitting this into chunks helps the compiler create faster assembly
240*da0073e9SAndroid Build Coastguard Worker   for (int64_t i = 0; i < numel;) {
241*da0073e9SAndroid Build Coastguard Worker     for (; iterate_continue(iters...) && i < numel;) {
242*da0073e9SAndroid Build Coastguard Worker       op(*iters.data_...);
243*da0073e9SAndroid Build Coastguard Worker       iterate(1, iters...);
244*da0073e9SAndroid Build Coastguard Worker       i++;
245*da0073e9SAndroid Build Coastguard Worker     }
246*da0073e9SAndroid Build Coastguard Worker     iterate_overflow(iters...);
247*da0073e9SAndroid Build Coastguard Worker   }
248*da0073e9SAndroid Build Coastguard Worker }
249*da0073e9SAndroid Build Coastguard Worker 
250*da0073e9SAndroid Build Coastguard Worker /*
251*da0073e9SAndroid Build Coastguard Worker   Apply a pointwise operator to sequence of tensors
252*da0073e9SAndroid Build Coastguard Worker 
253*da0073e9SAndroid Build Coastguard Worker   The calling convention for op is a function/functor that takes the same
254*da0073e9SAndroid Build Coastguard Worker   number of pointers of type scalar as the number of given tensors. For example,
255*da0073e9SAndroid Build Coastguard Worker   to compute a = b * c, op would be of the form:
256*da0073e9SAndroid Build Coastguard Worker   [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
257*da0073e9SAndroid Build Coastguard Worker   b_val[0] * c_val[0]; };
258*da0073e9SAndroid Build Coastguard Worker */
259*da0073e9SAndroid Build Coastguard Worker 
260*da0073e9SAndroid Build Coastguard Worker template <typename scalar1, typename scalar2, typename Op>
CPU_tensor_apply2(Tensor tensor1,Tensor tensor2,const Op op)261*da0073e9SAndroid Build Coastguard Worker inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
262*da0073e9SAndroid Build Coastguard Worker   if (!_apply_preamble({tensor1, tensor2}))
263*da0073e9SAndroid Build Coastguard Worker     return;
264*da0073e9SAndroid Build Coastguard Worker   if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
265*da0073e9SAndroid Build Coastguard Worker     apply_op(
266*da0073e9SAndroid Build Coastguard Worker         tensor1.numel(),
267*da0073e9SAndroid Build Coastguard Worker         0,
268*da0073e9SAndroid Build Coastguard Worker         op,
269*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar1, 8>(tensor1),
270*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar2, 8>(tensor2));
271*da0073e9SAndroid Build Coastguard Worker   } else {
272*da0073e9SAndroid Build Coastguard Worker     apply_op(
273*da0073e9SAndroid Build Coastguard Worker         tensor1.numel(),
274*da0073e9SAndroid Build Coastguard Worker         0,
275*da0073e9SAndroid Build Coastguard Worker         op,
276*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar1>(tensor1),
277*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar2>(tensor2));
278*da0073e9SAndroid Build Coastguard Worker   }
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker 
281*da0073e9SAndroid Build Coastguard Worker template <typename scalar1, typename scalar2, typename scalar3, typename Op>
CPU_tensor_apply3(Tensor tensor1,Tensor tensor2,Tensor tensor3,const Op op)282*da0073e9SAndroid Build Coastguard Worker inline void CPU_tensor_apply3(
283*da0073e9SAndroid Build Coastguard Worker     Tensor tensor1,
284*da0073e9SAndroid Build Coastguard Worker     Tensor tensor2,
285*da0073e9SAndroid Build Coastguard Worker     Tensor tensor3,
286*da0073e9SAndroid Build Coastguard Worker     const Op op) {
287*da0073e9SAndroid Build Coastguard Worker   if (!_apply_preamble({tensor1, tensor2, tensor3}))
288*da0073e9SAndroid Build Coastguard Worker     return;
289*da0073e9SAndroid Build Coastguard Worker   if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
290*da0073e9SAndroid Build Coastguard Worker     apply_op(
291*da0073e9SAndroid Build Coastguard Worker         tensor1.numel(),
292*da0073e9SAndroid Build Coastguard Worker         0,
293*da0073e9SAndroid Build Coastguard Worker         op,
294*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar1, 8>(tensor1),
295*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar2, 8>(tensor2),
296*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar3, 8>(tensor3));
297*da0073e9SAndroid Build Coastguard Worker   } else {
298*da0073e9SAndroid Build Coastguard Worker     apply_op(
299*da0073e9SAndroid Build Coastguard Worker         tensor1.numel(),
300*da0073e9SAndroid Build Coastguard Worker         0,
301*da0073e9SAndroid Build Coastguard Worker         op,
302*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar1>(tensor1),
303*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar2>(tensor2),
304*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar3>(tensor3));
305*da0073e9SAndroid Build Coastguard Worker   }
306*da0073e9SAndroid Build Coastguard Worker }
307*da0073e9SAndroid Build Coastguard Worker 
308*da0073e9SAndroid Build Coastguard Worker template <
309*da0073e9SAndroid Build Coastguard Worker     typename scalar1,
310*da0073e9SAndroid Build Coastguard Worker     typename scalar2,
311*da0073e9SAndroid Build Coastguard Worker     typename scalar3,
312*da0073e9SAndroid Build Coastguard Worker     typename scalar4,
313*da0073e9SAndroid Build Coastguard Worker     typename Op>
CPU_tensor_apply4(Tensor tensor1,Tensor tensor2,Tensor tensor3,Tensor tensor4,const Op op)314*da0073e9SAndroid Build Coastguard Worker inline void CPU_tensor_apply4(
315*da0073e9SAndroid Build Coastguard Worker     Tensor tensor1,
316*da0073e9SAndroid Build Coastguard Worker     Tensor tensor2,
317*da0073e9SAndroid Build Coastguard Worker     Tensor tensor3,
318*da0073e9SAndroid Build Coastguard Worker     Tensor tensor4,
319*da0073e9SAndroid Build Coastguard Worker     const Op op) {
320*da0073e9SAndroid Build Coastguard Worker   if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
321*da0073e9SAndroid Build Coastguard Worker     return;
322*da0073e9SAndroid Build Coastguard Worker   if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
323*da0073e9SAndroid Build Coastguard Worker     apply_op(
324*da0073e9SAndroid Build Coastguard Worker         tensor1.numel(),
325*da0073e9SAndroid Build Coastguard Worker         0,
326*da0073e9SAndroid Build Coastguard Worker         op,
327*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar1, 8>(tensor1),
328*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar2, 8>(tensor2),
329*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar3, 8>(tensor3),
330*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter_fixed<scalar4, 8>(tensor4));
331*da0073e9SAndroid Build Coastguard Worker   } else {
332*da0073e9SAndroid Build Coastguard Worker     apply_op(
333*da0073e9SAndroid Build Coastguard Worker         tensor1.numel(),
334*da0073e9SAndroid Build Coastguard Worker         0,
335*da0073e9SAndroid Build Coastguard Worker         op,
336*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar1>(tensor1),
337*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar2>(tensor2),
338*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar3>(tensor3),
339*da0073e9SAndroid Build Coastguard Worker         strided_tensor_iter<scalar4>(tensor4));
340*da0073e9SAndroid Build Coastguard Worker   }
341*da0073e9SAndroid Build Coastguard Worker }
342*da0073e9SAndroid Build Coastguard Worker 
343*da0073e9SAndroid Build Coastguard Worker } // namespace at
344