1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/IndexKernel.h>
3
4 #include <cmath>
5 #include <iostream>
6
7 #include <ATen/Context.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/Dispatch_v2.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/native/TensorIterator.h>
12 #include <ATen/native/cpu/AtomicAddFloat.h>
13 #include <ATen/native/cpu/IndexKernelUtils.h>
14 #include <ATen/native/cpu/Loops.h>
15 #include <ATen/cpu/vec/vec.h>
16 #include <c10/util/irange.h>
17 #include <c10/core/Scalar.h>
18
19 namespace at::native {
20 namespace {
21
22 using namespace vec;
23
index_kernel(TensorIteratorBase & iter,IntArrayRef index_size,IntArrayRef index_stride)24 void index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
25 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
26 iter.dtype(), "index_cpu", [&] {
27 cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
28 *(scalar_t*)dst = *(scalar_t*)(src + offset);
29 });
30 });
31 }
32
33 // Given a linear index, returns the offset of the tensor.
34 // Implements the same algorithm as its (legacy) GPU version cuda::detail::IndexToOffset
35 // OffsetCalculator implements yet again the same algorithm but in a column-major order
36 struct IndexToOffset {
37 const IntArrayRef sizes;
38 const IntArrayRef strides;
39 const int64_t ndim;
IndexToOffsetat::native::__anon3a6d807b0111::IndexToOffset40 explicit IndexToOffset(const TensorBase & tensor) :
41 sizes(tensor.sizes()), strides(tensor.strides()), ndim(tensor.dim()) {
42 }
43
getat::native::__anon3a6d807b0111::IndexToOffset44 int64_t get(int64_t linear_index) const {
45 int64_t offset = 0;
46 for (int64_t i = ndim - 1; i > 0; i--) {
47 offset += (linear_index % sizes[i]) * strides[i];
48 linear_index /= sizes[i];
49 }
50 return offset + linear_index * strides[0];
51 }
52 };
53
54 template <typename scalar_t, typename func_t>
cpu_take_put_kernel(TensorIterator & iter,const TensorBase & indexed,bool is_indexed_data_mutated,const func_t & f,bool serial_execution=false)55 void cpu_take_put_kernel(
56 TensorIterator& iter,
57 const TensorBase& indexed,
58 bool is_indexed_data_mutated,
59 const func_t& f,
60 bool serial_execution=false) {
61 // This kernel follows the same strategy as `cpu_index_kernel`
62 // Even though the indexed_tensor is const, we modify it through the data_ptr
63 // This is a bit dirty, but otherwise it would be necessary to unnecessarily add tensor
64 // with zero strides to `iter` which would not be much better
65
66 // When launch the parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
67 // to make the whole available thread numbers get more balanced work load and a better cache location.
68 // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
69 // Perhaps tweak this number for `put_`? This number was tweaked for `index_put`
70 constexpr int parallel_grain_size = 3000;
71 const bool is_contiguous = indexed.is_contiguous();
72 const auto numel = indexed.numel();
73 const auto offset_indexed = IndexToOffset(indexed);
74
75 auto* indexed_data = is_indexed_data_mutated ?
76 indexed.data_ptr<scalar_t>()
77 : const_cast<scalar_t*>(indexed.const_data_ptr<scalar_t>());
78 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
79 auto* iterated_data_bytes = data[0];
80 auto* index_data_bytes = data[1];
81 for (const auto elem C10_UNUSED : c10::irange(n)) {
82 auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
83 auto& iterated = *reinterpret_cast<scalar_t*>(iterated_data_bytes);
84
85 TORCH_CHECK_INDEX(idx >= -numel && idx < numel,
86 "out of range: tried to access index ",
87 idx, " on a tensor of ", numel, " elements.");
88 if (idx < 0) {
89 idx += numel;
90 }
91 if (!is_contiguous) {
92 idx = offset_indexed.get(idx);
93 }
94 f(iterated, indexed_data, idx);
95 iterated_data_bytes += strides[0];
96 index_data_bytes += strides[1];
97 }
98 };
99 if (serial_execution) {
100 iter.serial_for_each(loop, {0, iter.numel()});
101 } else {
102 iter.for_each(loop, parallel_grain_size);
103 }
104 }
105
put_kernel(TensorIterator & iter,const TensorBase & self,const bool accumulate)106 void put_kernel(
107 TensorIterator& iter,
108 const TensorBase & self,
109 const bool accumulate) {
110 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
111 iter.dtype(), "take_put_cpu", [&] {
112 // iter could be const, but for_each does not have a const version
113 if (accumulate) {
114 // nb. This deterministic issue the same as that of `index_put_kernel`
115 // See Note [Enabling Deterministic Operations]
116 // Parallel cpu_put_kernel with accumulation is nondeterministic, so we
117 // must enable serial execution if deterministic algorithms are enabled.
118 bool is_deterministic = at::globalContext().deterministicAlgorithms();
119 bool use_parallel_for = (!is_deterministic) && (
120 (iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
121 if (use_parallel_for && iter.dtype() == ScalarType::Float) {
122 cpu_take_put_kernel<float>(iter, self, true,
123 [](float& iterated, float* indexed, const int64_t idx) {
124 cpu_atomic_add_float(indexed+idx, iterated);
125 });
126 } else {
127 // TODO: investigate parallelization of the accumulate kernel.
128 // Unlike the non-accumulate case, this needs to be thread-safe.
129 cpu_take_put_kernel<scalar_t>(iter, self, true,
130 [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
131 indexed[idx] += iterated;
132 },
133 /*serial_execution=*/true);
134 }
135 } else {
136 cpu_take_put_kernel<scalar_t>(iter, self, true,
137 [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
138 indexed[idx] = iterated;
139 });
140 }
141 });
142 }
143
take_kernel(TensorIterator & iter,const TensorBase & input)144 void take_kernel(
145 TensorIterator& iter,
146 const TensorBase & input) {
147 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
148 iter.dtype(), "take_cpu", [&] {
149 cpu_take_put_kernel<scalar_t>(iter, input, false,
150 [](scalar_t& iterated, const scalar_t* indexed, const int64_t idx) {
151 iterated = indexed[idx];
152 });
153 });
154 }
155
index_put_kernel(TensorIterator & iter,IntArrayRef index_size,IntArrayRef index_stride,bool accumulate)156 void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
157 // NOTE: duplicate indices are only supported if accumulate is true.
158 AT_DISPATCH_V2(
159 iter.dtype(),
160 "index_put",
161 AT_WRAP([&] {
162 // See Note [Enabling Deterministic Operations]
163 // Parallel cpu_index_kernel with accumulation is nondeterministic, so we
164 // must enable serial execution if deterministic algorithms are enabled.
165 const bool is_deterministic = at::globalContext().deterministicAlgorithms();
166 if (accumulate) {
167 bool use_parallel_for = (!is_deterministic) && (
168 (iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
169 if (use_parallel_for && iter.dtype() == ScalarType::Float) {
170 cpu_index_kernel<float>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
171 cpu_atomic_add_float((float*)(dst + offset), *(float*)src);
172 });
173 } else {
174 // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
175 // this needs to be thread-safe.
176 cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
177 *(scalar_t*)(dst + offset) += *(scalar_t*)src;
178 }, /*serial_execution=*/true);
179 }
180 } else {
181 cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
182 *(scalar_t*)(dst + offset) = *(scalar_t*)src;
183 }, /*serial_execution=*/is_deterministic);
184 }
185 }),
186 AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
187 AT_EXPAND(AT_FLOAT8_TYPES),
188 kComplexHalf,
189 kHalf,
190 kBool,
191 kBFloat16);
192 }
193
index_fill_kernel(TensorIterator & iter,int64_t dim,int64_t self_dim_size,int64_t self_dim_stride,const Scalar & source)194 void index_fill_kernel(
195 TensorIterator& iter,
196 int64_t dim,
197 int64_t self_dim_size,
198 int64_t self_dim_stride,
199 const Scalar& source) {
200 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, kComplexHalf,
201 iter.dtype(), "index_fill_cpu", [&] {
202 auto fill_val = source.to<scalar_t>();
203 auto handle_nonzero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
204 auto* self_data_bytes = data[0];
205 auto* index_data_bytes = data[1];
206 for (const auto elem C10_UNUSED : c10::irange(n)) {
207 auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
208 auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
209 TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size,
210 "index ", idx, " is out of bounds for dimension ",
211 dim, " with size ", self_dim_size);
212 if (idx < 0) {
213 idx += self_dim_size;
214 }
215
216 self_data[idx * self_dim_stride] = fill_val;
217
218 self_data_bytes += strides[0];
219 index_data_bytes += strides[1];
220 }
221 };
222 auto handle_zero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
223 auto* self_data_bytes = data[0];
224 auto* index_data_bytes = data[1];
225 auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
226 TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size,
227 "index ", idx, " is out of bounds for dimension ",
228 dim, " with size ", self_dim_size);
229 if (idx < 0) {
230 idx += self_dim_size;
231 }
232 for (const auto elem C10_UNUSED: c10::irange(n)) {
233 auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
234
235 self_data[idx * self_dim_stride] = fill_val;
236
237 self_data_bytes += strides[0];
238 }
239 };
240
241 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
242 auto idx_stride = strides[1];
243 if (idx_stride) {
244 handle_nonzero_idx_stride(data, strides, n);
245 }
246 else {
247 handle_zero_idx_stride(data, strides, n);
248 }
249 };
250 iter.for_each(loop);
251 });
252 }
253
index_copy_kernel(TensorIterator & iter,int64_t dim,int64_t self_dim_size,int64_t self_dim_stride)254 void index_copy_kernel(
255 TensorIterator& iter,
256 int64_t dim,
257 int64_t self_dim_size,
258 int64_t self_dim_stride) {
259 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, kComplexHalf,
260 iter.dtype(), "index_copy_cpu", [&] {
261 auto handle_nonzero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
262 auto* self_data_bytes = data[0];
263 auto* index_data_bytes = data[1];
264 auto* source_data_bytes = data[2];
265 for (const auto elem C10_UNUSED : c10::irange(n)) {
266 auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
267 auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
268 auto* source_data = reinterpret_cast<scalar_t*>(source_data_bytes);
269 TORCH_CHECK_INDEX(idx >= 0 && idx < self_dim_size,
270 "index_copy_(): index ", idx, " is out of bounds for dimension ",
271 dim, " with size ", self_dim_size);
272
273 self_data[idx * self_dim_stride] = *source_data;
274
275 self_data_bytes += strides[0];
276 index_data_bytes += strides[1];
277 source_data_bytes += strides[2];
278 }
279 };
280 auto handle_zero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
281 auto* self_data_bytes = data[0];
282 auto* index_data_bytes = data[1];
283 auto* source_data_bytes = data[2];
284 auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
285 TORCH_CHECK_INDEX(idx >= 0 && idx < self_dim_size,
286 "index_copy_(): index ", idx, " is out of bounds for dimension ",
287 dim, " with size ", self_dim_size);
288 for (const auto elem C10_UNUSED : c10::irange(n)) {
289 auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
290 auto* source_data = reinterpret_cast<scalar_t*>(source_data_bytes);
291
292 self_data[idx * self_dim_stride] = *source_data;
293
294 self_data_bytes += strides[0];
295 source_data_bytes += strides[2];
296 }
297 };
298
299 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
300 auto idx_stride = strides[1];
301 if (idx_stride) {
302 handle_nonzero_idx_stride(data, strides, n);
303 }
304 else {
305 handle_zero_idx_stride(data, strides, n);
306 }
307 };
308 bool is_deterministic = at::globalContext().deterministicAlgorithms();
309 if (is_deterministic) {
310 iter.serial_for_each(loop, {0, iter.numel()});
311 } else {
312 iter.for_each(loop);
313 }
314 });
315 }
316
317 template <typename scalar_t>
cpu_masked_fill_kernel(TensorIterator & iter,scalar_t value)318 void cpu_masked_fill_kernel(TensorIterator& iter, scalar_t value) {
319 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
320 char* dst = data[0];
321 char* mask = data[1];
322 for (const auto i : c10::irange(n)) {
323 bool mask_value = *reinterpret_cast<bool*>(mask + strides[1] * i);
324
325 if (mask_value) {
326 *(scalar_t*)(dst + strides[0] * i) = value;
327 }
328 }
329 };
330 iter.for_each(loop);
331 }
332
masked_fill_kernel(TensorIterator & iter,const Scalar & value)333 void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
334 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf,
335 iter.dtype(), "masked_fill", [&] {
336 scalar_t scalar_val = value.to<scalar_t>();
337 auto mask_dtype = iter.input_dtype(0);
338 TORCH_CHECK(mask_dtype == ScalarType::Bool, "masked_fill only supports boolean masks, "
339 "but got mask with dtype ", mask_dtype);
340 cpu_masked_fill_kernel<scalar_t>(iter, scalar_val);
341 });
342 }
343
344 template <typename scalar_t>
cpu_masked_scatter_kernel(TensorIterator & iter,const TensorBase & source)345 void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
346 std::ptrdiff_t source_cntr = 0;
347 const scalar_t* source_ptr = source.const_data_ptr<scalar_t>();
348 auto numel = source.numel();
349
350 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
351 char* dst = data[0];
352 const int64_t dst_stride = strides[0];
353 char* mask = data[1];
354 const int64_t mask_stride = strides[1];
355 for (const auto i : c10::irange(n)) {
356 auto mask_value = *reinterpret_cast<bool*>(mask + mask_stride * i);
357 if (mask_value) {
358 TORCH_CHECK(source_cntr < numel, "Number of elements of source < number of ones in mask");
359 *(scalar_t*)(dst + dst_stride * i) = *(source_ptr);
360 source_ptr++;
361 source_cntr++;
362 }
363 }
364 };
365 iter.serial_for_each(loop, {0, iter.numel()});
366 }
367
masked_scatter_kernel(TensorIterator & iter,const TensorBase & source)368 void masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
369 TORCH_CHECK(iter.input_dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, "
370 "but got mask with dtype ", iter.input_dtype());
371 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
372 ScalarType::Bool,
373 ScalarType::BFloat16,
374 ScalarType::Half,
375 iter.dtype(),
376 "masked_scatter",
377 [&] {
378 cpu_masked_scatter_kernel<scalar_t>(iter, source);
379 });
380 }
381
382 template <typename scalar_t, typename mask_t, typename func_t>
cpu_masked_select_serial_kernel(TensorIterator & iter,const func_t & f)383 void cpu_masked_select_serial_kernel(TensorIterator& iter, const func_t& f) {
384 int64_t offset = 0;
385 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
386 char* dst = data[0];
387 char* src = data[1];
388 char* mask = data[2];
389 for (const auto i : c10::irange(n)) {
390 mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
391 if constexpr (!std::is_same<mask_t, bool>::value) {
392 TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
393 }
394 if (mask_value) {
395 int64_t offset_bytes = offset * sizeof(scalar_t);
396 f(dst, src + strides[1] * i, offset_bytes);
397 offset++;
398 }
399 }
400 };
401 iter.serial_for_each(loop, {0, iter.numel()});
402 }
403
masked_select_serial_kernel(TensorIterator & iter,int64_t result_stride)404 void masked_select_serial_kernel(TensorIterator& iter, int64_t result_stride) {
405 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
406 iter.dtype(), "masked_select", [&] {
407 auto mask_dtype = iter.input_dtype(1);
408 if (mask_dtype == ScalarType::Bool) {
409 cpu_masked_select_serial_kernel<scalar_t, bool>(iter, [result_stride](char* dst, char* src, int64_t offset) {
410 *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
411 });
412 } else {
413 cpu_masked_select_serial_kernel<scalar_t, unsigned char>(iter, [result_stride](char* dst, char* src, int64_t offset) {
414 *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
415 });
416 }
417 });
418 }
419
420 template <typename scalar_t, typename mask_t, typename func_t>
cpu_masked_select_kernel(TensorIterator & iter,const func_t & f)421 void cpu_masked_select_kernel(TensorIterator& iter, const func_t& f) {
422 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
423 char* dst = data[0];
424 char* src = data[1];
425 char* mask = data[2];
426 char* mask_prefix_sum = data[3];
427 for (const auto i : c10::irange(n)) {
428 mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
429 if constexpr (!std::is_same<mask_t, bool>::value) {
430 TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
431 }
432 if (mask_value) {
433 int64_t offset = *(int64_t*)(mask_prefix_sum + strides[3] * i);
434 int64_t offset_bytes = (offset - 1) * sizeof(scalar_t);
435 f(dst, src + strides[1] * i, offset_bytes);
436 }
437 }
438 };
439 iter.for_each(loop);
440 }
441
masked_select_kernel(TensorIterator & iter,int64_t result_stride)442 void masked_select_kernel(TensorIterator& iter, int64_t result_stride) {
443 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half,
444 iter.dtype(), "masked_select", [&] {
445 auto mask_dtype = iter.input_dtype(1);
446 if (mask_dtype == ScalarType::Bool) {
447 cpu_masked_select_kernel<scalar_t, bool>(iter, [result_stride](char* dst, char* src, int64_t offset) {
448 *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
449 });
450 } else {
451 cpu_masked_select_kernel<scalar_t, unsigned char>(iter, [result_stride](char* dst, char* src, int64_t offset) {
452 *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
453 });
454 }
455 });
456 }
457
458 template <typename scalar_t>
cpu_hflip_vec(at::TensorIterator & iter)459 void cpu_hflip_vec(at::TensorIterator& iter) {
460
461 auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
462
463 // Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
464 // and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
465 // output and input.
466 static constexpr int ntensors = 2;
467 const int64_t *outer_strides = &strides[3];
468
469 std::array<char*, ntensors> data_arr;
470 std::copy_n(base, ntensors, data_arr.data());
471
472 using Vec = Vectorized<scalar_t>;
473
474 constexpr auto stride = sizeof(scalar_t);
475 TORCH_INTERNAL_ASSERT(stride == -strides[0] && stride == strides[1]);
476
477 for (const auto j C10_UNUSED : c10::irange(size1)) {
478
479 // vectorized loop with negative stride for output
480 char** C10_RESTRICT data_ = data_arr.data();
481 int64_t n = size0;
482
483 char* C10_RESTRICT data[ntensors];
484 for (const auto arg : c10::irange(ntensors)) {
485 data[arg] = data_[arg];
486 }
487
488 int64_t i = 0;
489
490 // data[0] unaligned pre-pass
491 int64_t offset = (j * n + (n - i - Vec::size())) % 32;
492 offset = (offset >= n) ? n : offset;
493 for (; i < offset; i++) {
494 scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
495 *out_ptr = *(scalar_t *)(data[1] + i * stride);
496 }
497 // Empirically found that it is faster to process 3 data items together vs 2 or 4
498 for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) {
499 auto out1 = Vec::loadu(data[1] + i * stride);
500 auto out2 = Vec::loadu(data[1] + (i + Vec::size()) * stride);
501 auto out3 = Vec::loadu(data[1] + (i + 2 * Vec::size()) * stride);
502 // flip the vector: 1234 -> 4321
503 out1 = flip(out1);
504 out2 = flip(out2);
505 out3 = flip(out3);
506 out1.store(data[0] - (i + Vec::size() - 1) * stride);
507 out2.store(data[0] - (i + 2 * Vec::size() - 1) * stride);
508 out3.store(data[0] - (i + 3 * Vec::size() - 1) * stride);
509 }
510 if (i < n) {
511 for (; i < n; i++) {
512 scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
513 *out_ptr = *(scalar_t *)(data[1] + i * stride);
514 }
515 }
516
517 // advance:
518 for (const auto arg : c10::irange(ntensors)) {
519 data_arr[arg] += outer_strides[arg];
520 }
521 }
522 };
523
524 int64_t grain_size = at::internal::GRAIN_SIZE;
525 iter.for_each(loop2d, grain_size);
526 iter.cast_outputs();
527 }
528
cpu_vflip_memcpy(at::TensorIterator & iter)529 void cpu_vflip_memcpy(at::TensorIterator& iter) {
530 // This is a vertical flip specialization using memcpy to speed-up the runtime
531
532 auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
533
534 // Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
535 // and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
536 // output and input.
537 static constexpr int ntensors = 2;
538 const int64_t *outer_strides = &strides[3];
539
540 std::array<char*, ntensors> data_arr;
541 std::copy_n(base, ntensors, data_arr.data());
542
543 TORCH_INTERNAL_ASSERT(strides[0] == strides[1]);
544 const int64_t stride = strides[0];
545
546 for (const auto j C10_UNUSED : c10::irange(size1)) {
547
548 char** C10_RESTRICT data_ = data_arr.data();
549 int64_t n = size0;
550
551 char* C10_RESTRICT data[ntensors];
552 for (const auto arg : c10::irange(ntensors)) {
553 data[arg] = data_[arg];
554 }
555
556 memcpy(data[0], data[1], n * stride);
557
558 // advance:
559 for (const auto arg : c10::irange(data_arr.size())) {
560 data_arr[arg] += outer_strides[arg];
561 }
562 }
563 };
564
565 int64_t grain_size = at::internal::GRAIN_SIZE;
566 iter.for_each(loop2d, grain_size);
567 iter.cast_outputs();
568 }
569
570 constexpr int64_t hflip_mask_size = 32;
571
generate_vec_hflip_reg_mask(int64_t data_stride)572 std::array<char, hflip_mask_size> generate_vec_hflip_reg_mask(int64_t data_stride) {
573 std::array<char, hflip_mask_size> mask;
574 for (const auto k : c10::irange(hflip_mask_size / 2)) {
575 int j = k / data_stride + 1;
576 int v = (j * data_stride - 1) - (k % data_stride);
577 v = std::min(v, (int) (hflip_mask_size / 2 - 1));
578 mask[hflip_mask_size - 1 - k] = v;
579 mask[hflip_mask_size / 2 - 1 - k] = v;
580 }
581 return mask;
582 }
583
vectorized_cpu_hflip_channels_last(char * C10_RESTRICT * data,const int64_t data_size,const int64_t data_stride,const std::array<char,32> & mdata)584 int64_t vectorized_cpu_hflip_channels_last(
585 char * C10_RESTRICT *data, const int64_t data_size, const int64_t data_stride, const std::array<char, 32> & mdata) {
586
587 int64_t i = 0;
588 #ifdef CPU_CAPABILITY_AVX2
589
590 constexpr auto vec_size = 256 / 8;
591
592 if (data_size > vec_size) {
593
594 // Example for num channels=3 and dtype=uint8
595 // -> data_stride = 3
596 // -> usable_vec_stride = 30
597 // -> usable_vec_half_stride = 15
598 // Data: (1 2 3) (4 5 6) (7 8 9) (10 11 12) (13 14 15) (16 17 18) (19 20 21) (22 23 24) (25 26 27) (28 29 30) (31 32 33)
599 // load by 2 parts
600 // R = [ (1 2 3) (4 5 6) (7 8 9) (10 11 12) (13 14 15) (16 | (16 17 18) (19 20 21) (22 23 24) (25 26 27) (28 29 30) (31 ]
601 // flip(R) ->
602 // R = [ 31 (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) | 16 (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3) ]
603 //
604 // Write in 2 parts
605 // Output pointer: output_ptr = data[0] v
606 // - Init:
607 // (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X)
608 // 0) Move to initial position: output_ptr = data[0] + data_stride - vec_size / 2;
609 // v
610 // (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X) (X X X)
611 // - In the loop:
612 // 1) Write 1st block from output_ptr
613 // v
614 // |----> vec_size / 2 ---------------------------|
615 // Output part 1: (X X X) (X X X) (X X X) (X X X) (X X X) (X X 16) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
616 // 2) Write 2nd block from output_ptr - usable_vec_half_stride:
617 // v
618 // |-----> vec_size / 2 ----------------------------------|
619 // Output part 2: (X X 31) (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
620 //
621 // 3) Move to the next position: output_ptr -= usable_vec_stride
622 //
623 // - After the loop:
624 // 4) Move to write position
625 // v
626 // (X X 31) (28 29 30) (25 26 27) (22 23 24) (19 20 21) (16 17 18) (13 14 15) (10 11 12) (7 8 9) (4 5 6) (1 2 3)
627
628 const __m256i mask = _mm256_loadu_si256((__m256i *) mdata.data());
629
630 const auto usable_vec_stride = 2 * (vec_size / 2 / data_stride) * data_stride;
631 const auto usable_vec_half_stride = usable_vec_stride / 2;
632
633 auto output_ptr = data[0] + data_stride - vec_size / 2;
634 auto input_ptr = data[1];
635
636 for (; i < data_size - vec_size; i += usable_vec_stride) {
637
638 // load 256-bits by two 128-bits parts
639 auto a0 = _mm_loadu_si128((__m128i *) (input_ptr + i));
640 auto b0 = _mm256_castsi128_si256(a0);
641 auto a1 = _mm_loadu_si128((__m128i *) (input_ptr + i + usable_vec_half_stride));
642 auto data_vec = _mm256_inserti128_si256(b0, a1, 1);
643
644 auto reversed_vec = _mm256_shuffle_epi8(data_vec, mask);
645
646 // write output in two parts
647 auto rev_vec_h = _mm256_extracti128_si256(reversed_vec, 0);
648 _mm_storeu_si128((__m128i *) (output_ptr - i), rev_vec_h);
649 auto rev_vec_l = _mm256_extracti128_si256(reversed_vec, 1);
650 _mm_storeu_si128((__m128i *) (output_ptr - i - usable_vec_half_stride), rev_vec_l);
651 }
652
653 data[0] -= i;
654 data[1] += i;
655 }
656 #endif
657 return i;
658 }
659
cpu_hflip_channels_last_vec(at::TensorIterator & iter)660 void cpu_hflip_channels_last_vec(at::TensorIterator& iter) {
661
662 auto input_strides = iter.strides(1);
663 const auto data_stride = input_strides[1];
664
665 // Generate avx mask once
666 alignas(hflip_mask_size) auto mdata = generate_vec_hflip_reg_mask(data_stride);
667
668 auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
669
670 // Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
671 // and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
672 // output and input.
673 static constexpr int ntensors = 2;
674 const int64_t *outer_strides = &strides[3];
675 const int64_t stride = strides[0];
676
677 TORCH_INTERNAL_ASSERT(stride == strides[1]);
678
679 auto c = -outer_strides[0];
680 TORCH_INTERNAL_ASSERT(c == outer_strides[1]);
681
682 char* C10_RESTRICT data[ntensors] = {base[0], base[1]};
683 const int64_t size = size0 * size1;
684
685 int64_t i = 0;
686
687 if (c >= 2 && c <= 16) {
688 i = vectorized_cpu_hflip_channels_last(data, size * stride, c, mdata) / stride;
689 }
690
691 auto data_stride = size0 * stride;
692 for (; i < size; i += size0) {
693
694 memcpy(data[0], data[1], data_stride);
695
696 // advance:
697 for (const auto arg : c10::irange(ntensors)) {
698 data[arg] += outer_strides[arg];
699 }
700 }
701
702 };
703
704 int64_t grain_size = at::internal::GRAIN_SIZE;
705 iter.for_each(loop2d, grain_size);
706 iter.cast_outputs();
707 }
708
flip_kernel(TensorIterator & iter,const bool quantized)709 void flip_kernel(TensorIterator& iter, const bool quantized) {
710 if (quantized) {
711 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_cpu",
712 [&iter] { cpu_kernel(iter,
713 [](scalar_t a, scalar_t /*dummy input*/) -> scalar_t {
714 return a;
715 });
716 });
717 } else {
718 auto output_strides = iter.strides(0);
719 auto input_strides = iter.strides(1);
720 if (iter.ndim() > 0 && output_strides[0] == -iter.element_size(0) && input_strides[0] == iter.element_size(1)) {
721 // Special case: horizontal flip with vectorization and input is contiguous
722 // Context: horizontal flip leads to strides[0] < 0 and
723 // thus is_contiguous condition is not satisfied and non-vectorized code path is taken.
724 auto iter_dtype = iter.dtype();
725 // Ignoring half and bfloat16 as cpu_hflip_vec is slower than cpu_kernel_vec
726 if (isIntegralType(iter_dtype, true) || iter_dtype == kDouble || iter_dtype == kFloat) {
727 // Replace AT_DISPATCH_ALL_TYPES_AND by manual if/else due to internal test failures:
728 // - "dtype 'Float' not selected for kernel tag hflip_cpu"
729 // - "dtype 'Long' not selected for kernel tag hflip_cpu"
730 //
731 // AT_DISPATCH_ALL_TYPES_AND(kBool,
732 // iter_dtype, "hflip_cpu", [&iter] {
733 // cpu_hflip_vec<scalar_t>(iter);
734 // });
735
736 if (iter_dtype == kByte) {
737 return cpu_hflip_vec<uint8_t>(iter);
738 } else if (iter_dtype == kChar) {
739 return cpu_hflip_vec<int8_t>(iter);
740 } else if (iter_dtype == kInt) {
741 return cpu_hflip_vec<int32_t>(iter);
742 } else if (iter_dtype == kLong) {
743 return cpu_hflip_vec<int64_t>(iter);
744 } else if (iter_dtype == kShort) {
745 return cpu_hflip_vec<int16_t>(iter);
746 } else if (iter_dtype == kBool) {
747 return cpu_hflip_vec<bool>(iter);
748 } else if (iter_dtype == kFloat) {
749 return cpu_hflip_vec<float>(iter);
750 } else if (iter_dtype == kDouble) {
751 return cpu_hflip_vec<double>(iter);
752 }
753 }
754 // other dtypes (float16, bfloat16, complex) are handled by cpu_kernel_vec (see below)
755 } else if (iter.has_contiguous_first_dim()) {
756 // Special cases:
757 // a) channels last hflip on (N, C, H, W) and outer_stride(=dtype_size * C) in [2, 16]
758 // b) flip dim=-2 on (N, ..., M, C) and outer_stride(=dtype_size * C) in [2, 16]
759 auto output_strides_2 = iter.strides(0);
760 auto input_strides_2 = iter.strides(1);
761 auto c = -output_strides_2[1];
762 if (c >= 2 && c <= 16 &&
763 c == input_strides_2[1] &&
764 c == iter.element_size(0) * iter.shape()[0] // checks if dim=1 is contiguous as well
765 ) {
766 return cpu_hflip_channels_last_vec(iter);
767 }
768 // Special case: vertical flip using memcpy (faster than generic cpu_kernel_vec)
769 return cpu_vflip_memcpy(iter);
770 }
771
772 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(), "flip_cpu",
773 [&iter] { cpu_kernel_vec(iter,
774 [](scalar_t a, scalar_t /*dummy input*/) -> scalar_t {
775 return a;
776 },
777 [](Vectorized<scalar_t> a, Vectorized<scalar_t> /*dummy input*/) -> Vectorized<scalar_t> {
778 return a;
779 });
780 });
781 }
782 }
783
784 } // anonymous namespace
785
786 REGISTER_DISPATCH(index_stub, &index_kernel);
787 REGISTER_DISPATCH(index_fill_stub, &index_fill_kernel);
788 REGISTER_DISPATCH(index_copy_stub, &index_copy_kernel);
789 REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
790 REGISTER_DISPATCH(put_stub, &put_kernel);
791 REGISTER_DISPATCH(take_stub, &take_kernel);
792 REGISTER_DISPATCH(masked_fill_stub, &masked_fill_kernel);
793 REGISTER_DISPATCH(masked_select_serial_stub, &masked_select_serial_kernel);
794 REGISTER_DISPATCH(masked_select_stub, &masked_select_kernel);
795 REGISTER_DISPATCH(masked_scatter_stub, &masked_scatter_kernel);
796 REGISTER_DISPATCH(flip_stub, &flip_kernel);
797
798 } // namespace at::native
799