xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDASparseBlas.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 /*
4   Provides a subset of cuSPARSE functions as templates:
5 
6     csrgeam2<scalar_t>(...)
7 
8   where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
9   The functions are available in at::cuda::sparse namespace.
10 */
11 
12 #include <ATen/cuda/CUDAContext.h>
13 #include <ATen/cuda/CUDASparse.h>
14 
15 namespace at::cuda::sparse {
16 
17 #define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)             \
18   cusparseHandle_t handle, int m, int n, const scalar_t *alpha,     \
19       const cusparseMatDescr_t descrA, int nnzA,                    \
20       const scalar_t *csrSortedValA, const int *csrSortedRowPtrA,   \
21       const int *csrSortedColIndA, const scalar_t *beta,            \
22       const cusparseMatDescr_t descrB, int nnzB,                    \
23       const scalar_t *csrSortedValB, const int *csrSortedRowPtrB,   \
24       const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
25       const scalar_t *csrSortedValC, const int *csrSortedRowPtrC,   \
26       const int *csrSortedColIndC, size_t *pBufferSizeInBytes
27 
28 template <typename scalar_t>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (scalar_t))29 inline void csrgeam2_bufferSizeExt(
30     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
31   TORCH_INTERNAL_ASSERT(
32       false,
33       "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
34       typeid(scalar_t).name());
35 }
36 
37 template <>
38 void csrgeam2_bufferSizeExt<float>(
39     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
40 template <>
41 void csrgeam2_bufferSizeExt<double>(
42     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
43 template <>
44 void csrgeam2_bufferSizeExt<c10::complex<float>>(
45     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
46 template <>
47 void csrgeam2_bufferSizeExt<c10::complex<double>>(
48     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
49 
50 #define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()                                      \
51   cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA,     \
52       int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA,     \
53       const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
54       const int *csrSortedColIndB, const cusparseMatDescr_t descrC,           \
55       int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
56 
57 template <typename scalar_t>
csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES ())58 inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
59   TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
60       handle,
61       m,
62       n,
63       descrA,
64       nnzA,
65       csrSortedRowPtrA,
66       csrSortedColIndA,
67       descrB,
68       nnzB,
69       csrSortedRowPtrB,
70       csrSortedColIndB,
71       descrC,
72       csrSortedRowPtrC,
73       nnzTotalDevHostPtr,
74       workspace));
75 }
76 
77 #define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)                                 \
78   cusparseHandle_t handle, int m, int n, const scalar_t *alpha,              \
79       const cusparseMatDescr_t descrA, int nnzA,                             \
80       const scalar_t *csrSortedValA, const int *csrSortedRowPtrA,            \
81       const int *csrSortedColIndA, const scalar_t *beta,                     \
82       const cusparseMatDescr_t descrB, int nnzB,                             \
83       const scalar_t *csrSortedValB, const int *csrSortedRowPtrB,            \
84       const int *csrSortedColIndB, const cusparseMatDescr_t descrC,          \
85       scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
86       void *pBuffer
87 
88 template <typename scalar_t>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (scalar_t))89 inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
90   TORCH_INTERNAL_ASSERT(
91       false,
92       "at::cuda::sparse::csrgeam2: not implemented for ",
93       typeid(scalar_t).name());
94 }
95 
96 template <>
97 void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
98 template <>
99 void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
100 template <>
101 void csrgeam2<c10::complex<float>>(
102     CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
103 template <>
104 void csrgeam2<c10::complex<double>>(
105     CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
106 
107 #define CUSPARSE_BSRMM_ARGTYPES(scalar_t)                                    \
108   cusparseHandle_t handle, cusparseDirection_t dirA,                         \
109       cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
110       int kb, int nnzb, const scalar_t *alpha,                               \
111       const cusparseMatDescr_t descrA, const scalar_t *bsrValA,              \
112       const int *bsrRowPtrA, const int *bsrColIndA, int blockDim,            \
113       const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
114 
115 template <typename scalar_t>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (scalar_t))116 inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
117   TORCH_INTERNAL_ASSERT(
118       false,
119       "at::cuda::sparse::bsrmm: not implemented for ",
120       typeid(scalar_t).name());
121 }
122 
123 template <>
124 void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
125 template <>
126 void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
127 template <>
128 void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
129 template <>
130 void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
131 
132 #define CUSPARSE_BSRMV_ARGTYPES(scalar_t)                                    \
133   cusparseHandle_t handle, cusparseDirection_t dirA,                         \
134       cusparseOperation_t transA, int mb, int nb, int nnzb,                  \
135       const scalar_t *alpha, const cusparseMatDescr_t descrA,                \
136       const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
137       int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
138 
139 template <typename scalar_t>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (scalar_t))140 inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
141   TORCH_INTERNAL_ASSERT(
142       false,
143       "at::cuda::sparse::bsrmv: not implemented for ",
144       typeid(scalar_t).name());
145 }
146 
147 template <>
148 void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
149 template <>
150 void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
151 template <>
152 void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
153 template <>
154 void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
155 
156 #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
157 
158 #define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)                 \
159   cusparseHandle_t handle, cusparseDirection_t dirA,              \
160       cusparseOperation_t transA, int mb, int nnzb,               \
161       const cusparseMatDescr_t descrA, scalar_t *bsrValA,         \
162       const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
163       bsrsv2Info_t info, int *pBufferSizeInBytes
164 
165 template <typename scalar_t>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (scalar_t))166 inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
167   TORCH_INTERNAL_ASSERT(
168       false,
169       "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
170       typeid(scalar_t).name());
171 }
172 
173 template <>
174 void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
175 template <>
176 void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
177 template <>
178 void bsrsv2_bufferSize<c10::complex<float>>(
179     CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
180 template <>
181 void bsrsv2_bufferSize<c10::complex<double>>(
182     CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
183 
184 #define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)               \
185   cusparseHandle_t handle, cusparseDirection_t dirA,              \
186       cusparseOperation_t transA, int mb, int nnzb,               \
187       const cusparseMatDescr_t descrA, const scalar_t *bsrValA,   \
188       const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
189       bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
190 
191 template <typename scalar_t>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (scalar_t))192 inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
193   TORCH_INTERNAL_ASSERT(
194       false,
195       "at::cuda::sparse::bsrsv2_analysis: not implemented for ",
196       typeid(scalar_t).name());
197 }
198 
199 template <>
200 void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
201 template <>
202 void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
203 template <>
204 void bsrsv2_analysis<c10::complex<float>>(
205     CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
206 template <>
207 void bsrsv2_analysis<c10::complex<double>>(
208     CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
209 
210 #define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)                           \
211   cusparseHandle_t handle, cusparseDirection_t dirA,                       \
212       cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
213       const cusparseMatDescr_t descrA, const scalar_t *bsrValA,            \
214       const int *bsrRowPtrA, const int *bsrColIndA, int blockDim,          \
215       bsrsv2Info_t info, const scalar_t *x, scalar_t *y,                   \
216       cusparseSolvePolicy_t policy, void *pBuffer
217 
218 template <typename scalar_t>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (scalar_t))219 inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
220   TORCH_INTERNAL_ASSERT(
221       false,
222       "at::cuda::sparse::bsrsv2_solve: not implemented for ",
223       typeid(scalar_t).name());
224 }
225 
226 template <>
227 void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
228 template <>
229 void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
230 template <>
231 void bsrsv2_solve<c10::complex<float>>(
232     CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
233 template <>
234 void bsrsv2_solve<c10::complex<double>>(
235     CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
236 
237 #define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)                            \
238   cusparseHandle_t handle, cusparseDirection_t dirA,                         \
239       cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
240       int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA,          \
241       const int *bsrRowPtrA, const int *bsrColIndA, int blockDim,            \
242       bsrsm2Info_t info, int *pBufferSizeInBytes
243 
244 template <typename scalar_t>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (scalar_t))245 inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
246   TORCH_INTERNAL_ASSERT(
247       false,
248       "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
249       typeid(scalar_t).name());
250 }
251 
252 template <>
253 void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
254 template <>
255 void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
256 template <>
257 void bsrsm2_bufferSize<c10::complex<float>>(
258     CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
259 template <>
260 void bsrsm2_bufferSize<c10::complex<double>>(
261     CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
262 
263 #define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)                          \
264   cusparseHandle_t handle, cusparseDirection_t dirA,                         \
265       cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
266       int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA,    \
267       const int *bsrRowPtrA, const int *bsrColIndA, int blockDim,            \
268       bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
269 
270 template <typename scalar_t>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (scalar_t))271 inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
272   TORCH_INTERNAL_ASSERT(
273       false,
274       "at::cuda::sparse::bsrsm2_analysis: not implemented for ",
275       typeid(scalar_t).name());
276 }
277 
278 template <>
279 void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
280 template <>
281 void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
282 template <>
283 void bsrsm2_analysis<c10::complex<float>>(
284     CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
285 template <>
286 void bsrsm2_analysis<c10::complex<double>>(
287     CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
288 
289 #define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)                             \
290   cusparseHandle_t handle, cusparseDirection_t dirA,                         \
291       cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
292       int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA,      \
293       const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
294       int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb,           \
295       scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
296 
297 template <typename scalar_t>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (scalar_t))298 inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
299   TORCH_INTERNAL_ASSERT(
300       false,
301       "at::cuda::sparse::bsrsm2_solve: not implemented for ",
302       typeid(scalar_t).name());
303 }
304 
305 template <>
306 void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
307 template <>
308 void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
309 template <>
310 void bsrsm2_solve<c10::complex<float>>(
311     CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
312 template <>
313 void bsrsm2_solve<c10::complex<double>>(
314     CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
315 
316 #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
317 
318 } // namespace at::cuda::sparse
319