#pragma once /* Provides a subset of MKL Sparse BLAS functions as templates: mv(operation, alpha, A, descr, x, beta, y) where scalar_t is double, float, c10::complex or c10::complex. The functions are available in at::mkl::sparse namespace. */ #include #include #include namespace at::mkl::sparse { #define MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t) \ sparse_matrix_t *A, const sparse_index_base_t indexing, const MKL_INT rows, \ const MKL_INT cols, MKL_INT *rows_start, MKL_INT *rows_end, \ MKL_INT *col_indx, scalar_t *values template inline void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::create_csr: not implemented for ", typeid(scalar_t).name()); } template <> void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(float)); template <> void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(double)); template <> void create_csr>( MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex)); template <> void create_csr>( MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex)); #define MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t) \ sparse_matrix_t *A, const sparse_index_base_t indexing, \ const sparse_layout_t block_layout, const MKL_INT rows, \ const MKL_INT cols, MKL_INT block_size, MKL_INT *rows_start, \ MKL_INT *rows_end, MKL_INT *col_indx, scalar_t *values template inline void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::create_bsr: not implemented for ", typeid(scalar_t).name()); } template <> void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(float)); template <> void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(double)); template <> void create_bsr>( MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex)); template <> void create_bsr>( MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex)); #define MKL_SPARSE_MV_ARGTYPES(scalar_t) \ const sparse_operation_t operation, const scalar_t alpha, \ const sparse_matrix_t A, const struct matrix_descr descr, \ const scalar_t *x, const scalar_t beta, scalar_t *y template inline void mv(MKL_SPARSE_MV_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::mv: not implemented for ", typeid(scalar_t).name()); } template <> void mv(MKL_SPARSE_MV_ARGTYPES(float)); template <> void mv(MKL_SPARSE_MV_ARGTYPES(double)); template <> void mv>(MKL_SPARSE_MV_ARGTYPES(c10::complex)); template <> void mv>(MKL_SPARSE_MV_ARGTYPES(c10::complex)); #define MKL_SPARSE_ADD_ARGTYPES(scalar_t) \ const sparse_operation_t operation, const sparse_matrix_t A, \ const scalar_t alpha, const sparse_matrix_t B, sparse_matrix_t *C template inline void add(MKL_SPARSE_ADD_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::add: not implemented for ", typeid(scalar_t).name()); } template <> void add(MKL_SPARSE_ADD_ARGTYPES(float)); template <> void add(MKL_SPARSE_ADD_ARGTYPES(double)); template <> void add>(MKL_SPARSE_ADD_ARGTYPES(c10::complex)); template <> void add>(MKL_SPARSE_ADD_ARGTYPES(c10::complex)); #define MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t) \ const sparse_matrix_t source, sparse_index_base_t *indexing, MKL_INT *rows, \ MKL_INT *cols, MKL_INT **rows_start, MKL_INT **rows_end, \ MKL_INT **col_indx, scalar_t **values template inline void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::export_csr: not implemented for ", typeid(scalar_t).name()); } template <> void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(float)); template <> void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(double)); template <> void export_csr>( MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex)); template <> void export_csr>( MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex)); #define MKL_SPARSE_MM_ARGTYPES(scalar_t) \ const sparse_operation_t operation, const scalar_t alpha, \ const sparse_matrix_t A, const struct matrix_descr descr, \ const sparse_layout_t layout, const scalar_t *B, const MKL_INT columns, \ const MKL_INT ldb, const scalar_t beta, scalar_t *C, const MKL_INT ldc template inline void mm(MKL_SPARSE_MM_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::mm: not implemented for ", typeid(scalar_t).name()); } template <> void mm(MKL_SPARSE_MM_ARGTYPES(float)); template <> void mm(MKL_SPARSE_MM_ARGTYPES(double)); template <> void mm>(MKL_SPARSE_MM_ARGTYPES(c10::complex)); template <> void mm>(MKL_SPARSE_MM_ARGTYPES(c10::complex)); #define MKL_SPARSE_SPMMD_ARGTYPES(scalar_t) \ const sparse_operation_t operation, const sparse_matrix_t A, \ const sparse_matrix_t B, const sparse_layout_t layout, scalar_t *C, \ const MKL_INT ldc template inline void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::spmmd: not implemented for ", typeid(scalar_t).name()); } template <> void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(float)); template <> void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(double)); template <> void spmmd>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex)); template <> void spmmd>( MKL_SPARSE_SPMMD_ARGTYPES(c10::complex)); #define MKL_SPARSE_TRSV_ARGTYPES(scalar_t) \ const sparse_operation_t operation, const scalar_t alpha, \ const sparse_matrix_t A, const struct matrix_descr descr, \ const scalar_t *x, scalar_t *y template inline sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::trsv: not implemented for ", typeid(scalar_t).name()); } template <> sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(float)); template <> sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(double)); template <> sparse_status_t trsv>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex)); template <> sparse_status_t trsv>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex)); #define MKL_SPARSE_TRSM_ARGTYPES(scalar_t) \ const sparse_operation_t operation, const scalar_t alpha, \ const sparse_matrix_t A, const struct matrix_descr descr, \ const sparse_layout_t layout, const scalar_t *x, const MKL_INT columns, \ const MKL_INT ldx, scalar_t *y, const MKL_INT ldy template inline sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(scalar_t)) { TORCH_INTERNAL_ASSERT( false, "at::mkl::sparse::trsm: not implemented for ", typeid(scalar_t).name()); } template <> sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(float)); template <> sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(double)); template <> sparse_status_t trsm>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex)); template <> sparse_status_t trsm>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex)); } // namespace at::mkl::sparse