xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/FlattenIndicesKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/sparse/SparseStubs.h>
3 #include <ATen/native/sparse/FlattenIndicesCommon.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/cuda/KernelUtils.cuh>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/AccumulateType.h>
8 
9 namespace at::native {
10 
11 namespace {
12 
13 template <typename func_t>
14 struct CUDAKernelLauncher {
launchat::native::__anon9940a87a0111::CUDAKernelLauncher15   static void launch(TensorIteratorBase& iter, const func_t& f) {
16     gpu_kernel(iter, f);
17   }
18 };
19 
flatten_indices_cuda_kernel(const Tensor & indices,IntArrayRef size)20 Tensor flatten_indices_cuda_kernel(const Tensor& indices, IntArrayRef size) {
21   return _flatten_indices<CUDAKernelLauncher>(indices, size);
22 }
23 
24 }
25 
26 REGISTER_CUDA_DISPATCH(flatten_indices_stub, &flatten_indices_cuda_kernel);
27 
28 } // namespace at::native
29