xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseBlasImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <ATen/core/Scalar.h>
5 
6 namespace at::native::sparse::impl {
7 
8 TORCH_API Tensor& _compressed_row_strided_mm_out(
9     const Tensor& compressed_row_sparse,
10     const Tensor& strided,
11     Tensor& result);
12 
13 TORCH_API Tensor& _compressed_row_strided_addmm_out(
14     const Tensor& self,
15     const Tensor& mat1,
16     const Tensor& mat2,
17     const Scalar& beta,
18     const Scalar& alpha,
19     Tensor& result);
20 
21 namespace cpu {
22 
23 void addmv_out_sparse_csr(
24     const Tensor& mat,
25     const Tensor& vec,
26     const Scalar& beta,
27     const Scalar& alpha,
28     const Tensor& result);
29 
30 void add_out_sparse_csr(
31     const Tensor& mat1,
32     const Tensor& mat2,
33     const Scalar& alpha,
34     const Tensor& result);
35 
36 void triangular_solve_out_sparse_csr(
37     const Tensor& A,
38     const Tensor& B,
39     const Tensor& X,
40     bool upper,
41     bool transpose,
42     bool unitriangular);
43 
44 } // namespace cpu
45 } // namespace at::native::sparse::impl
46