xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/ValidateCompressedIndicesKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/sparse/ValidateCompressedIndicesCommon.h>
2 #include <ATen/native/cpu/Loops.h>
3 
4 #ifdef AT_PER_OPERATOR_HEADERS
5 #include <ATen/ops/_validate_compressed_sparse_indices_native.h>
6 #endif
7 
8 namespace at::native {
9 
10 namespace {
11 
12 template <typename func_t>
13 struct CPUKernel {
launchat::native::__anone85d6fa70111::CPUKernel14   static void launch(TensorIteratorBase& iter, const func_t& f) {
15     cpu_kernel(iter, f);
16   }
17 };
18 
19 template <typename func_t>
20 struct EmptyKernel {
launchat::native::__anone85d6fa70111::EmptyKernel21   static void launch(TensorIteratorBase& iter, const func_t& f) {
22   }
23 };
24 
25 template <typename func_t, typename vec_func_t>
26 struct CPUVecKernel {
launchat::native::__anone85d6fa70111::CPUVecKernel27   static void launch(TensorIteratorBase& iter, const func_t& f, const vec_func_t& vec_f) {
28     cpu_kernel_vec(iter, f, vec_f);
29   }
30 };
31 
32 }
33 
_validate_compressed_sparse_indices_cpu(const bool is_crow,const Tensor & cidx,const Tensor & idx,const int64_t cdim,const int64_t dim,const int64_t nnz)34 void _validate_compressed_sparse_indices_cpu(
35     const bool is_crow,
36     const Tensor& cidx,
37     const Tensor& idx,
38     const int64_t cdim,
39     const int64_t dim,
40     const int64_t nnz) {
41   // Call into
42   // compressed_index_invariance_checks_kernel<EmptyKernel, CPUVecKernel, Vectorized>
43   // to enable vectorized checks once all the conditions for that are met,
44   // see ATen/native/sparse/CompressedIndexChecksCommon.h for more details.
45   validate_compressed_sparse_indices_kernel<CPUKernel>(
46       is_crow, cidx, idx, cdim, dim, nnz);
47 }
48 
49 } //namespace at::native
50