xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2 Functions here use deprecated cuSPARSE API that was removed in CUDA 11.
3 This file will be removed eventually.
4 */
5 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
6 #include <ATen/core/Tensor.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/SparseTensorUtils.h>
9 #include <ATen/native/sparse/cuda/SparseBlasLegacy.h>
10 #include <ATen/native/sparse/cuda/SparseCUDABlas.h>
11 
12 namespace at::native {
13 
s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz,int64_t m,int64_t n,int64_t k,const Tensor & r_,const Scalar & beta,const Tensor & t,const Scalar & alpha,const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,const Tensor & dense)14 void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, const Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, const Tensor& dense) {
15   TORCH_INTERNAL_ASSERT(nnz > 0);
16 
17   // No half support, so we don't have to use CUDATypeConversion
18   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
19       values.scalar_type(), "addmm_sparse_cuda", [&] {
20         scalar_t cast_beta = beta.to<scalar_t>();
21         scalar_t cast_alpha = alpha.to<scalar_t>();
22         Tensor r__;
23         if (cast_beta == scalar_t(0)) {
24           r_.zero_();
25         } else if (!at::sparse::is_same_tensor(t, r_)) {
26           r_.copy_(t);
27         }
28         if (r_.stride(0) == 1 && r_.stride(1) == r_.size(0)) {
29           r__ = r_;
30         } else {
31           // Note: This storage arrangement is preferred due to most of the CUDA kernels handle only contiguous tensors
32           r__ = r_.transpose(0, 1).clone(at::MemoryFormat::Contiguous);
33           r__.transpose_(0, 1);
34         }
35         TORCH_INTERNAL_ASSERT(r__.mT().is_contiguous());
36         Tensor dense_;
37         char transpose_dense;
38         if (dense.stride(0) == 1 && dense.stride(1) == dense.size(0)) {
39           transpose_dense = 'n';
40           dense_ = dense;
41         } else if (dense.stride(1) == 1 && dense.stride(0) == dense.size(1)) {
42           transpose_dense = 't';
43           dense_ = dense;
44         } else {
45           transpose_dense = 't';
46           dense_ = dense.contiguous();
47         }
48 
49         sparse::cuda::csrmm2(
50           'n',
51           transpose_dense,
52           m,
53           n,
54           k,
55           nnz,
56           cast_alpha,
57           values.data_ptr<scalar_t>(),
58           crow_indices.data_ptr<int32_t>(),
59           col_indices.data_ptr<int32_t>(),
60           dense_.data_ptr<scalar_t>(),
61           (transpose_dense == 'n' ? dense_.stride(1) : dense_.stride(0)),
62           cast_beta,
63           r__.data_ptr<scalar_t>(),
64           r__.stride(1));
65 
66         if (!at::sparse::is_same_tensor(r__, r_)) {
67           r_.copy_(r__);
68         }
69       }
70     );
71 }
72 
73 } // namespace at::native
74