xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseBlas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/SparseCsrTensorUtils.h>
5 #include <ATen/native/Resize.h>
6 #include <ATen/native/sparse/SparseBlas.h>
7 #include <ATen/native/sparse/SparseBlasImpl.h>
8 #include <ATen/native/cpu/SampledAddmmKernel.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/addmv_native.h>
15 #include <ATen/ops/copy_native.h>
16 #include <ATen/ops/mul.h>
17 #include <ATen/ops/scalar_tensor_native.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/addmm.h>
20 #include <ATen/ops/resize_as_sparse_native.h>
21 #include <ATen/ops/sparse_sampled_addmm_native.h>
22 #include <ATen/ops/triangular_solve_native.h>
23 #endif
24 
25 #include <c10/util/MaybeOwned.h>
26 
27 namespace at::native {
28 
addmv_out_sparse_compressed(const Tensor & self,const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,Tensor & result)29 Tensor& addmv_out_sparse_compressed(
30     const Tensor& self,
31     const Tensor& mat,
32     const Tensor& vec,
33     const Scalar& beta,
34     const Scalar& alpha,
35     Tensor& result) {
36   TORCH_CHECK(
37       mat.layout() != kSparseBsc,
38       "torch.addmv: operation not supported for mat with SparseBsc layout");
39   if (mat.layout() == kSparseCsc) {
40     // TODO: Add native CSC support to avoid this expensive conversion
41     return addmv_out_sparse_compressed(
42         self, mat.to_sparse_csr(), vec, beta, alpha, result);
43   }
44   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
45       mat.layout() == kSparseCsr || mat.layout() == kSparseBsr);
46 
47   TORCH_CHECK(mat.dim() == 2, "addmv: Expected mat to be 2-D");
48   TORCH_CHECK(vec.dim() == 1, "addmv: Expected vec to be 1-D");
49 
50   c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
51   auto betaval = beta.toComplexDouble();
52 
53   if (&result != &self) {
54     at::native::resize_output(result, self_->sizes());
55     if (betaval != 0.0) {
56       at::native::copy_(result, *self_);
57     }
58   }
59 
60   if (mat._nnz() == 0) {
61     // shortcut for an empty matrix
62     // By definition, when beta==0, values in self should be ignored. nans and
63     // infs should not propagate
64     if (betaval == 0.0) {
65       return result.zero_();
66     } else {
67       return at::mul_out(
68           const_cast<Tensor&>(result),
69           self,
70           at::native::scalar_tensor(
71               beta,
72               self.scalar_type(),
73               std::nullopt /*layout*/,
74               at::kCPU,
75               std::nullopt /* pin_memory */));
76     }
77   }
78 
79   sparse::impl::cpu::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
80   return result;
81 }
82 
83 /*
84   Solves a system of linear equations whose coefficients are represented in a sparse triangular matrix A:
85   op(A) X = B.
86 
87   Args:
88   * `B` - dense Tensor of size m × nrhs.
89   * `A` - sparse Tensor of size m × m.
90   * `upper` - controls whether upper or lower triangular part of A is considered in computations.
91   * `transpose` - if true then op(A) = A^T.
92   * `unitriangular` - if true then the diagonal elements of A are assumed to be one.
93   * `X` - dense Tensor of size m × nrhs.
94   * `clone_A` - cloned matrix A, required only for compatibility with strided layout interface.
95 */
triangular_solve_out_sparse_csr_cpu(const Tensor & B,const Tensor & A,bool upper,bool transpose,bool unitriangular,Tensor & X,Tensor & clone_A)96 std::tuple<Tensor&, Tensor&> triangular_solve_out_sparse_csr_cpu(
97     const Tensor& B,
98     const Tensor& A,
99     bool upper,
100     bool transpose,
101     bool unitriangular,
102     Tensor& X,
103     Tensor& clone_A) {
104   sparse::impl::cpu::triangular_solve_out_sparse_csr(A, B, X, upper, transpose, unitriangular);
105   return std::tuple<Tensor&, Tensor&>(X, clone_A);
106 }
107 
108 /*
109   Computes `result` <- α*(A @ B) * spy(C) + β*C, where spy(C) is the sparsity pattern matrix of C.
110 
111   Args:
112   * `mat1` - [in] dense Tensor A of size m × k.
113   * `mat2` - [in] dense Tensor B of size k × n.
114   * `self` - [in] sparse Tensor C of size m × n.
115   * `result` - [out] sparse Tensor of size m × n.
116 */
sparse_sampled_addmm_out_sparse_csr_cpu(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)117 Tensor& sparse_sampled_addmm_out_sparse_csr_cpu(
118     const Tensor& self,
119     const Tensor& mat1,
120     const Tensor& mat2,
121     const Scalar& beta,
122     const Scalar& alpha,
123     Tensor& result) {
124   at::native::sparse::sparse_sampled_addmm_check_inputs(self, mat1, mat2, beta, alpha, result);
125   // Allow only same types as for the CUDA path
126   auto t = self.scalar_type();
127   TORCH_CHECK(t == ScalarType::Double || t == ScalarType::Float ||
128     t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble,
129     "sparse_sampled_addmm: Expected self to be a floating-point or complex tensor, but got ", t);
130   if (&result != &self) {
131     // We allow self to be a single matrix when mat1 and mat2 are batched
132     auto result_sizes = DimVector(mat1.sizes().slice(0, mat1.dim() - 2));
133     result_sizes.push_back(self.size(-2));
134     result_sizes.push_back(self.size(-1));
135     at::sparse_csr::get_sparse_csr_impl(result)->resize_(self._nnz(), result_sizes);
136     result.copy_(self);
137   }
138 
139   if (mat1.numel() == 0 || mat2.numel() == 0 || result._nnz() == 0) {
140     result.mul_(beta);
141     return result;
142   }
143 
144   // transpose mat2 to [b, n, k] from performance perspective.
145   // for gnn classic usage, mat2 is already stored in [b, n, k] physically,
146   // so no extra memcpy is needed.
147   auto mat2_t = mat2.transpose(-1, -2).contiguous();
148   sampled_addmm_sparse_csr_stub(kCPU, mat1.contiguous(), mat2_t, beta, alpha, result);
149 
150   return result;
151 }
152 
sparse_sampled_addmm_sparse_csr_cpu(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)153 Tensor sparse_sampled_addmm_sparse_csr_cpu(
154     const Tensor& self,
155     const Tensor& mat1,
156     const Tensor& mat2,
157     const Scalar& beta,
158     const Scalar& alpha) {
159   auto result = at::empty({0, 0}, self.options());
160   at::native::sparse_sampled_addmm_out_sparse_csr_cpu(self, mat1, mat2, beta, alpha, result);
161   return result;
162 }
163 
164 namespace sparse {
165 
sparse_sampled_addmm_check_inputs(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)166 void sparse_sampled_addmm_check_inputs(
167     const Tensor& self,
168     const Tensor& mat1,
169     const Tensor& mat2,
170     const Scalar& beta,
171     const Scalar& alpha,
172     const Tensor& result) {
173   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.is_sparse_csr());
174 
175   TORCH_CHECK(
176       mat1.layout() == kStrided,
177       "sampled_addmm: Expected mat1 to have strided layout, but got ",
178       mat1.layout());
179   TORCH_CHECK(
180       mat2.layout() == kStrided,
181       "sampled_addmm: Expected mat2 to have strided layout, but got ",
182       mat2.layout());
183 
184   TORCH_CHECK(
185       result.layout() == kSparseCsr,
186       "sampled_addmm: Expected result to have sparse csr layout, but got ",
187       result.layout());
188   TORCH_CHECK(self.dense_dim() == 0,
189       "sampled_addmm: Expected non-hybrid self tensor");
190   TORCH_CHECK(result.dense_dim() == 0,
191       "sampled_addmm: Expected non-hybrid result tensor");
192 
193   TORCH_CHECK(
194       mat1.scalar_type() == mat2.scalar_type(),
195       "sampled_addmm: Expected mat1 and mat2 to have the same dtype, but got ",
196       mat1.scalar_type(),
197       " and ",
198       mat2.scalar_type());
199   TORCH_CHECK(
200       mat1.scalar_type() == self.scalar_type(),
201       "sampled_addmm: Expected mat1 and self to have the same dtype, but got ",
202       mat1.scalar_type(),
203       " and ",
204       self.scalar_type());
205   TORCH_CHECK(
206       result.scalar_type() == self.scalar_type(),
207       "sampled_addmm: Expected result and self to have the same dtype, but got ",
208       result.scalar_type(),
209       " and ",
210       self.scalar_type());
211 
212   TORCH_CHECK(
213       mat1.dim() >= 2,
214       "sampled_addmm: Expected mat1 to be a matrix, got ",
215       mat1.dim(),
216       "-D tensor");
217   TORCH_CHECK(
218       mat2.dim() >= 2,
219       "sampled_addmm: Expected mat2 to be a matrix, got ",
220       mat2.dim(),
221       "-D tensor");
222   TORCH_CHECK(
223       result.dim() >= 2,
224       "sampled_addmm: Expected result to be a matrix, got ",
225       result.dim(),
226       "-D tensor");
227 
228   TORCH_CHECK(
229     mat1.sizes().slice(0, mat1.dim() - 2) == mat2.sizes().slice(0, mat2.dim() - 2),
230     "sampled_addmm: Expected mat1 and mat2 to have the same batch size, but got ",
231     mat1.sizes().slice(0, mat1.dim() - 2),
232     " and ",
233     mat2.sizes().slice(0, mat2.dim() - 2));
234 
235   TORCH_CHECK(
236     !(self.dim() > 2 && self.sizes().slice(0, self.dim() - 2) != mat1.sizes().slice(0, mat1.dim() - 2)),
237     "sampled_addmm: Expected self and mat1 to have the same batch size, but got ",
238     self.sizes().slice(0, self.dim() - 2),
239     " and ",
240     mat1.sizes().slice(0, mat1.dim() - 2));
241 
242   IntArrayRef mat1_sizes = mat1.sizes();
243   IntArrayRef mat2_sizes = mat2.sizes();
244   TORCH_CHECK(
245       mat1_sizes[mat1.dim() - 1] == mat2_sizes[mat2.dim() - 2],
246       "sampled_addmm: mat1 and mat2 shapes cannot be multiplied (",
247       mat1_sizes[mat1.dim() - 2],
248       "x",
249       mat1_sizes[mat1.dim() - 1],
250       " and ",
251       mat2_sizes[mat2.dim() - 2],
252       "x",
253       mat2_sizes[mat2.dim() - 1],
254       ")");
255 
256   IntArrayRef self_sizes = self.sizes();
257   TORCH_CHECK(
258       self_sizes[self.dim() - 2] == mat1_sizes[mat1.dim() - 2],
259       "sampled_addmm: self.shape[-2] must match mat1.shape[-2]");
260   TORCH_CHECK(
261       self_sizes[self.dim() - 1] == mat2_sizes[mat2.dim() - 1],
262       "sampled_addmm: self.shape[-1] must match mat2.shape[-1]");
263 }
264 
265 } // namespace sparse
266 
267 DEFINE_DISPATCH(sampled_addmm_sparse_csr_stub);
268 
269 } // namespace at::native
270