1 #pragma once 2 3 #include <ATen/cuda/ATenCUDAGeneral.h> 4 5 namespace at::native::sparse::cuda{ 6 7 TORCH_CUDA_CU_API void Xcoo2csr( 8 const int* coorowind, 9 int64_t nnz, 10 int64_t m, 11 int* csrrowptr); 12 13 /* Level 3 */ 14 template <typename T> 15 TORCH_CUDA_CU_API void csrmm2( 16 char transa, 17 char transb, 18 int64_t m, 19 int64_t n, 20 int64_t k, 21 int64_t nnz, 22 T alpha, 23 T* csrvala, 24 int* csrrowptra, 25 int* csrcolinda, 26 T* b, 27 int64_t ldb, 28 T beta, 29 T* c, 30 int64_t ldc); 31 32 /* format conversion */ 33 TORCH_CUDA_CU_API void CreateIdentityPermutation(int64_t nnz, int* P); 34 TORCH_CUDA_CU_API void Xcsrsort_bufferSizeExt( 35 int64_t m, 36 int64_t n, 37 int64_t nnz, 38 const int* csrRowPtr, 39 const int* csrColInd, 40 size_t* pBufferSizeInBytes); 41 TORCH_CUDA_CU_API void Xcsrsort( 42 int64_t m, 43 int64_t n, 44 int64_t nnz, 45 const int* csrRowPtr, 46 int* csrColInd, 47 int* P, 48 void* pBuffer); 49 TORCH_CUDA_CU_API void Xcoosort_bufferSizeExt( 50 int64_t m, 51 int64_t n, 52 int64_t nnz, 53 const int* cooRows, 54 const int* cooCols, 55 size_t* pBufferSizeInBytes); 56 TORCH_CUDA_CU_API void XcoosortByRow( 57 int64_t m, 58 int64_t n, 59 int64_t nnz, 60 int* cooRows, 61 int* cooCols, 62 int* P, 63 void* pBuffer); 64 } // namespace at::native::sparse::cuda 65