1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/SparseCsrTensorUtils.h>
5 #include <ATen/native/Resize.h>
6 #include <ATen/native/sparse/SparseBlas.h>
7 #include <ATen/native/sparse/SparseBlasImpl.h>
8 #include <ATen/native/cpu/SampledAddmmKernel.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/addmv_native.h>
15 #include <ATen/ops/copy_native.h>
16 #include <ATen/ops/mul.h>
17 #include <ATen/ops/scalar_tensor_native.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/addmm.h>
20 #include <ATen/ops/resize_as_sparse_native.h>
21 #include <ATen/ops/sparse_sampled_addmm_native.h>
22 #include <ATen/ops/triangular_solve_native.h>
23 #endif
24
25 #include <c10/util/MaybeOwned.h>
26
27 namespace at::native {
28
addmv_out_sparse_compressed(const Tensor & self,const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,Tensor & result)29 Tensor& addmv_out_sparse_compressed(
30 const Tensor& self,
31 const Tensor& mat,
32 const Tensor& vec,
33 const Scalar& beta,
34 const Scalar& alpha,
35 Tensor& result) {
36 TORCH_CHECK(
37 mat.layout() != kSparseBsc,
38 "torch.addmv: operation not supported for mat with SparseBsc layout");
39 if (mat.layout() == kSparseCsc) {
40 // TODO: Add native CSC support to avoid this expensive conversion
41 return addmv_out_sparse_compressed(
42 self, mat.to_sparse_csr(), vec, beta, alpha, result);
43 }
44 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
45 mat.layout() == kSparseCsr || mat.layout() == kSparseBsr);
46
47 TORCH_CHECK(mat.dim() == 2, "addmv: Expected mat to be 2-D");
48 TORCH_CHECK(vec.dim() == 1, "addmv: Expected vec to be 1-D");
49
50 c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
51 auto betaval = beta.toComplexDouble();
52
53 if (&result != &self) {
54 at::native::resize_output(result, self_->sizes());
55 if (betaval != 0.0) {
56 at::native::copy_(result, *self_);
57 }
58 }
59
60 if (mat._nnz() == 0) {
61 // shortcut for an empty matrix
62 // By definition, when beta==0, values in self should be ignored. nans and
63 // infs should not propagate
64 if (betaval == 0.0) {
65 return result.zero_();
66 } else {
67 return at::mul_out(
68 const_cast<Tensor&>(result),
69 self,
70 at::native::scalar_tensor(
71 beta,
72 self.scalar_type(),
73 std::nullopt /*layout*/,
74 at::kCPU,
75 std::nullopt /* pin_memory */));
76 }
77 }
78
79 sparse::impl::cpu::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
80 return result;
81 }
82
83 /*
84 Solves a system of linear equations whose coefficients are represented in a sparse triangular matrix A:
85 op(A) X = B.
86
87 Args:
88 * `B` - dense Tensor of size m × nrhs.
89 * `A` - sparse Tensor of size m × m.
90 * `upper` - controls whether upper or lower triangular part of A is considered in computations.
91 * `transpose` - if true then op(A) = A^T.
92 * `unitriangular` - if true then the diagonal elements of A are assumed to be one.
93 * `X` - dense Tensor of size m × nrhs.
94 * `clone_A` - cloned matrix A, required only for compatibility with strided layout interface.
95 */
triangular_solve_out_sparse_csr_cpu(const Tensor & B,const Tensor & A,bool upper,bool transpose,bool unitriangular,Tensor & X,Tensor & clone_A)96 std::tuple<Tensor&, Tensor&> triangular_solve_out_sparse_csr_cpu(
97 const Tensor& B,
98 const Tensor& A,
99 bool upper,
100 bool transpose,
101 bool unitriangular,
102 Tensor& X,
103 Tensor& clone_A) {
104 sparse::impl::cpu::triangular_solve_out_sparse_csr(A, B, X, upper, transpose, unitriangular);
105 return std::tuple<Tensor&, Tensor&>(X, clone_A);
106 }
107
108 /*
109 Computes `result` <- α*(A @ B) * spy(C) + β*C, where spy(C) is the sparsity pattern matrix of C.
110
111 Args:
112 * `mat1` - [in] dense Tensor A of size m × k.
113 * `mat2` - [in] dense Tensor B of size k × n.
114 * `self` - [in] sparse Tensor C of size m × n.
115 * `result` - [out] sparse Tensor of size m × n.
116 */
sparse_sampled_addmm_out_sparse_csr_cpu(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)117 Tensor& sparse_sampled_addmm_out_sparse_csr_cpu(
118 const Tensor& self,
119 const Tensor& mat1,
120 const Tensor& mat2,
121 const Scalar& beta,
122 const Scalar& alpha,
123 Tensor& result) {
124 at::native::sparse::sparse_sampled_addmm_check_inputs(self, mat1, mat2, beta, alpha, result);
125 // Allow only same types as for the CUDA path
126 auto t = self.scalar_type();
127 TORCH_CHECK(t == ScalarType::Double || t == ScalarType::Float ||
128 t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble,
129 "sparse_sampled_addmm: Expected self to be a floating-point or complex tensor, but got ", t);
130 if (&result != &self) {
131 // We allow self to be a single matrix when mat1 and mat2 are batched
132 auto result_sizes = DimVector(mat1.sizes().slice(0, mat1.dim() - 2));
133 result_sizes.push_back(self.size(-2));
134 result_sizes.push_back(self.size(-1));
135 at::sparse_csr::get_sparse_csr_impl(result)->resize_(self._nnz(), result_sizes);
136 result.copy_(self);
137 }
138
139 if (mat1.numel() == 0 || mat2.numel() == 0 || result._nnz() == 0) {
140 result.mul_(beta);
141 return result;
142 }
143
144 // transpose mat2 to [b, n, k] from performance perspective.
145 // for gnn classic usage, mat2 is already stored in [b, n, k] physically,
146 // so no extra memcpy is needed.
147 auto mat2_t = mat2.transpose(-1, -2).contiguous();
148 sampled_addmm_sparse_csr_stub(kCPU, mat1.contiguous(), mat2_t, beta, alpha, result);
149
150 return result;
151 }
152
sparse_sampled_addmm_sparse_csr_cpu(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)153 Tensor sparse_sampled_addmm_sparse_csr_cpu(
154 const Tensor& self,
155 const Tensor& mat1,
156 const Tensor& mat2,
157 const Scalar& beta,
158 const Scalar& alpha) {
159 auto result = at::empty({0, 0}, self.options());
160 at::native::sparse_sampled_addmm_out_sparse_csr_cpu(self, mat1, mat2, beta, alpha, result);
161 return result;
162 }
163
164 namespace sparse {
165
sparse_sampled_addmm_check_inputs(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)166 void sparse_sampled_addmm_check_inputs(
167 const Tensor& self,
168 const Tensor& mat1,
169 const Tensor& mat2,
170 const Scalar& beta,
171 const Scalar& alpha,
172 const Tensor& result) {
173 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.is_sparse_csr());
174
175 TORCH_CHECK(
176 mat1.layout() == kStrided,
177 "sampled_addmm: Expected mat1 to have strided layout, but got ",
178 mat1.layout());
179 TORCH_CHECK(
180 mat2.layout() == kStrided,
181 "sampled_addmm: Expected mat2 to have strided layout, but got ",
182 mat2.layout());
183
184 TORCH_CHECK(
185 result.layout() == kSparseCsr,
186 "sampled_addmm: Expected result to have sparse csr layout, but got ",
187 result.layout());
188 TORCH_CHECK(self.dense_dim() == 0,
189 "sampled_addmm: Expected non-hybrid self tensor");
190 TORCH_CHECK(result.dense_dim() == 0,
191 "sampled_addmm: Expected non-hybrid result tensor");
192
193 TORCH_CHECK(
194 mat1.scalar_type() == mat2.scalar_type(),
195 "sampled_addmm: Expected mat1 and mat2 to have the same dtype, but got ",
196 mat1.scalar_type(),
197 " and ",
198 mat2.scalar_type());
199 TORCH_CHECK(
200 mat1.scalar_type() == self.scalar_type(),
201 "sampled_addmm: Expected mat1 and self to have the same dtype, but got ",
202 mat1.scalar_type(),
203 " and ",
204 self.scalar_type());
205 TORCH_CHECK(
206 result.scalar_type() == self.scalar_type(),
207 "sampled_addmm: Expected result and self to have the same dtype, but got ",
208 result.scalar_type(),
209 " and ",
210 self.scalar_type());
211
212 TORCH_CHECK(
213 mat1.dim() >= 2,
214 "sampled_addmm: Expected mat1 to be a matrix, got ",
215 mat1.dim(),
216 "-D tensor");
217 TORCH_CHECK(
218 mat2.dim() >= 2,
219 "sampled_addmm: Expected mat2 to be a matrix, got ",
220 mat2.dim(),
221 "-D tensor");
222 TORCH_CHECK(
223 result.dim() >= 2,
224 "sampled_addmm: Expected result to be a matrix, got ",
225 result.dim(),
226 "-D tensor");
227
228 TORCH_CHECK(
229 mat1.sizes().slice(0, mat1.dim() - 2) == mat2.sizes().slice(0, mat2.dim() - 2),
230 "sampled_addmm: Expected mat1 and mat2 to have the same batch size, but got ",
231 mat1.sizes().slice(0, mat1.dim() - 2),
232 " and ",
233 mat2.sizes().slice(0, mat2.dim() - 2));
234
235 TORCH_CHECK(
236 !(self.dim() > 2 && self.sizes().slice(0, self.dim() - 2) != mat1.sizes().slice(0, mat1.dim() - 2)),
237 "sampled_addmm: Expected self and mat1 to have the same batch size, but got ",
238 self.sizes().slice(0, self.dim() - 2),
239 " and ",
240 mat1.sizes().slice(0, mat1.dim() - 2));
241
242 IntArrayRef mat1_sizes = mat1.sizes();
243 IntArrayRef mat2_sizes = mat2.sizes();
244 TORCH_CHECK(
245 mat1_sizes[mat1.dim() - 1] == mat2_sizes[mat2.dim() - 2],
246 "sampled_addmm: mat1 and mat2 shapes cannot be multiplied (",
247 mat1_sizes[mat1.dim() - 2],
248 "x",
249 mat1_sizes[mat1.dim() - 1],
250 " and ",
251 mat2_sizes[mat2.dim() - 2],
252 "x",
253 mat2_sizes[mat2.dim() - 1],
254 ")");
255
256 IntArrayRef self_sizes = self.sizes();
257 TORCH_CHECK(
258 self_sizes[self.dim() - 2] == mat1_sizes[mat1.dim() - 2],
259 "sampled_addmm: self.shape[-2] must match mat1.shape[-2]");
260 TORCH_CHECK(
261 self_sizes[self.dim() - 1] == mat2_sizes[mat2.dim() - 1],
262 "sampled_addmm: self.shape[-1] must match mat2.shape[-1]");
263 }
264
265 } // namespace sparse
266
267 DEFINE_DISPATCH(sampled_addmm_sparse_csr_stub);
268
269 } // namespace at::native
270