xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/ATenCUDAGeneral.h>
4 
5 namespace at::native::sparse::cuda{
6 
7 TORCH_CUDA_CU_API void Xcoo2csr(
8     const int* coorowind,
9     int64_t nnz,
10     int64_t m,
11     int* csrrowptr);
12 
13 /* Level 3 */
14 template <typename T>
15 TORCH_CUDA_CU_API void csrmm2(
16     char transa,
17     char transb,
18     int64_t m,
19     int64_t n,
20     int64_t k,
21     int64_t nnz,
22     T alpha,
23     T* csrvala,
24     int* csrrowptra,
25     int* csrcolinda,
26     T* b,
27     int64_t ldb,
28     T beta,
29     T* c,
30     int64_t ldc);
31 
32 /* format conversion */
33 TORCH_CUDA_CU_API void CreateIdentityPermutation(int64_t nnz, int* P);
34 TORCH_CUDA_CU_API void Xcsrsort_bufferSizeExt(
35     int64_t m,
36     int64_t n,
37     int64_t nnz,
38     const int* csrRowPtr,
39     const int* csrColInd,
40     size_t* pBufferSizeInBytes);
41 TORCH_CUDA_CU_API void Xcsrsort(
42     int64_t m,
43     int64_t n,
44     int64_t nnz,
45     const int* csrRowPtr,
46     int* csrColInd,
47     int* P,
48     void* pBuffer);
49 TORCH_CUDA_CU_API void Xcoosort_bufferSizeExt(
50     int64_t m,
51     int64_t n,
52     int64_t nnz,
53     const int* cooRows,
54     const int* cooCols,
55     size_t* pBufferSizeInBytes);
56 TORCH_CUDA_CU_API void XcoosortByRow(
57     int64_t m,
58     int64_t n,
59     int64_t nnz,
60     int* cooRows,
61     int* cooCols,
62     int* P,
63     void* pBuffer);
64 } // namespace at::native::sparse::cuda
65