1 #pragma once
2
3 /*
4 Provides a subset of cuSPARSE functions as templates:
5
6 csrgeam2<scalar_t>(...)
7
8 where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
9 The functions are available in at::cuda::sparse namespace.
10 */
11
12 #include <ATen/cuda/CUDAContext.h>
13 #include <ATen/cuda/CUDASparse.h>
14
15 namespace at::cuda::sparse {
16
17 #define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \
18 cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
19 const cusparseMatDescr_t descrA, int nnzA, \
20 const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
21 const int *csrSortedColIndA, const scalar_t *beta, \
22 const cusparseMatDescr_t descrB, int nnzB, \
23 const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
24 const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
25 const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \
26 const int *csrSortedColIndC, size_t *pBufferSizeInBytes
27
28 template <typename scalar_t>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (scalar_t))29 inline void csrgeam2_bufferSizeExt(
30 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
31 TORCH_INTERNAL_ASSERT(
32 false,
33 "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
34 typeid(scalar_t).name());
35 }
36
37 template <>
38 void csrgeam2_bufferSizeExt<float>(
39 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
40 template <>
41 void csrgeam2_bufferSizeExt<double>(
42 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
43 template <>
44 void csrgeam2_bufferSizeExt<c10::complex<float>>(
45 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
46 template <>
47 void csrgeam2_bufferSizeExt<c10::complex<double>>(
48 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
49
50 #define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \
51 cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \
52 int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \
53 const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
54 const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
55 int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
56
57 template <typename scalar_t>
csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES ())58 inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
59 TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
60 handle,
61 m,
62 n,
63 descrA,
64 nnzA,
65 csrSortedRowPtrA,
66 csrSortedColIndA,
67 descrB,
68 nnzB,
69 csrSortedRowPtrB,
70 csrSortedColIndB,
71 descrC,
72 csrSortedRowPtrC,
73 nnzTotalDevHostPtr,
74 workspace));
75 }
76
77 #define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \
78 cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
79 const cusparseMatDescr_t descrA, int nnzA, \
80 const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
81 const int *csrSortedColIndA, const scalar_t *beta, \
82 const cusparseMatDescr_t descrB, int nnzB, \
83 const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
84 const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
85 scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
86 void *pBuffer
87
88 template <typename scalar_t>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (scalar_t))89 inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
90 TORCH_INTERNAL_ASSERT(
91 false,
92 "at::cuda::sparse::csrgeam2: not implemented for ",
93 typeid(scalar_t).name());
94 }
95
96 template <>
97 void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
98 template <>
99 void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
100 template <>
101 void csrgeam2<c10::complex<float>>(
102 CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
103 template <>
104 void csrgeam2<c10::complex<double>>(
105 CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
106
107 #define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \
108 cusparseHandle_t handle, cusparseDirection_t dirA, \
109 cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
110 int kb, int nnzb, const scalar_t *alpha, \
111 const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
112 const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
113 const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
114
115 template <typename scalar_t>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (scalar_t))116 inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
117 TORCH_INTERNAL_ASSERT(
118 false,
119 "at::cuda::sparse::bsrmm: not implemented for ",
120 typeid(scalar_t).name());
121 }
122
123 template <>
124 void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
125 template <>
126 void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
127 template <>
128 void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
129 template <>
130 void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
131
132 #define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
133 cusparseHandle_t handle, cusparseDirection_t dirA, \
134 cusparseOperation_t transA, int mb, int nb, int nnzb, \
135 const scalar_t *alpha, const cusparseMatDescr_t descrA, \
136 const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
137 int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
138
139 template <typename scalar_t>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (scalar_t))140 inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
141 TORCH_INTERNAL_ASSERT(
142 false,
143 "at::cuda::sparse::bsrmv: not implemented for ",
144 typeid(scalar_t).name());
145 }
146
147 template <>
148 void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
149 template <>
150 void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
151 template <>
152 void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
153 template <>
154 void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
155
156 #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
157
158 #define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \
159 cusparseHandle_t handle, cusparseDirection_t dirA, \
160 cusparseOperation_t transA, int mb, int nnzb, \
161 const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
162 const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
163 bsrsv2Info_t info, int *pBufferSizeInBytes
164
165 template <typename scalar_t>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (scalar_t))166 inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
167 TORCH_INTERNAL_ASSERT(
168 false,
169 "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
170 typeid(scalar_t).name());
171 }
172
173 template <>
174 void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
175 template <>
176 void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
177 template <>
178 void bsrsv2_bufferSize<c10::complex<float>>(
179 CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
180 template <>
181 void bsrsv2_bufferSize<c10::complex<double>>(
182 CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
183
184 #define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \
185 cusparseHandle_t handle, cusparseDirection_t dirA, \
186 cusparseOperation_t transA, int mb, int nnzb, \
187 const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
188 const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
189 bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
190
191 template <typename scalar_t>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (scalar_t))192 inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
193 TORCH_INTERNAL_ASSERT(
194 false,
195 "at::cuda::sparse::bsrsv2_analysis: not implemented for ",
196 typeid(scalar_t).name());
197 }
198
199 template <>
200 void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
201 template <>
202 void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
203 template <>
204 void bsrsv2_analysis<c10::complex<float>>(
205 CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
206 template <>
207 void bsrsv2_analysis<c10::complex<double>>(
208 CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
209
210 #define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \
211 cusparseHandle_t handle, cusparseDirection_t dirA, \
212 cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
213 const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
214 const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
215 bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \
216 cusparseSolvePolicy_t policy, void *pBuffer
217
218 template <typename scalar_t>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (scalar_t))219 inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
220 TORCH_INTERNAL_ASSERT(
221 false,
222 "at::cuda::sparse::bsrsv2_solve: not implemented for ",
223 typeid(scalar_t).name());
224 }
225
226 template <>
227 void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
228 template <>
229 void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
230 template <>
231 void bsrsv2_solve<c10::complex<float>>(
232 CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
233 template <>
234 void bsrsv2_solve<c10::complex<double>>(
235 CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
236
237 #define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \
238 cusparseHandle_t handle, cusparseDirection_t dirA, \
239 cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
240 int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
241 const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
242 bsrsm2Info_t info, int *pBufferSizeInBytes
243
244 template <typename scalar_t>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (scalar_t))245 inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
246 TORCH_INTERNAL_ASSERT(
247 false,
248 "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
249 typeid(scalar_t).name());
250 }
251
252 template <>
253 void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
254 template <>
255 void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
256 template <>
257 void bsrsm2_bufferSize<c10::complex<float>>(
258 CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
259 template <>
260 void bsrsm2_bufferSize<c10::complex<double>>(
261 CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
262
263 #define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \
264 cusparseHandle_t handle, cusparseDirection_t dirA, \
265 cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
266 int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
267 const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
268 bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
269
270 template <typename scalar_t>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (scalar_t))271 inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
272 TORCH_INTERNAL_ASSERT(
273 false,
274 "at::cuda::sparse::bsrsm2_analysis: not implemented for ",
275 typeid(scalar_t).name());
276 }
277
278 template <>
279 void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
280 template <>
281 void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
282 template <>
283 void bsrsm2_analysis<c10::complex<float>>(
284 CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
285 template <>
286 void bsrsm2_analysis<c10::complex<double>>(
287 CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
288
289 #define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \
290 cusparseHandle_t handle, cusparseDirection_t dirA, \
291 cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
292 int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \
293 const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
294 int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \
295 scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
296
297 template <typename scalar_t>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (scalar_t))298 inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
299 TORCH_INTERNAL_ASSERT(
300 false,
301 "at::cuda::sparse::bsrsm2_solve: not implemented for ",
302 typeid(scalar_t).name());
303 }
304
305 template <>
306 void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
307 template <>
308 void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
309 template <>
310 void bsrsm2_solve<c10::complex<float>>(
311 CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
312 template <>
313 void bsrsm2_solve<c10::complex<double>>(
314 CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
315
316 #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
317
318 } // namespace at::cuda::sparse
319