xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/IndexKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/IndexKernel.h>
3 
4 #include <cmath>
5 #include <iostream>
6 
7 #include <ATen/Context.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/Dispatch_v2.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/native/TensorIterator.h>
12 #include <ATen/native/cpu/AtomicAddFloat.h>
13 #include <ATen/native/cpu/IndexKernelUtils.h>
14 #include <ATen/native/cpu/Loops.h>
15 #include <ATen/cpu/vec/vec.h>
16 #include <c10/util/irange.h>
17 #include <c10/core/Scalar.h>
18 
19 namespace at::native {
20 namespace {
21 
22 using namespace vec;
23 
index_kernel(TensorIteratorBase & iter,IntArrayRef index_size,IntArrayRef index_stride)24 void index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
25   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
26     iter.dtype(), "index_cpu", [&] {
27     cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
28       *(scalar_t*)dst = *(scalar_t*)(src + offset);
29     });
30   });
31 }
32 
33 // Given a linear index, returns the offset of the tensor.
34 // Implements the same algorithm as its (legacy) GPU version cuda::detail::IndexToOffset
35 // OffsetCalculator implements yet again the same algorithm but in a column-major order
36 struct IndexToOffset {
37   const IntArrayRef sizes;
38   const IntArrayRef strides;
39   const int64_t ndim;
IndexToOffsetat::native::__anon3a6d807b0111::IndexToOffset40   explicit IndexToOffset(const TensorBase & tensor) :
41       sizes(tensor.sizes()), strides(tensor.strides()), ndim(tensor.dim()) {
42   }
43 
getat::native::__anon3a6d807b0111::IndexToOffset44   int64_t get(int64_t linear_index) const {
45     int64_t offset = 0;
46     for (int64_t i = ndim - 1; i > 0; i--) {
47       offset += (linear_index % sizes[i]) * strides[i];
48       linear_index /= sizes[i];
49     }
50     return offset + linear_index * strides[0];
51   }
52 };
53 
54 template <typename scalar_t, typename func_t>
cpu_take_put_kernel(TensorIterator & iter,const TensorBase & indexed,bool is_indexed_data_mutated,const func_t & f,bool serial_execution=false)55 void cpu_take_put_kernel(
56     TensorIterator& iter,
57     const TensorBase& indexed,
58     bool is_indexed_data_mutated,
59     const func_t& f,
60     bool serial_execution=false) {
61   // This kernel follows the same strategy as `cpu_index_kernel`
62   // Even though the indexed_tensor is const, we modify it through the data_ptr
63   // This is a bit dirty, but otherwise it would be necessary to unnecessarily add tensor
64   // with zero strides to `iter` which would not be much better
65 
66   // When launch the parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
67   // to make the whole available thread numbers get more balanced work load and a better cache location.
68   // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
69   // Perhaps tweak this number for `put_`? This number was tweaked for `index_put`
70   constexpr int parallel_grain_size = 3000;
71   const bool is_contiguous = indexed.is_contiguous();
72   const auto numel = indexed.numel();
73   const auto offset_indexed = IndexToOffset(indexed);
74 
75   auto* indexed_data = is_indexed_data_mutated ?
76    indexed.data_ptr<scalar_t>()
77    : const_cast<scalar_t*>(indexed.const_data_ptr<scalar_t>());
78   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
79     auto* iterated_data_bytes = data[0];
80     auto* index_data_bytes = data[1];
81     for (const auto elem C10_UNUSED : c10::irange(n)) {
82       auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
83       auto& iterated = *reinterpret_cast<scalar_t*>(iterated_data_bytes);
84 
85       TORCH_CHECK_INDEX(idx >= -numel && idx < numel,
86                         "out of range: tried to access index ",
87                         idx, " on a tensor of ", numel, " elements.");
88       if (idx < 0) {
89         idx += numel;
90       }
91       if (!is_contiguous) {
92         idx = offset_indexed.get(idx);
93       }
94       f(iterated, indexed_data, idx);
95       iterated_data_bytes += strides[0];
96       index_data_bytes += strides[1];
97     }
98   };
99   if (serial_execution) {
100     iter.serial_for_each(loop, {0, iter.numel()});
101   } else {
102     iter.for_each(loop, parallel_grain_size);
103   }
104 }
105 
put_kernel(TensorIterator & iter,const TensorBase & self,const bool accumulate)106 void put_kernel(
107   TensorIterator& iter,
108   const TensorBase & self,
109   const bool accumulate) {
110   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
111     iter.dtype(), "take_put_cpu", [&] {
112   // iter could be const, but for_each does not have a const version
113     if (accumulate) {
114       // nb. This deterministic issue the same as that of `index_put_kernel`
115       // See Note [Enabling Deterministic Operations]
116       // Parallel cpu_put_kernel with accumulation is nondeterministic, so we
117       // must enable serial execution if deterministic algorithms are enabled.
118       bool is_deterministic = at::globalContext().deterministicAlgorithms();
119       bool use_parallel_for = (!is_deterministic) && (
120         (iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
121       if (use_parallel_for && iter.dtype() == ScalarType::Float) {
122         cpu_take_put_kernel<float>(iter, self, true,
123             [](float& iterated, float* indexed, const int64_t idx) {
124                 cpu_atomic_add_float(indexed+idx, iterated);
125               });
126       } else {
127         // TODO: investigate parallelization of the accumulate kernel.
128         // Unlike the non-accumulate case, this needs to be thread-safe.
129         cpu_take_put_kernel<scalar_t>(iter, self, true,
130             [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
131                 indexed[idx] += iterated;
132               },
133             /*serial_execution=*/true);
134       }
135     } else {
136       cpu_take_put_kernel<scalar_t>(iter, self, true,
137           [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
138               indexed[idx] = iterated;
139             });
140     }
141   });
142 }
143 
take_kernel(TensorIterator & iter,const TensorBase & input)144 void take_kernel(
145   TensorIterator& iter,
146   const TensorBase & input) {
147   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
148     iter.dtype(), "take_cpu", [&] {
149       cpu_take_put_kernel<scalar_t>(iter, input, false,
150           [](scalar_t& iterated, const scalar_t* indexed, const int64_t idx) {
151               iterated = indexed[idx];
152             });
153     });
154 }
155 
index_put_kernel(TensorIterator & iter,IntArrayRef index_size,IntArrayRef index_stride,bool accumulate)156 void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
157   // NOTE: duplicate indices are only supported if accumulate is true.
158   AT_DISPATCH_V2(
159     iter.dtype(),
160     "index_put",
161     AT_WRAP([&] {
162       // See Note [Enabling Deterministic Operations]
163       // Parallel cpu_index_kernel with accumulation is nondeterministic, so we
164       // must enable serial execution if deterministic algorithms are enabled.
165       const bool is_deterministic = at::globalContext().deterministicAlgorithms();
166       if (accumulate) {
167         bool use_parallel_for = (!is_deterministic) && (
168           (iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
169         if (use_parallel_for && iter.dtype() == ScalarType::Float) {
170           cpu_index_kernel<float>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
171             cpu_atomic_add_float((float*)(dst + offset), *(float*)src);
172           });
173         } else {
174           // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
175           // this needs to be thread-safe.
176           cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
177             *(scalar_t*)(dst + offset) += *(scalar_t*)src;
178           }, /*serial_execution=*/true);
179         }
180       } else {
181         cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
182           *(scalar_t*)(dst + offset) = *(scalar_t*)src;
183         }, /*serial_execution=*/is_deterministic);
184       }
185     }),
186     AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
187     AT_EXPAND(AT_FLOAT8_TYPES),
188     kComplexHalf,
189     kHalf,
190     kBool,
191     kBFloat16);
192 }
193 
index_fill_kernel(TensorIterator & iter,int64_t dim,int64_t self_dim_size,int64_t self_dim_stride,const Scalar & source)194 void index_fill_kernel(
195   TensorIterator& iter,
196   int64_t dim,
197   int64_t self_dim_size,
198   int64_t self_dim_stride,
199   const Scalar& source) {
200   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, kComplexHalf,
201     iter.dtype(), "index_fill_cpu", [&] {
202     auto fill_val = source.to<scalar_t>();
203     auto handle_nonzero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
204       auto* self_data_bytes = data[0];
205       auto* index_data_bytes = data[1];
206       for (const auto elem C10_UNUSED : c10::irange(n)) {
207         auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
208         auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
209         TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size,
210                           "index ", idx, " is out of bounds for dimension ",
211                           dim, " with size ", self_dim_size);
212         if (idx < 0) {
213           idx += self_dim_size;
214         }
215 
216         self_data[idx * self_dim_stride] = fill_val;
217 
218         self_data_bytes += strides[0];
219         index_data_bytes += strides[1];
220       }
221     };
222     auto handle_zero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
223       auto* self_data_bytes = data[0];
224       auto* index_data_bytes = data[1];
225       auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
226       TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size,
227                         "index ", idx, " is out of bounds for dimension ",
228                         dim, " with size ", self_dim_size);
229       if (idx < 0) {
230         idx += self_dim_size;
231       }
232       for (const auto elem C10_UNUSED: c10::irange(n)) {
233         auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
234 
235         self_data[idx * self_dim_stride] = fill_val;
236 
237         self_data_bytes += strides[0];
238       }
239     };
240 
241     auto loop = [&](char** data, const int64_t* strides, int64_t n) {
242       auto idx_stride = strides[1];
243       if (idx_stride) {
244         handle_nonzero_idx_stride(data, strides, n);
245       }
246       else {
247         handle_zero_idx_stride(data, strides, n);
248       }
249     };
250     iter.for_each(loop);
251   });
252 }
253 
index_copy_kernel(TensorIterator & iter,int64_t dim,int64_t self_dim_size,int64_t self_dim_stride)254 void index_copy_kernel(
255   TensorIterator& iter,
256   int64_t dim,
257   int64_t self_dim_size,
258   int64_t self_dim_stride) {
259   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, kComplexHalf,
260     iter.dtype(), "index_copy_cpu", [&] {
261     auto handle_nonzero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
262       auto* self_data_bytes = data[0];
263       auto* index_data_bytes = data[1];
264       auto* source_data_bytes = data[2];
265       for (const auto elem C10_UNUSED : c10::irange(n)) {
266         auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
267         auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
268         auto* source_data = reinterpret_cast<scalar_t*>(source_data_bytes);
269         TORCH_CHECK_INDEX(idx >= 0 && idx < self_dim_size,
270               "index_copy_(): index ", idx, " is out of bounds for dimension ",
271               dim, " with size ", self_dim_size);
272 
273         self_data[idx * self_dim_stride] = *source_data;
274 
275         self_data_bytes += strides[0];
276         index_data_bytes += strides[1];
277         source_data_bytes += strides[2];
278       }
279     };
280     auto handle_zero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
281       auto* self_data_bytes = data[0];
282       auto* index_data_bytes = data[1];
283       auto* source_data_bytes = data[2];
284       auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
285       TORCH_CHECK_INDEX(idx >= 0 && idx < self_dim_size,
286             "index_copy_(): index ", idx, " is out of bounds for dimension ",
287             dim, " with size ", self_dim_size);
288       for (const auto elem C10_UNUSED : c10::irange(n)) {
289         auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
290         auto* source_data = reinterpret_cast<scalar_t*>(source_data_bytes);
291 
292         self_data[idx * self_dim_stride] = *source_data;
293 
294         self_data_bytes += strides[0];
295         source_data_bytes += strides[2];
296       }
297     };
298 
299     auto loop = [&](char** data, const int64_t* strides, int64_t n) {
300       auto idx_stride = strides[1];
301       if (idx_stride) {
302         handle_nonzero_idx_stride(data, strides, n);
303       }
304       else {
305         handle_zero_idx_stride(data, strides, n);
306       }
307     };
308     bool is_deterministic = at::globalContext().deterministicAlgorithms();
309     if (is_deterministic) {
310       iter.serial_for_each(loop, {0, iter.numel()});
311     } else {
312       iter.for_each(loop);
313     }
314   });
315 }
316 
317 template <typename scalar_t>
cpu_masked_fill_kernel(TensorIterator & iter,scalar_t value)318 void cpu_masked_fill_kernel(TensorIterator& iter, scalar_t value) {
319   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
320     char* dst = data[0];
321     char* mask = data[1];
322     for (const auto i : c10::irange(n)) {
323       bool mask_value = *reinterpret_cast<bool*>(mask + strides[1] * i);
324 
325       if (mask_value) {
326         *(scalar_t*)(dst + strides[0] * i) = value;
327       }
328     }
329   };
330   iter.for_each(loop);
331 }
332 
masked_fill_kernel(TensorIterator & iter,const Scalar & value)333 void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
334   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf,
335     iter.dtype(), "masked_fill", [&] {
336       scalar_t scalar_val = value.to<scalar_t>();
337       auto mask_dtype = iter.input_dtype(0);
338       TORCH_CHECK(mask_dtype == ScalarType::Bool, "masked_fill only supports boolean masks, "
339         "but got mask with dtype ", mask_dtype);
340       cpu_masked_fill_kernel<scalar_t>(iter, scalar_val);
341     });
342 }
343 
344 template <typename scalar_t>
cpu_masked_scatter_kernel(TensorIterator & iter,const TensorBase & source)345 void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
346   std::ptrdiff_t source_cntr = 0;
347   const scalar_t* source_ptr = source.const_data_ptr<scalar_t>();
348   auto numel = source.numel();
349 
350   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
351     char* dst = data[0];
352     const int64_t dst_stride = strides[0];
353     char* mask = data[1];
354     const int64_t mask_stride = strides[1];
355     for (const auto i : c10::irange(n)) {
356       auto mask_value = *reinterpret_cast<bool*>(mask + mask_stride * i);
357       if (mask_value) {
358         TORCH_CHECK(source_cntr < numel, "Number of elements of source < number of ones in mask");
359         *(scalar_t*)(dst + dst_stride * i) = *(source_ptr);
360         source_ptr++;
361         source_cntr++;
362       }
363     }
364   };
365   iter.serial_for_each(loop, {0, iter.numel()});
366 }
367 
masked_scatter_kernel(TensorIterator & iter,const TensorBase & source)368 void masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
369  TORCH_CHECK(iter.input_dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, "
370     "but got mask with dtype ", iter.input_dtype());
371   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
372       ScalarType::Bool,
373       ScalarType::BFloat16,
374       ScalarType::Half,
375       iter.dtype(),
376       "masked_scatter",
377       [&] {
378           cpu_masked_scatter_kernel<scalar_t>(iter, source);
379       });
380 }
381 
382 template <typename scalar_t, typename mask_t, typename func_t>
cpu_masked_select_serial_kernel(TensorIterator & iter,const func_t & f)383 void cpu_masked_select_serial_kernel(TensorIterator& iter, const func_t& f) {
384   int64_t offset = 0;
385   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
386     char* dst = data[0];
387     char* src = data[1];
388     char* mask = data[2];
389     for (const auto i : c10::irange(n)) {
390       mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
391       if constexpr (!std::is_same<mask_t, bool>::value) {
392         TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
393       }
394       if (mask_value) {
395         int64_t offset_bytes = offset * sizeof(scalar_t);
396         f(dst, src + strides[1] * i, offset_bytes);
397         offset++;
398       }
399     }
400   };
401   iter.serial_for_each(loop, {0, iter.numel()});
402 }
403 
masked_select_serial_kernel(TensorIterator & iter,int64_t result_stride)404 void masked_select_serial_kernel(TensorIterator& iter, int64_t result_stride) {
405   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
406     iter.dtype(), "masked_select", [&] {
407       auto mask_dtype = iter.input_dtype(1);
408       if (mask_dtype == ScalarType::Bool) {
409         cpu_masked_select_serial_kernel<scalar_t, bool>(iter, [result_stride](char* dst, char* src, int64_t offset) {
410           *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
411         });
412       } else {
413         cpu_masked_select_serial_kernel<scalar_t, unsigned char>(iter, [result_stride](char* dst, char* src, int64_t offset) {
414           *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
415         });
416       }
417     });
418 }
419 
420 template <typename scalar_t, typename mask_t, typename func_t>
cpu_masked_select_kernel(TensorIterator & iter,const func_t & f)421 void cpu_masked_select_kernel(TensorIterator& iter, const func_t& f) {
422   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
423     char* dst = data[0];
424     char* src = data[1];
425     char* mask = data[2];
426     char* mask_prefix_sum = data[3];
427     for (const auto i : c10::irange(n)) {
428       mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
429       if constexpr (!std::is_same<mask_t, bool>::value) {
430         TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
431       }
432       if (mask_value) {
433         int64_t offset = *(int64_t*)(mask_prefix_sum + strides[3] * i);
434         int64_t offset_bytes = (offset - 1) * sizeof(scalar_t);
435         f(dst, src + strides[1] * i, offset_bytes);
436       }
437     }
438   };
439   iter.for_each(loop);
440 }
441 
masked_select_kernel(TensorIterator & iter,int64_t result_stride)442 void masked_select_kernel(TensorIterator& iter, int64_t result_stride) {
443   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
444     iter.dtype(), "masked_select", [&] {
445       auto mask_dtype = iter.input_dtype(1);
446       if (mask_dtype == ScalarType::Bool) {
447         cpu_masked_select_kernel<scalar_t, bool>(iter, [result_stride](char* dst, char* src, int64_t offset) {
448           *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
449         });
450       } else {
451         cpu_masked_select_kernel<scalar_t, unsigned char>(iter, [result_stride](char* dst, char* src, int64_t offset) {
452           *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
453         });
454       }
455     });
456 }
457 
458 template <typename scalar_t>
cpu_hflip_vec(at::TensorIterator & iter)459 void cpu_hflip_vec(at::TensorIterator& iter) {
460 
461   auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
462 
463     // Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
464     // and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
465     // output and input.
466     static constexpr int ntensors = 2;
467     const int64_t *outer_strides = &strides[3];
468 
469     std::array<char*, ntensors> data_arr;
470     std::copy_n(base, ntensors, data_arr.data());
471 
472     using Vec = Vectorized<scalar_t>;
473 
474     constexpr auto stride = sizeof(scalar_t);
475     TORCH_INTERNAL_ASSERT(stride == -strides[0] && stride == strides[1]);
476 
477     for (const auto j C10_UNUSED : c10::irange(size1)) {
478 
479       // vectorized loop with negative stride for output
480       char** C10_RESTRICT data_ = data_arr.data();
481       int64_t n = size0;
482 
483       char* C10_RESTRICT data[ntensors];
484       for (const auto arg : c10::irange(ntensors)) {
485         data[arg] = data_[arg];
486       }
487 
488       int64_t i = 0;
489 
490       // data[0] unaligned pre-pass
491       int64_t offset = (j * n + (n - i - Vec::size())) % 32;
492       offset = (offset >= n) ? n : offset;
493       for (; i < offset; i++) {
494         scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
495         *out_ptr = *(scalar_t *)(data[1] + i * stride);
496       }
497       // Empirically found that it is faster to process 3 data items together vs 2 or 4
498       for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) {
499         auto out1 = Vec::loadu(data[1] + i * stride);
500         auto out2 = Vec::loadu(data[1] + (i + Vec::size()) * stride);
501         auto out3 = Vec::loadu(data[1] + (i + 2 * Vec::size()) * stride);
502         // flip the vector: 1234 -> 4321
503         out1 = flip(out1);
504         out2 = flip(out2);
505         out3 = flip(out3);
506         out1.store(data[0] - (i + Vec::size() - 1) * stride);
507         out2.store(data[0] - (i + 2 * Vec::size() - 1) * stride);
508         out3.store(data[0] - (i + 3 * Vec::size() - 1) * stride);
509       }
510       if (i < n) {
511         for (; i < n; i++) {
512           scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
513           *out_ptr = *(scalar_t *)(data[1] + i * stride);
514         }
515       }
516 
517       // advance:
518       for (const auto arg : c10::irange(ntensors)) {
519         data_arr[arg] += outer_strides[arg];
520       }
521     }
522   };
523 
524   int64_t grain_size = at::internal::GRAIN_SIZE;
525   iter.for_each(loop2d, grain_size);
526   iter.cast_outputs();
527 }
528 
cpu_vflip_memcpy(at::TensorIterator & iter)529 void cpu_vflip_memcpy(at::TensorIterator& iter) {
530   // This is a vertical flip specialization using memcpy to speed-up the runtime
531 
532   auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
533 
534     // Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
535     // and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
536     // output and input.
537     static constexpr int ntensors = 2;
538     const int64_t *outer_strides = &strides[3];
539 
540     std::array<char*, ntensors> data_arr;
541     std::copy_n(base, ntensors, data_arr.data());
542 
543     TORCH_INTERNAL_ASSERT(strides[0] == strides[1]);
544     const int64_t stride = strides[0];
545 
546     for (const auto j C10_UNUSED : c10::irange(size1)) {
547 
548       char** C10_RESTRICT data_ = data_arr.data();
549       int64_t n = size0;
550 
551       char* C10_RESTRICT data[ntensors];
552       for (const auto arg : c10::irange(ntensors)) {
553         data[arg] = data_[arg];
554       }
555 
556       memcpy(data[0], data[1], n * stride);
557 
558       // advance:
559       for (const auto arg : c10::irange(data_arr.size())) {
560         data_arr[arg] += outer_strides[arg];
561       }
562     }
563   };
564 
565   int64_t grain_size = at::internal::GRAIN_SIZE;
566   iter.for_each(loop2d, grain_size);
567   iter.cast_outputs();
568 }
569 
570 constexpr int64_t hflip_mask_size = 32;
571 
generate_vec_hflip_reg_mask(int64_t data_stride)572 std::array<char, hflip_mask_size> generate_vec_hflip_reg_mask(int64_t data_stride) {
573     std::array<char, hflip_mask_size> mask;
574     for (const auto k : c10::irange(hflip_mask_size / 2)) {
575       int j = k / data_stride + 1;
576       int v = (j * data_stride - 1) - (k % data_stride);
577       v = std::min(v, (int) (hflip_mask_size / 2 - 1));
578       mask[hflip_mask_size - 1 - k] = v;
579       mask[hflip_mask_size / 2 - 1 - k] = v;
580     }
581     return mask;
582 }
583 
vectorized_cpu_hflip_channels_last(char * C10_RESTRICT * data,const int64_t data_size,const int64_t data_stride,const std::array<char,32> & mdata)584 int64_t vectorized_cpu_hflip_channels_last(
585     char * C10_RESTRICT *data, const int64_t data_size, const int64_t data_stride, const std::array<char, 32> & mdata) {
586 
587   int64_t i = 0;
588 #ifdef CPU_CAPABILITY_AVX2
589 
590   constexpr auto vec_size = 256 / 8;
591 
592   if (data_size > vec_size) {
593 
594       // Example for num channels=3 and dtype=uint8
595       // -> data_stride = 3
596       // -> usable_vec_stride = 30
597       // -> usable_vec_half_stride = 15
598       // Data: (1 2 3) (4 5 6) (7 8 9) (10 11 12) (13 14 15) (16 17 18) (19 20 21) (22 23 24) (25 26 27) (28 29 30) (31 32 33)
599       // load by 2 parts
600       // R = [ (1 2 3) (4 5 6) (7 8 9) (10 11 12) (13 14 15) (16 | (16 17 18) (19 20 21) (22 23 24) (25 26 27) (28 29 30) (31 ]
601       // flip(R) ->
602       // R = [ 31 (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) | 16 (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3) ]
603       //
604       // Write in 2 parts
605       // Output pointer: output_ptr = data[0]                                                                                  v
606       // - Init:
607       //                (X X X)  (X X X)    (X X X)    (X X X)    (X X X)    (X X X)    (X X X)    (X X X)    (X X X) (X X X) (X X X)
608       // 0) Move to initial position: output_ptr = data[0] + data_stride - vec_size / 2;
609       //                                                                          v
610       //                (X X X)  (X X X)    (X X X)    (X X X)    (X X X)    (X X X)    (X X X)    (X X X)    (X X X) (X X X) (X X X)
611       // - In the loop:
612       // 1) Write 1st block from output_ptr
613       //                                                                            v
614       //                                                                            |----> vec_size / 2 ---------------------------|
615       // Output part 1: (X X X)  (X X X)    (X X X)    (X X X)    (X X X)     (X X 16)  (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
616       // 2) Write 2nd block from output_ptr - usable_vec_half_stride:
617       //                                                                            v
618       //                     |-----> vec_size / 2 ----------------------------------|
619       // Output part 2: (X X 31) (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
620       //
621       // 3) Move to the next position: output_ptr -= usable_vec_stride
622       //
623       // - After the loop:
624       // 4) Move to write position
625       //                 v
626       //                (X X 31) (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
627 
628     const __m256i mask = _mm256_loadu_si256((__m256i *) mdata.data());
629 
630     const auto usable_vec_stride = 2 * (vec_size / 2 / data_stride) * data_stride;
631     const auto usable_vec_half_stride = usable_vec_stride / 2;
632 
633     auto output_ptr = data[0] + data_stride - vec_size / 2;
634     auto input_ptr = data[1];
635 
636     for (; i < data_size - vec_size; i += usable_vec_stride) {
637 
638       // load 256-bits by two 128-bits parts
639       auto a0 = _mm_loadu_si128((__m128i *) (input_ptr + i));
640       auto b0 = _mm256_castsi128_si256(a0);
641       auto a1 = _mm_loadu_si128((__m128i *) (input_ptr + i + usable_vec_half_stride));
642       auto data_vec = _mm256_inserti128_si256(b0, a1, 1);
643 
644       auto reversed_vec = _mm256_shuffle_epi8(data_vec, mask);
645 
646       // write output in two parts
647       auto rev_vec_h = _mm256_extracti128_si256(reversed_vec, 0);
648       _mm_storeu_si128((__m128i *) (output_ptr - i), rev_vec_h);
649       auto rev_vec_l = _mm256_extracti128_si256(reversed_vec, 1);
650       _mm_storeu_si128((__m128i *) (output_ptr - i - usable_vec_half_stride), rev_vec_l);
651     }
652 
653     data[0] -= i;
654     data[1] += i;
655   }
656 #endif
657   return i;
658 }
659 
cpu_hflip_channels_last_vec(at::TensorIterator & iter)660 void cpu_hflip_channels_last_vec(at::TensorIterator& iter) {
661 
662   auto input_strides = iter.strides(1);
663   const auto data_stride = input_strides[1];
664 
665   // Generate avx mask once
666   alignas(hflip_mask_size) auto mdata = generate_vec_hflip_reg_mask(data_stride);
667 
668   auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
669 
670     // Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
671     // and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
672     // output and input.
673     static constexpr int ntensors = 2;
674     const int64_t *outer_strides = &strides[3];
675     const int64_t stride = strides[0];
676 
677     TORCH_INTERNAL_ASSERT(stride == strides[1]);
678 
679     auto c = -outer_strides[0];
680     TORCH_INTERNAL_ASSERT(c == outer_strides[1]);
681 
682     char* C10_RESTRICT data[ntensors] = {base[0], base[1]};
683     const int64_t size = size0 * size1;
684 
685     int64_t i = 0;
686 
687     if (c >= 2 && c <= 16) {
688       i = vectorized_cpu_hflip_channels_last(data, size * stride, c, mdata) / stride;
689     }
690 
691     auto data_stride = size0 * stride;
692     for (; i < size; i += size0) {
693 
694       memcpy(data[0], data[1], data_stride);
695 
696       // advance:
697       for (const auto arg : c10::irange(ntensors)) {
698         data[arg] += outer_strides[arg];
699       }
700     }
701 
702   };
703 
704   int64_t grain_size = at::internal::GRAIN_SIZE;
705   iter.for_each(loop2d, grain_size);
706   iter.cast_outputs();
707 }
708 
flip_kernel(TensorIterator & iter,const bool quantized)709 void flip_kernel(TensorIterator& iter, const bool quantized) {
710   if (quantized) {
711     AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_cpu",
712         [&iter] { cpu_kernel(iter,
713           [](scalar_t a, scalar_t /*dummy input*/) -> scalar_t {
714             return a;
715         });
716     });
717   } else {
718     auto output_strides = iter.strides(0);
719     auto input_strides = iter.strides(1);
720     if (iter.ndim() > 0 && output_strides[0] == -iter.element_size(0) && input_strides[0] == iter.element_size(1)) {
721       // Special case: horizontal flip with vectorization and input is contiguous
722       // Context: horizontal flip leads to strides[0] < 0 and
723       // thus is_contiguous condition is not satisfied and non-vectorized code path is taken.
724       auto iter_dtype = iter.dtype();
725       // Ignoring half and bfloat16 as cpu_hflip_vec is slower than cpu_kernel_vec
726       if (isIntegralType(iter_dtype, true) || iter_dtype == kDouble || iter_dtype == kFloat) {
727         // Replace AT_DISPATCH_ALL_TYPES_AND by manual if/else due to internal test failures:
728         // - "dtype 'Float' not selected for kernel tag hflip_cpu"
729         // - "dtype 'Long' not selected for kernel tag hflip_cpu"
730         //
731         // AT_DISPATCH_ALL_TYPES_AND(kBool,
732         //     iter_dtype, "hflip_cpu", [&iter] {
733         //       cpu_hflip_vec<scalar_t>(iter);
734         // });
735 
736         if (iter_dtype == kByte) {
737           return cpu_hflip_vec<uint8_t>(iter);
738         } else if (iter_dtype == kChar) {
739           return cpu_hflip_vec<int8_t>(iter);
740         } else if (iter_dtype == kInt) {
741           return cpu_hflip_vec<int32_t>(iter);
742         } else if (iter_dtype == kLong) {
743           return cpu_hflip_vec<int64_t>(iter);
744         } else if (iter_dtype == kShort) {
745           return cpu_hflip_vec<int16_t>(iter);
746         } else if (iter_dtype == kBool) {
747           return cpu_hflip_vec<bool>(iter);
748         } else if (iter_dtype == kFloat) {
749           return cpu_hflip_vec<float>(iter);
750         } else if (iter_dtype == kDouble) {
751           return cpu_hflip_vec<double>(iter);
752         }
753       }
754       // other dtypes (float16, bfloat16, complex) are handled by cpu_kernel_vec (see below)
755     } else if (iter.has_contiguous_first_dim()) {
756       // Special cases:
757       // a) channels last hflip on (N, C, H, W) and outer_stride(=dtype_size * C) in [2, 16]
758       // b) flip dim=-2 on (N, ..., M, C) and outer_stride(=dtype_size * C) in [2, 16]
759       auto output_strides_2 = iter.strides(0);
760       auto input_strides_2 = iter.strides(1);
761       auto c = -output_strides_2[1];
762       if (c >= 2 && c <= 16 &&
763           c == input_strides_2[1] &&
764           c == iter.element_size(0) * iter.shape()[0]  // checks if dim=1 is contiguous as well
765       ) {
766         return cpu_hflip_channels_last_vec(iter);
767       }
768       // Special case: vertical flip using memcpy (faster than generic cpu_kernel_vec)
769       return cpu_vflip_memcpy(iter);
770     }
771 
772     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(), "flip_cpu",
773         [&iter] { cpu_kernel_vec(iter,
774           [](scalar_t a, scalar_t /*dummy input*/) -> scalar_t {
775             return a;
776         },
777           [](Vectorized<scalar_t> a, Vectorized<scalar_t> /*dummy input*/) -> Vectorized<scalar_t> {
778             return a;
779         });
780     });
781   }
782 }
783 
784 } // anonymous namespace
785 
786 REGISTER_DISPATCH(index_stub, &index_kernel);
787 REGISTER_DISPATCH(index_fill_stub, &index_fill_kernel);
788 REGISTER_DISPATCH(index_copy_stub, &index_copy_kernel);
789 REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
790 REGISTER_DISPATCH(put_stub, &put_kernel);
791 REGISTER_DISPATCH(take_stub, &take_kernel);
792 REGISTER_DISPATCH(masked_fill_stub, &masked_fill_kernel);
793 REGISTER_DISPATCH(masked_select_serial_stub, &masked_select_serial_kernel);
794 REGISTER_DISPATCH(masked_select_stub, &masked_select_kernel);
795 REGISTER_DISPATCH(masked_scatter_stub, &masked_scatter_kernel);
796 REGISTER_DISPATCH(flip_stub, &flip_kernel);
797 
798 } // namespace at::native
799