1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/sparse/SparseStubs.h>
3 #include <ATen/native/sparse/FlattenIndicesCommon.h>
4 #include <ATen/native/cpu/Loops.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/AccumulateType.h>
7
8 namespace at::native {
9
10 namespace {
11
12 template <typename func_t>
13 struct CPUKernelLauncher {
launchat::native::__anon8f4f8c560111::CPUKernelLauncher14 static void launch(TensorIteratorBase& iter, const func_t& f) {
15 cpu_kernel(iter, f);
16 }
17 };
18
flatten_indices_cpu_kernel(const Tensor & indices,IntArrayRef size)19 Tensor flatten_indices_cpu_kernel(const Tensor& indices, IntArrayRef size) {
20 return _flatten_indices<CPUKernelLauncher>(indices, size);
21 }
22
23 }
24
25 REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel);
26 REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
27 REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
28 REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
29 REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
30
31 } // namespace at::native
32