1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/CollapseDims.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/TensorUtils.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
7*da0073e9SAndroid Build Coastguard Worker #include <cstring>
8*da0073e9SAndroid Build Coastguard Worker #include <limits>
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker namespace at {
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker /*
13*da0073e9SAndroid Build Coastguard Worker * The basic strategy for apply is as follows:
14*da0073e9SAndroid Build Coastguard Worker *
15*da0073e9SAndroid Build Coastguard Worker * 1. Starting with the outermost index, loop until we reach a dimension where
16*da0073e9SAndroid Build Coastguard Worker * the data is no longer contiguous, i.e. the stride at that dimension is not
17*da0073e9SAndroid Build Coastguard Worker * equal to the size of the tensor defined by the outer dimensions. Let's call
18*da0073e9SAndroid Build Coastguard Worker * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
19*da0073e9SAndroid Build Coastguard Worker * A is equal to the entire Tensor. Let's call the inner tensor B.
20*da0073e9SAndroid Build Coastguard Worker *
21*da0073e9SAndroid Build Coastguard Worker * 2. We loop through the indices in B, starting at its outermost dimension. For
22*da0073e9SAndroid Build Coastguard Worker * example, if B is a 2x2 matrix, then we do:
23*da0073e9SAndroid Build Coastguard Worker *
24*da0073e9SAndroid Build Coastguard Worker * B[0][0]
25*da0073e9SAndroid Build Coastguard Worker * B[0][1]
26*da0073e9SAndroid Build Coastguard Worker * B[1][0]
27*da0073e9SAndroid Build Coastguard Worker * B[1][1]
28*da0073e9SAndroid Build Coastguard Worker *
29*da0073e9SAndroid Build Coastguard Worker * We set the offset into the underlying storage as (storageOffset + stride_B *
30*da0073e9SAndroid Build Coastguard Worker * index_B), i.e. basically we compute the offset into the storage as we would
31*da0073e9SAndroid Build Coastguard Worker * normally for a Tensor. But because we are guaranteed the subsequent data is
32*da0073e9SAndroid Build Coastguard Worker * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
33*da0073e9SAndroid Build Coastguard Worker * the operation, without having to follow the order described by the strides of
34*da0073e9SAndroid Build Coastguard Worker * A.
35*da0073e9SAndroid Build Coastguard Worker *
36*da0073e9SAndroid Build Coastguard Worker * 3. As an optimization, we merge dimensions of A that are contiguous in
37*da0073e9SAndroid Build Coastguard Worker * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
38*da0073e9SAndroid Build Coastguard Worker * then the first two dimensions can be merged for the purposes of APPLY,
39*da0073e9SAndroid Build Coastguard Worker * reducing the number of nested loops.
40*da0073e9SAndroid Build Coastguard Worker */
41*da0073e9SAndroid Build Coastguard Worker
sort_strides(Tensor & tensor_)42*da0073e9SAndroid Build Coastguard Worker inline Tensor sort_strides(Tensor& tensor_) {
43*da0073e9SAndroid Build Coastguard Worker IntArrayRef strides = tensor_.strides();
44*da0073e9SAndroid Build Coastguard Worker std::vector<int64_t> indices;
45*da0073e9SAndroid Build Coastguard Worker indices.reserve(tensor_.ndimension());
46*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(tensor_.ndimension())) {
47*da0073e9SAndroid Build Coastguard Worker indices.push_back(i);
48*da0073e9SAndroid Build Coastguard Worker }
49*da0073e9SAndroid Build Coastguard Worker std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
50*da0073e9SAndroid Build Coastguard Worker return strides[i1] > strides[i2];
51*da0073e9SAndroid Build Coastguard Worker });
52*da0073e9SAndroid Build Coastguard Worker Tensor tensor = tensor_.permute(indices);
53*da0073e9SAndroid Build Coastguard Worker return tensor;
54*da0073e9SAndroid Build Coastguard Worker }
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker template <typename T, int N>
57*da0073e9SAndroid Build Coastguard Worker struct strided_tensor_iter_fixed {
58*da0073e9SAndroid Build Coastguard Worker public:
59*da0073e9SAndroid Build Coastguard Worker T* data_ = NULL;
60*da0073e9SAndroid Build Coastguard Worker int64_t dim_ = 0;
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker int64_t counter_[N] = {0};
63*da0073e9SAndroid Build Coastguard Worker int64_t sizes_[N] = {0};
64*da0073e9SAndroid Build Coastguard Worker int64_t strides_[N] = {0};
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
67*da0073e9SAndroid Build Coastguard Worker void operator=(strided_tensor_iter_fixed const& x) = delete;
68*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
69*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed(
70*da0073e9SAndroid Build Coastguard Worker Tensor& tensor,
71*da0073e9SAndroid Build Coastguard Worker C10_UNUSED bool sort_strides = false)
72*da0073e9SAndroid Build Coastguard Worker : data_(tensor.data_ptr<T>()) {
73*da0073e9SAndroid Build Coastguard Worker std::memset(counter_, 0, sizeof(int64_t) * N);
74*da0073e9SAndroid Build Coastguard Worker if (tensor.dim() > 0) {
75*da0073e9SAndroid Build Coastguard Worker std::memcpy(
76*da0073e9SAndroid Build Coastguard Worker sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
77*da0073e9SAndroid Build Coastguard Worker std::memcpy(
78*da0073e9SAndroid Build Coastguard Worker strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
79*da0073e9SAndroid Build Coastguard Worker }
80*da0073e9SAndroid Build Coastguard Worker dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
81*da0073e9SAndroid Build Coastguard Worker }
82*da0073e9SAndroid Build Coastguard Worker };
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker template <typename T>
85*da0073e9SAndroid Build Coastguard Worker struct strided_tensor_iter {
86*da0073e9SAndroid Build Coastguard Worker private:
87*da0073e9SAndroid Build Coastguard Worker public:
88*da0073e9SAndroid Build Coastguard Worker T* data_ = NULL;
89*da0073e9SAndroid Build Coastguard Worker int64_t dim_;
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker std::vector<int64_t> counter_;
92*da0073e9SAndroid Build Coastguard Worker std::vector<int64_t> sizes_;
93*da0073e9SAndroid Build Coastguard Worker std::vector<int64_t> strides_;
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter(strided_tensor_iter const&) = delete;
96*da0073e9SAndroid Build Coastguard Worker void operator=(strided_tensor_iter const& x) = delete;
97*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter(strided_tensor_iter&&) = default;
strided_tensor_iterstrided_tensor_iter98*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter(Tensor& tensor)
99*da0073e9SAndroid Build Coastguard Worker : data_(tensor.data_ptr<T>()),
100*da0073e9SAndroid Build Coastguard Worker dim_(tensor.ndimension()),
101*da0073e9SAndroid Build Coastguard Worker counter_(dim_, 0),
102*da0073e9SAndroid Build Coastguard Worker sizes_(tensor.sizes().vec()),
103*da0073e9SAndroid Build Coastguard Worker strides_(tensor.strides().vec()) {
104*da0073e9SAndroid Build Coastguard Worker dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
105*da0073e9SAndroid Build Coastguard Worker }
106*da0073e9SAndroid Build Coastguard Worker };
107*da0073e9SAndroid Build Coastguard Worker
_all_equal_numel(at::ArrayRef<Tensor> tensors)108*da0073e9SAndroid Build Coastguard Worker inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
109*da0073e9SAndroid Build Coastguard Worker if (tensors.empty())
110*da0073e9SAndroid Build Coastguard Worker return true;
111*da0073e9SAndroid Build Coastguard Worker int64_t all_numel = tensors[0].numel();
112*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(1, tensors.size())) {
113*da0073e9SAndroid Build Coastguard Worker if (tensors[i].numel() != all_numel)
114*da0073e9SAndroid Build Coastguard Worker return false;
115*da0073e9SAndroid Build Coastguard Worker }
116*da0073e9SAndroid Build Coastguard Worker return true;
117*da0073e9SAndroid Build Coastguard Worker }
118*da0073e9SAndroid Build Coastguard Worker
_all_equal_numel_error(at::ArrayRef<Tensor> tensors)119*da0073e9SAndroid Build Coastguard Worker inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
120*da0073e9SAndroid Build Coastguard Worker std::ostringstream oss;
121*da0073e9SAndroid Build Coastguard Worker oss << "inconsistent tensor size, expected ";
122*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < tensors.size() - 1; i++) {
123*da0073e9SAndroid Build Coastguard Worker oss << tensors[i].sizes() << ", ";
124*da0073e9SAndroid Build Coastguard Worker }
125*da0073e9SAndroid Build Coastguard Worker oss << "and " << tensors[tensors.size() - 1].sizes()
126*da0073e9SAndroid Build Coastguard Worker << " to have the same number of elements, but got ";
127*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < tensors.size() - 1; i++) {
128*da0073e9SAndroid Build Coastguard Worker oss << tensors[i].numel() << ", ";
129*da0073e9SAndroid Build Coastguard Worker }
130*da0073e9SAndroid Build Coastguard Worker oss << "and " << tensors[tensors.size() - 1].numel()
131*da0073e9SAndroid Build Coastguard Worker << " elements respectively";
132*da0073e9SAndroid Build Coastguard Worker return oss.str();
133*da0073e9SAndroid Build Coastguard Worker }
134*da0073e9SAndroid Build Coastguard Worker
_apply_preamble(ArrayRef<Tensor> tensors)135*da0073e9SAndroid Build Coastguard Worker inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
136*da0073e9SAndroid Build Coastguard Worker checkDeviceType("CPU_tensor_apply", tensors, kCPU);
137*da0073e9SAndroid Build Coastguard Worker checkLayout("CPU_tensor_apply", tensors, kStrided);
138*da0073e9SAndroid Build Coastguard Worker if (!_all_equal_numel(tensors))
139*da0073e9SAndroid Build Coastguard Worker AT_ERROR(_all_equal_numel_error(tensors));
140*da0073e9SAndroid Build Coastguard Worker // An empty tensor has no elements
141*da0073e9SAndroid Build Coastguard Worker for (auto& t : tensors)
142*da0073e9SAndroid Build Coastguard Worker if (t.numel() == 0)
143*da0073e9SAndroid Build Coastguard Worker return false;
144*da0073e9SAndroid Build Coastguard Worker return true;
145*da0073e9SAndroid Build Coastguard Worker }
146*da0073e9SAndroid Build Coastguard Worker
_max_dim_tensors(ArrayRef<Tensor> tensors)147*da0073e9SAndroid Build Coastguard Worker inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
148*da0073e9SAndroid Build Coastguard Worker int64_t dim = 0;
149*da0073e9SAndroid Build Coastguard Worker for (auto& t : tensors)
150*da0073e9SAndroid Build Coastguard Worker dim = std::max(dim, t.ndimension());
151*da0073e9SAndroid Build Coastguard Worker return dim;
152*da0073e9SAndroid Build Coastguard Worker }
153*da0073e9SAndroid Build Coastguard Worker
iterate(int64_t)154*da0073e9SAndroid Build Coastguard Worker inline void iterate(int64_t /*size*/){};
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
iterate(int64_t size,Arg & iter,Args &...iter_tail)157*da0073e9SAndroid Build Coastguard Worker inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
158*da0073e9SAndroid Build Coastguard Worker iter.counter_[iter.dim_ - 1] += size;
159*da0073e9SAndroid Build Coastguard Worker iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
160*da0073e9SAndroid Build Coastguard Worker iterate(size, iter_tail...);
161*da0073e9SAndroid Build Coastguard Worker }
162*da0073e9SAndroid Build Coastguard Worker
iterate_continue()163*da0073e9SAndroid Build Coastguard Worker inline bool iterate_continue() {
164*da0073e9SAndroid Build Coastguard Worker return true;
165*da0073e9SAndroid Build Coastguard Worker };
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
iterate_continue(Arg & iter,Args &...iter_tail)168*da0073e9SAndroid Build Coastguard Worker inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
169*da0073e9SAndroid Build Coastguard Worker return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
170*da0073e9SAndroid Build Coastguard Worker iterate_continue(iter_tail...);
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker
max_iterate_size()173*da0073e9SAndroid Build Coastguard Worker inline int64_t max_iterate_size() {
174*da0073e9SAndroid Build Coastguard Worker return std::numeric_limits<int64_t>::max();
175*da0073e9SAndroid Build Coastguard Worker };
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
max_iterate_size(Arg & iter,Args &...iter_tail)178*da0073e9SAndroid Build Coastguard Worker inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
179*da0073e9SAndroid Build Coastguard Worker return std::min(
180*da0073e9SAndroid Build Coastguard Worker (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
181*da0073e9SAndroid Build Coastguard Worker max_iterate_size(iter_tail...));
182*da0073e9SAndroid Build Coastguard Worker }
183*da0073e9SAndroid Build Coastguard Worker
iterate_overflow()184*da0073e9SAndroid Build Coastguard Worker inline void iterate_overflow(){};
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
iterate_overflow(Arg & iter,Args &...iter_tail)187*da0073e9SAndroid Build Coastguard Worker inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
188*da0073e9SAndroid Build Coastguard Worker if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
189*da0073e9SAndroid Build Coastguard Worker for (int64_t i = iter.dim_ - 1; i > 0; i--) {
190*da0073e9SAndroid Build Coastguard Worker if (iter.counter_[i] == iter.sizes_[i]) {
191*da0073e9SAndroid Build Coastguard Worker iter.counter_[i] = 0;
192*da0073e9SAndroid Build Coastguard Worker iter.counter_[i - 1]++;
193*da0073e9SAndroid Build Coastguard Worker iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
194*da0073e9SAndroid Build Coastguard Worker iter.strides_[i - 1];
195*da0073e9SAndroid Build Coastguard Worker }
196*da0073e9SAndroid Build Coastguard Worker }
197*da0073e9SAndroid Build Coastguard Worker }
198*da0073e9SAndroid Build Coastguard Worker iterate_overflow(iter_tail...);
199*da0073e9SAndroid Build Coastguard Worker }
200*da0073e9SAndroid Build Coastguard Worker
forward(int64_t)201*da0073e9SAndroid Build Coastguard Worker inline void forward(int64_t /*offset*/){};
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
forward(int64_t offset,Arg & iter,Args &...iter_tail)204*da0073e9SAndroid Build Coastguard Worker inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
205*da0073e9SAndroid Build Coastguard Worker int64_t multi = offset;
206*da0073e9SAndroid Build Coastguard Worker for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
207*da0073e9SAndroid Build Coastguard Worker int64_t inc = multi % iter.sizes_[i];
208*da0073e9SAndroid Build Coastguard Worker multi = multi / iter.sizes_[i];
209*da0073e9SAndroid Build Coastguard Worker iter.data_ = iter.data_ + inc * iter.strides_[i];
210*da0073e9SAndroid Build Coastguard Worker iter.counter_[i] += inc;
211*da0073e9SAndroid Build Coastguard Worker }
212*da0073e9SAndroid Build Coastguard Worker forward(offset, iter_tail...);
213*da0073e9SAndroid Build Coastguard Worker }
214*da0073e9SAndroid Build Coastguard Worker
max_dim()215*da0073e9SAndroid Build Coastguard Worker inline int64_t max_dim() {
216*da0073e9SAndroid Build Coastguard Worker return 0;
217*da0073e9SAndroid Build Coastguard Worker }
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker template <typename Arg, typename... Args>
max_dim(Arg & iter,Args &...iter_tail)220*da0073e9SAndroid Build Coastguard Worker inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
221*da0073e9SAndroid Build Coastguard Worker return std::max(iter.dim_, max_dim(iter_tail...));
222*da0073e9SAndroid Build Coastguard Worker }
223*da0073e9SAndroid Build Coastguard Worker
apply_op()224*da0073e9SAndroid Build Coastguard Worker inline void apply_op(){};
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker template <typename Op, typename... Args>
apply_op(int64_t numel,int64_t offset,const Op & op,Args...iters)227*da0073e9SAndroid Build Coastguard Worker inline void apply_op(
228*da0073e9SAndroid Build Coastguard Worker int64_t numel,
229*da0073e9SAndroid Build Coastguard Worker int64_t offset,
230*da0073e9SAndroid Build Coastguard Worker const Op& op,
231*da0073e9SAndroid Build Coastguard Worker Args... iters) {
232*da0073e9SAndroid Build Coastguard Worker // For 0-dim tensors
233*da0073e9SAndroid Build Coastguard Worker if (numel == 1 && max_dim(iters...) == 0) {
234*da0073e9SAndroid Build Coastguard Worker op(*iters.data_...);
235*da0073e9SAndroid Build Coastguard Worker return;
236*da0073e9SAndroid Build Coastguard Worker }
237*da0073e9SAndroid Build Coastguard Worker if (offset > 0)
238*da0073e9SAndroid Build Coastguard Worker forward(offset, iters...);
239*da0073e9SAndroid Build Coastguard Worker // Splitting this into chunks helps the compiler create faster assembly
240*da0073e9SAndroid Build Coastguard Worker for (int64_t i = 0; i < numel;) {
241*da0073e9SAndroid Build Coastguard Worker for (; iterate_continue(iters...) && i < numel;) {
242*da0073e9SAndroid Build Coastguard Worker op(*iters.data_...);
243*da0073e9SAndroid Build Coastguard Worker iterate(1, iters...);
244*da0073e9SAndroid Build Coastguard Worker i++;
245*da0073e9SAndroid Build Coastguard Worker }
246*da0073e9SAndroid Build Coastguard Worker iterate_overflow(iters...);
247*da0073e9SAndroid Build Coastguard Worker }
248*da0073e9SAndroid Build Coastguard Worker }
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker /*
251*da0073e9SAndroid Build Coastguard Worker Apply a pointwise operator to sequence of tensors
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker The calling convention for op is a function/functor that takes the same
254*da0073e9SAndroid Build Coastguard Worker number of pointers of type scalar as the number of given tensors. For example,
255*da0073e9SAndroid Build Coastguard Worker to compute a = b * c, op would be of the form:
256*da0073e9SAndroid Build Coastguard Worker [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
257*da0073e9SAndroid Build Coastguard Worker b_val[0] * c_val[0]; };
258*da0073e9SAndroid Build Coastguard Worker */
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker template <typename scalar1, typename scalar2, typename Op>
CPU_tensor_apply2(Tensor tensor1,Tensor tensor2,const Op op)261*da0073e9SAndroid Build Coastguard Worker inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
262*da0073e9SAndroid Build Coastguard Worker if (!_apply_preamble({tensor1, tensor2}))
263*da0073e9SAndroid Build Coastguard Worker return;
264*da0073e9SAndroid Build Coastguard Worker if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
265*da0073e9SAndroid Build Coastguard Worker apply_op(
266*da0073e9SAndroid Build Coastguard Worker tensor1.numel(),
267*da0073e9SAndroid Build Coastguard Worker 0,
268*da0073e9SAndroid Build Coastguard Worker op,
269*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar1, 8>(tensor1),
270*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar2, 8>(tensor2));
271*da0073e9SAndroid Build Coastguard Worker } else {
272*da0073e9SAndroid Build Coastguard Worker apply_op(
273*da0073e9SAndroid Build Coastguard Worker tensor1.numel(),
274*da0073e9SAndroid Build Coastguard Worker 0,
275*da0073e9SAndroid Build Coastguard Worker op,
276*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar1>(tensor1),
277*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar2>(tensor2));
278*da0073e9SAndroid Build Coastguard Worker }
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker template <typename scalar1, typename scalar2, typename scalar3, typename Op>
CPU_tensor_apply3(Tensor tensor1,Tensor tensor2,Tensor tensor3,const Op op)282*da0073e9SAndroid Build Coastguard Worker inline void CPU_tensor_apply3(
283*da0073e9SAndroid Build Coastguard Worker Tensor tensor1,
284*da0073e9SAndroid Build Coastguard Worker Tensor tensor2,
285*da0073e9SAndroid Build Coastguard Worker Tensor tensor3,
286*da0073e9SAndroid Build Coastguard Worker const Op op) {
287*da0073e9SAndroid Build Coastguard Worker if (!_apply_preamble({tensor1, tensor2, tensor3}))
288*da0073e9SAndroid Build Coastguard Worker return;
289*da0073e9SAndroid Build Coastguard Worker if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
290*da0073e9SAndroid Build Coastguard Worker apply_op(
291*da0073e9SAndroid Build Coastguard Worker tensor1.numel(),
292*da0073e9SAndroid Build Coastguard Worker 0,
293*da0073e9SAndroid Build Coastguard Worker op,
294*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar1, 8>(tensor1),
295*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar2, 8>(tensor2),
296*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar3, 8>(tensor3));
297*da0073e9SAndroid Build Coastguard Worker } else {
298*da0073e9SAndroid Build Coastguard Worker apply_op(
299*da0073e9SAndroid Build Coastguard Worker tensor1.numel(),
300*da0073e9SAndroid Build Coastguard Worker 0,
301*da0073e9SAndroid Build Coastguard Worker op,
302*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar1>(tensor1),
303*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar2>(tensor2),
304*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar3>(tensor3));
305*da0073e9SAndroid Build Coastguard Worker }
306*da0073e9SAndroid Build Coastguard Worker }
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker template <
309*da0073e9SAndroid Build Coastguard Worker typename scalar1,
310*da0073e9SAndroid Build Coastguard Worker typename scalar2,
311*da0073e9SAndroid Build Coastguard Worker typename scalar3,
312*da0073e9SAndroid Build Coastguard Worker typename scalar4,
313*da0073e9SAndroid Build Coastguard Worker typename Op>
CPU_tensor_apply4(Tensor tensor1,Tensor tensor2,Tensor tensor3,Tensor tensor4,const Op op)314*da0073e9SAndroid Build Coastguard Worker inline void CPU_tensor_apply4(
315*da0073e9SAndroid Build Coastguard Worker Tensor tensor1,
316*da0073e9SAndroid Build Coastguard Worker Tensor tensor2,
317*da0073e9SAndroid Build Coastguard Worker Tensor tensor3,
318*da0073e9SAndroid Build Coastguard Worker Tensor tensor4,
319*da0073e9SAndroid Build Coastguard Worker const Op op) {
320*da0073e9SAndroid Build Coastguard Worker if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
321*da0073e9SAndroid Build Coastguard Worker return;
322*da0073e9SAndroid Build Coastguard Worker if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
323*da0073e9SAndroid Build Coastguard Worker apply_op(
324*da0073e9SAndroid Build Coastguard Worker tensor1.numel(),
325*da0073e9SAndroid Build Coastguard Worker 0,
326*da0073e9SAndroid Build Coastguard Worker op,
327*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar1, 8>(tensor1),
328*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar2, 8>(tensor2),
329*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar3, 8>(tensor3),
330*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter_fixed<scalar4, 8>(tensor4));
331*da0073e9SAndroid Build Coastguard Worker } else {
332*da0073e9SAndroid Build Coastguard Worker apply_op(
333*da0073e9SAndroid Build Coastguard Worker tensor1.numel(),
334*da0073e9SAndroid Build Coastguard Worker 0,
335*da0073e9SAndroid Build Coastguard Worker op,
336*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar1>(tensor1),
337*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar2>(tensor2),
338*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar3>(tensor3),
339*da0073e9SAndroid Build Coastguard Worker strided_tensor_iter<scalar4>(tensor4));
340*da0073e9SAndroid Build Coastguard Worker }
341*da0073e9SAndroid Build Coastguard Worker }
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker } // namespace at
344