xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/sparse/SparseStubs.h>
3 #include <ATen/native/sparse/FlattenIndicesCommon.h>
4 #include <ATen/native/cpu/Loops.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/AccumulateType.h>
7 
8 namespace at::native {
9 
10 namespace {
11 
12 template <typename func_t>
13 struct CPUKernelLauncher {
launchat::native::__anon8f4f8c560111::CPUKernelLauncher14   static void launch(TensorIteratorBase& iter, const func_t& f) {
15     cpu_kernel(iter, f);
16   }
17 };
18 
flatten_indices_cpu_kernel(const Tensor & indices,IntArrayRef size)19 Tensor flatten_indices_cpu_kernel(const Tensor& indices, IntArrayRef size) {
20   return _flatten_indices<CPUKernelLauncher>(indices, size);
21 }
22 
23 }
24 
25 REGISTER_ARCH_DISPATCH(flatten_indices_stub, DEFAULT, &flatten_indices_cpu_kernel);
26 REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
27 REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
28 REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
29 REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
30 
31 } // namespace at::native
32