1 /*
2 Functions here use deprecated cuSPARSE API that was removed in CUDA 11.
3 This file will be removed eventually.
4 */
5 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
6 #include <ATen/core/Tensor.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/SparseTensorUtils.h>
9 #include <ATen/native/sparse/cuda/SparseBlasLegacy.h>
10 #include <ATen/native/sparse/cuda/SparseCUDABlas.h>
11
12 namespace at::native {
13
s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz,int64_t m,int64_t n,int64_t k,const Tensor & r_,const Scalar & beta,const Tensor & t,const Scalar & alpha,const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,const Tensor & dense)14 void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, const Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, const Tensor& dense) {
15 TORCH_INTERNAL_ASSERT(nnz > 0);
16
17 // No half support, so we don't have to use CUDATypeConversion
18 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
19 values.scalar_type(), "addmm_sparse_cuda", [&] {
20 scalar_t cast_beta = beta.to<scalar_t>();
21 scalar_t cast_alpha = alpha.to<scalar_t>();
22 Tensor r__;
23 if (cast_beta == scalar_t(0)) {
24 r_.zero_();
25 } else if (!at::sparse::is_same_tensor(t, r_)) {
26 r_.copy_(t);
27 }
28 if (r_.stride(0) == 1 && r_.stride(1) == r_.size(0)) {
29 r__ = r_;
30 } else {
31 // Note: This storage arrangement is preferred due to most of the CUDA kernels handle only contiguous tensors
32 r__ = r_.transpose(0, 1).clone(at::MemoryFormat::Contiguous);
33 r__.transpose_(0, 1);
34 }
35 TORCH_INTERNAL_ASSERT(r__.mT().is_contiguous());
36 Tensor dense_;
37 char transpose_dense;
38 if (dense.stride(0) == 1 && dense.stride(1) == dense.size(0)) {
39 transpose_dense = 'n';
40 dense_ = dense;
41 } else if (dense.stride(1) == 1 && dense.stride(0) == dense.size(1)) {
42 transpose_dense = 't';
43 dense_ = dense;
44 } else {
45 transpose_dense = 't';
46 dense_ = dense.contiguous();
47 }
48
49 sparse::cuda::csrmm2(
50 'n',
51 transpose_dense,
52 m,
53 n,
54 k,
55 nnz,
56 cast_alpha,
57 values.data_ptr<scalar_t>(),
58 crow_indices.data_ptr<int32_t>(),
59 col_indices.data_ptr<int32_t>(),
60 dense_.data_ptr<scalar_t>(),
61 (transpose_dense == 'n' ? dense_.stride(1) : dense_.stride(0)),
62 cast_beta,
63 r__.data_ptr<scalar_t>(),
64 r__.stride(1));
65
66 if (!at::sparse::is_same_tensor(r__, r_)) {
67 r_.copy_(r__);
68 }
69 }
70 );
71 }
72
73 } // namespace at::native
74