1 #include <ATen/native/sparse/ValidateCompressedIndicesCommon.h>
2 #include <ATen/native/cpu/Loops.h>
3
4 #ifdef AT_PER_OPERATOR_HEADERS
5 #include <ATen/ops/_validate_compressed_sparse_indices_native.h>
6 #endif
7
8 namespace at::native {
9
10 namespace {
11
12 template <typename func_t>
13 struct CPUKernel {
launchat::native::__anone85d6fa70111::CPUKernel14 static void launch(TensorIteratorBase& iter, const func_t& f) {
15 cpu_kernel(iter, f);
16 }
17 };
18
19 template <typename func_t>
20 struct EmptyKernel {
launchat::native::__anone85d6fa70111::EmptyKernel21 static void launch(TensorIteratorBase& iter, const func_t& f) {
22 }
23 };
24
25 template <typename func_t, typename vec_func_t>
26 struct CPUVecKernel {
launchat::native::__anone85d6fa70111::CPUVecKernel27 static void launch(TensorIteratorBase& iter, const func_t& f, const vec_func_t& vec_f) {
28 cpu_kernel_vec(iter, f, vec_f);
29 }
30 };
31
32 }
33
_validate_compressed_sparse_indices_cpu(const bool is_crow,const Tensor & cidx,const Tensor & idx,const int64_t cdim,const int64_t dim,const int64_t nnz)34 void _validate_compressed_sparse_indices_cpu(
35 const bool is_crow,
36 const Tensor& cidx,
37 const Tensor& idx,
38 const int64_t cdim,
39 const int64_t dim,
40 const int64_t nnz) {
41 // Call into
42 // compressed_index_invariance_checks_kernel<EmptyKernel, CPUVecKernel, Vectorized>
43 // to enable vectorized checks once all the conditions for that are met,
44 // see ATen/native/sparse/CompressedIndexChecksCommon.h for more details.
45 validate_compressed_sparse_indices_kernel<CPUKernel>(
46 is_crow, cidx, idx, cdim, dim, nnz);
47 }
48
49 } //namespace at::native
50