1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/mkl/Sparse.h>
4 #include <ATen/native/mkl/SparseBlasImpl.h>
5 #include <ATen/native/sparse/SparseBlasImpl.h>
6 #include <ATen/SparseCsrTensorUtils.h>
7
8 // Required for checking whether Triton kernels are available
9 #include <ATen/core/dispatch/Dispatcher.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #include <ATen/Operators.h>
15 #else
16 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
17 #include <ATen/ops/empty_like.h>
18 #include <ATen/ops/zeros.h>
19 #endif
20
21 #if !AT_USE_MKL_SPARSE()
22 #include <ATen/Dispatch.h>
23 #include <ATen/Parallel.h>
24 #endif
25
26
27 namespace at::native::sparse::impl {
28
29 namespace {
30
31 #ifndef USE_ROCM
operands_support_triton_mm_kernel(const Tensor & compressed,const Tensor & strided)32 bool operands_support_triton_mm_kernel(const Tensor& compressed, const Tensor& strided) {
33 // Triton works only with blocksizes which are powers of 2.
34 const auto is_power_of_2 = [](int64_t v) -> bool {
35 return !(v & (v - 1));
36 };
37 return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(compressed.layout(), "operands_support_triton_mm_kernel", [&] { return false; },
38 [&] {
39 const auto blocksize = at::sparse_csr::getBlockSize(compressed);
40 // Dtype and blocksize checks for potential Triton usage.
41 return ((strided.scalar_type() == ScalarType::Half
42 || strided.scalar_type() == ScalarType::BFloat16
43 || strided.scalar_type() == ScalarType::Float)
44 && compressed.scalar_type() == strided.scalar_type()
45 && is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
46 && (blocksize[0] >= 16) && (blocksize[1] >= 16)
47 // lhs is retiled to (b0, b1) while rhs is to (b1, b0),
48 // so the result is tiled to (b0, b0) and we need to make
49 // sure that strided.size(-1) is divisible by b0.
50 && strided.size(-1) % blocksize[0] == 0);
51 });
52 }
53 #endif
54 }
55
_compressed_row_strided_mm_out(const Tensor & compressed,const Tensor & strided,Tensor & result)56 Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& strided, Tensor& result) {
57 const auto compressed_layout = compressed.layout();
58 const auto compressed_layout_str = at::sparse_csr::layoutToString(compressed_layout);
59
60 // Device restrictions
61 TORCH_CHECK(compressed.device() == strided.device()
62 && compressed.device() == result.device(),
63 "spmm_out(): all input arguments are expected to be on the same device.");
64
65 // Layout restrictions.
66 TORCH_CHECK(compressed_layout == kSparseCsr || compressed_layout == kSparseBsr,
67 "spmm(", compressed_layout_str, ", Strided): only Csr and Bsr formats are supported for the sparse argument.");
68 TORCH_CHECK(result.layout() == kStrided,
69 "spmm_out(): out argument is expected to be strided.");
70
71 // Dtype restrictions.
72 TORCH_CHECK(compressed.scalar_type() == strided.scalar_type(),
73 "spmm(", compressed_layout_str, ", Strided): arguments expected to have the same dtype.");
74
75 // Dim restrictions.
76 TORCH_CHECK(compressed.dim() == 2,
77 "spmm(", compressed_layout_str, ", Strided): sparse arguments which are not 2D are not supported.");
78 TORCH_CHECK(strided.dim() >= 2,
79 "spmm(", compressed_layout_str, ", Strided): expects strided inputs to be at least 2D.");
80
81 const auto m = compressed.sizes()[0];
82 const auto k = compressed.sizes()[1];
83 const auto n = strided.size(-1);
84 // Matrix product size compatibility.
85 TORCH_CHECK(strided.size(-2) == k,
86 "spmm(", compressed_layout_str, "Strided): argument sizes are not compatible for matrix multiplication. ",
87 "Got ", compressed_layout_str, ".sizes(-1) == ", k, " is not equal to ",
88 "Strided.sizes(-2) == ", strided.size(-2), ".");
89
90 // We assume that result is properly resized.
91 auto result_expected_size = at::DimVector(strided.sizes().slice(0, strided.dim() - 2));
92 result_expected_size.push_back(m);
93 result_expected_size.push_back(n);
94 TORCH_CHECK(result.sizes() == result_expected_size,
95 "spmm_out(): out argument has wrong size. ",
96 "Expected (", result_expected_size, ") but got (", result.sizes(), ").");
97
98 auto values = compressed.values();
99
100 using Blocksize = std::array<int64_t, 2>;
101 // We refer to these as (b0, b1) in the comments below.
102 Blocksize blocksize = {1, 1};
103 if (compressed_layout == kSparseBsr) {
104 blocksize = {values.size(-2), values.size(-1)};
105 }
106
107 // No stable support for ROCM in Triton yet.
108 #ifndef USE_ROCM
109
110 if (operands_support_triton_mm_kernel(compressed, strided)) {
111 const auto triton_schema = c10::Dispatcher::singleton()
112 .findSchema({"triton::_triton_bsr_dense_mm_out", ""});
113 if (triton_schema.has_value()) {
114 const auto triton_kernel = triton_schema.value().typed<Tensor&(const Tensor&, const Tensor&, Tensor&)>();
115 if (triton_kernel.hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
116 return triton_kernel.call(compressed, strided, result);
117 }
118 } /* else the schema is not defined and/or the key is not
119 overwritten, so skip and execute the code below. */
120 }
121 #endif
122
123 // (..., r, c) -> (..., r / b0, c / b1, b0, b1)
124 // NOTE: this function ALWAYS creates a view upon successful execution.
125 const auto tile_tensor = [compressed_layout](
126 const Tensor& t, Blocksize blocksize) -> Tensor {
127 if (compressed_layout == kSparseCsr) {
128 return t.unsqueeze(-1).unsqueeze_(-1);
129 }
130 else {
131 const auto size_neg_2_blocked = t.size(-2) / blocksize[0];
132 const auto size_neg_1_blocked = t.size(-1) / blocksize[1];
133 auto tiled_sizes = at::DimVector(t.sizes().slice(0, t.dim() - 2));
134 tiled_sizes.push_back(size_neg_2_blocked);
135 tiled_sizes.push_back(blocksize[0]);
136 tiled_sizes.push_back(size_neg_1_blocked);
137 tiled_sizes.push_back(blocksize[1]);
138 return t.reshape(tiled_sizes).transpose(-3, -2);
139 }
140 };
141
142 // Note that sparse values are (..., b0, b1). This means that
143 // the strided input has to be "tilable" to (..., b1, x) with
144 // any x >= 1 such that all the shapes are (block) matrix product
145 // compatible. The matrix product will then have shape (..., b0, x).
146 // This in turn means the result has to be "tilable" to
147 // (..., b0, x).
148 //
149 // These observations imply the following restrictions:
150 // 1. strided.size(-2) has to be divisible by b1.
151 // 2. result.size(-2) has to be divisible by b0.
152 // 3. both strided.size(-1) and result.size(-1)
153 // have to be divisible by x.
154 //
155 // Restrictions 1 and 2 are trivially satisfied.
156 // Regarding restriction 3:
157 // it would make sense to take the largest possible x for better
158 // performance since it is very likely that the last dimension
159 // is contiguous. As such, this value is exactly
160 // x = strided.size(-1), since strided.size(-1) == result.size(-1)
161
162 // See the comments above. This is our x.
163 const auto outer_blocksize = n;
164
165 Blocksize strided_blocksize = {blocksize[1], outer_blocksize};
166 const auto strided_tiled = tile_tensor(strided, strided_blocksize);
167
168 // Left argument is (..., b0, b1) and right is (..., b1, x).
169 // This naturally implies the result should be "tilable" as
170 // (..., b0, x).
171 Blocksize result_blocksize = {blocksize[0], outer_blocksize};
172 auto result_tiled = tile_tensor(result, result_blocksize);
173
174 if (compressed_layout == kSparseCsr) {
175 values.unsqueeze_(-1).unsqueeze_(-1);
176 }
177
178 auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(compressed);
179
180 // Select block rows of the strided input that intersect with the block columns of the sparse input.
181 auto strided_tiled_selected_rows = strided_tiled.index_select(-4, plain_indices);
182
183 // Promote to float if output is half or bfloat16 for better precision
184 const auto mm_dtype = (result.scalar_type() == kHalf || result.scalar_type() == kBFloat16)
185 ? kFloat : result.scalar_type();
186 // Now that we know which block rows intersect with which block columns,
187 // we can perform matrix products between pairs of blocks.
188 // NOTE: .to is a no-op when result.scalar_type() == mm_dtype.
189 const auto pairwise_block_mm = values.unsqueeze(-3).to(mm_dtype)
190 .matmul(strided_tiled_selected_rows.to(mm_dtype));
191
192 // Having pairwise block matrix products stored in pairwise_block_mm,
193 // it is sufficient to sum all the block products that share the same row
194 // encoded in the sparse index. Since the reduction step is done via
195 // advanced indexing methods, the compressed index ought to get converted
196 // to the COO format.
197 const auto compressed_indices_coo = at::_convert_indices_from_csr_to_coo(
198 compressed_indices,
199 plain_indices,
200 compressed_indices.scalar_type() == kInt).select(0, 0);
201
202 // Reduction step.
203 // If result is neither half nor bfloat16, do everything in-place.
204 if (result.scalar_type() == mm_dtype) {
205 // Zero out and sum over the blocks that share the same row indices.
206 result_tiled.zero_();
207 result_tiled.index_add_(
208 /*dim=*/-4,
209 /*index=*/compressed_indices_coo,
210 /*source=*/pairwise_block_mm);
211 }
212 // Otherwise accumulate into a buffer and then copy.
213 else {
214 // No need to zero out, sum over the blocks goes into a buffer
215 // followed by a copy into result.
216 auto promoted_result_tiled = at::zeros(
217 result_tiled.sizes(),
218 result_tiled.options().dtype(mm_dtype));
219 promoted_result_tiled.index_add_(
220 /*dim=*/-4,
221 /*index=*/compressed_indices_coo,
222 /*source=*/pairwise_block_mm);
223 result_tiled.copy_(promoted_result_tiled);
224 }
225
226 return result;
227 }
228
_compressed_row_strided_addmm_out(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)229 Tensor& _compressed_row_strided_addmm_out(
230 const Tensor& self,
231 const Tensor& mat1,
232 const Tensor& mat2,
233 const Scalar& beta,
234 const Scalar& alpha,
235 Tensor& result) {
236
237 // No stable support for ROCM in Triton yet.
238 #ifndef USE_ROCM
239 if (operands_support_triton_mm_kernel(mat1, mat2)) {
240 const auto triton_schema = c10::Dispatcher::singleton()
241 .findSchema({"triton::_triton_bsr_dense_addmm_out", ""});
242 if (triton_schema.has_value()) {
243 const auto triton_kernel = triton_schema.value().typed<Tensor&(const Tensor&, const Tensor&, const Tensor&, const Scalar&, const Scalar&, Tensor&)>();
244 if (triton_kernel.hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
245 try {
246 return triton_kernel.call(self, mat1, mat2, beta, alpha, result);
247 } catch (std::runtime_error& e) {
248 const std::string msg = e.what();
249 if (msg != std::string("Unable to cast NotImplemented to Tensor")) {
250 throw std::runtime_error(msg);
251 }
252 } /* else triton_kernel returned NotImplemented, continue
253 with the generic method below */
254 }
255 } /* else the schema is not defined and/or the key is not
256 overwritten, so skip and execute the code below. */
257 }
258 #endif
259
260 auto alpha_val = alpha.toComplexDouble();
261 auto beta_val = beta.toComplexDouble();
262 // If result is not the same as self, it could always be used as out argument to mm.
263 if (!result.is_same(self)) {
264 _compressed_row_strided_mm_out(mat1, mat2, result);
265 if (alpha_val != 1.) {
266 result.mul_(alpha);
267 }
268 // Process beta
269 if (beta_val != 0.) {
270 if (beta_val == 1.) {
271 result.add_(self);
272 } else {
273 result.add_(self.mul(beta));
274 }
275 }
276 }
277 // Otherwise we need to allocate external memory for mm if beta != 0.
278 else {
279 // Process beta
280 if (beta_val != 0.) {
281 if (beta_val != 1.) {
282 result.mul_(beta);
283 }
284 auto mm = at::empty_like(result);
285 _compressed_row_strided_mm_out(mat1, mat2, mm);
286 if (alpha_val != 1.) {
287 mm.mul_(alpha);
288 }
289 result.add_(mm);
290 }
291 else {
292 _compressed_row_strided_mm_out(mat1, mat2, result);
293 if (alpha_val != 1.) {
294 result.mul_(alpha);
295 }
296 }
297 }
298
299 return result;
300 }
301
302 namespace cpu {
303 #if !AT_USE_MKL_SPARSE()
304 namespace {
305 template<typename scalar_t, typename idx_t>
addmv_sparse_csr(const scalar_t * mat_values,const idx_t * crow_index,const idx_t * col_index,const int64_t mat_rows,const scalar_t * vec,const size_t vec_stride,const scalar_t alpha,const scalar_t beta,scalar_t * result,const size_t result_stride)306 void addmv_sparse_csr(
307 const scalar_t* mat_values,
308 const idx_t* crow_index,
309 const idx_t* col_index,
310 const int64_t mat_rows,
311 const scalar_t* vec,
312 const size_t vec_stride,
313 const scalar_t alpha,
314 const scalar_t beta,
315 scalar_t* result,
316 const size_t result_stride) {
317 at::parallel_for(0, mat_rows, 0, [&](int64_t rstart, int64_t rend) {
318 for(const auto row: c10::irange(rstart, rend)) {
319 scalar_t acc(0);
320 for(const auto idx: c10::irange(crow_index[row], crow_index[row + 1])) {
321 acc += mat_values[idx] * vec[col_index[idx] * vec_stride];
322 }
323 result[row * result_stride] = acc * alpha + result[row * result_stride] * beta;
324 }
325 });
326 }
327
328 template<typename scalar_t, typename idx_t>
addmv_sparse_bsr(const scalar_t * mat_values,const idx_t * crow_index,const idx_t * col_index,const int64_t mat_rows,const int64_t blocksize_rows,const int64_t blocksize_cols,const scalar_t * vec,const size_t vec_stride,const scalar_t alpha,const scalar_t beta,scalar_t * result,const size_t result_stride)329 void addmv_sparse_bsr(
330 const scalar_t* mat_values,
331 const idx_t* crow_index,
332 const idx_t* col_index,
333 const int64_t mat_rows,
334 const int64_t blocksize_rows,
335 const int64_t blocksize_cols,
336 const scalar_t* vec,
337 const size_t vec_stride,
338 const scalar_t alpha,
339 const scalar_t beta,
340 scalar_t* result,
341 const size_t result_stride) {
342 at::parallel_for(0, mat_rows, 0, [&](int64_t rstart, int64_t rend) {
343 for(const auto row: c10::irange(rstart, rend)) {
344 const auto block_row = row / blocksize_rows;
345 const auto block_row_offset = row % blocksize_rows;
346 scalar_t acc(0);
347 for(const auto block_idx: c10::irange(crow_index[block_row], crow_index[block_row + 1])) {
348 const auto block_offs = (block_idx * blocksize_rows + block_row_offset) * blocksize_cols;
349 const auto vec_offs = col_index[block_idx]* blocksize_cols;
350 for(const auto idx: c10::irange(blocksize_cols)) {
351 acc += mat_values[block_offs + idx] * vec[(vec_offs + idx) * vec_stride];
352 }
353 }
354 result[row * result_stride] = acc * alpha + result[row * result_stride] * beta;
355 }
356 });
357 }
358
359 template<typename scalar_t, typename idx_t>
addmv_out_sparse_csr(const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)360 void addmv_out_sparse_csr(
361 const Tensor& mat,
362 const Tensor& vec,
363 const Scalar& beta,
364 const Scalar& alpha,
365 const Tensor& result) {
366 auto cont_values = mat.values().contiguous();
367 if (mat.layout() == kSparseBsr) {
368 addmv_sparse_bsr(cont_values.data_ptr<scalar_t>(),
369 mat.crow_indices().data_ptr<idx_t>(),
370 mat.col_indices().data_ptr<idx_t>(),
371 mat.size(0),
372 mat.values().size(1),
373 mat.values().size(2),
374 vec.data_ptr<scalar_t>(),
375 vec.stride(0),
376 alpha.to<scalar_t>(),
377 beta.to<scalar_t>(),
378 result.data_ptr<scalar_t>(),
379 result.stride(0));
380 } else {
381 addmv_sparse_csr(cont_values.data_ptr<scalar_t>(),
382 mat.crow_indices().data_ptr<idx_t>(),
383 mat.col_indices().data_ptr<idx_t>(),
384 mat.size(0),
385 vec.data_ptr<scalar_t>(),
386 vec.stride(0),
387 alpha.to<scalar_t>(),
388 beta.to<scalar_t>(),
389 result.data_ptr<scalar_t>(),
390 result.stride(0));
391 }
392 }
393 } // anonymous namespace
394 #endif // !AT_USE_MKL_SPARSE()
395
396 /*
397 Computes a sparse matrix-dense vector product defined as
398 y <- alpha*op(A)*x + beta*y
399
400 Args:
401 * `mat` - Tensor storing sparse m x n matrix A.
402 * `vec` - Tensor storing dense vector x of size n.
403 * `result` - [in] Tensor storing dense vector y of size m.
404 [out] result of the operation.
405 */
addmv_out_sparse_csr(const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)406 void addmv_out_sparse_csr(
407 const Tensor& mat,
408 const Tensor& vec,
409 const Scalar& beta,
410 const Scalar& alpha,
411 const Tensor& result) {
412 #if !AT_USE_MKL_SPARSE()
413 TORCH_CHECK(mat.layout() == kSparseBsr || mat.layout() == kSparseCsr, "Unexpected layout", mat.layout());
414 if (beta.toComplexDouble() == 0.) {
415 result.zero_();
416 }
417 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
418 result.scalar_type(), "addmv_out_sparse_csr_impl_reference", [&] {
419 if (mat.crow_indices().scalar_type() == kLong) {
420 addmv_out_sparse_csr<scalar_t, int64_t>(mat, vec, beta, alpha, result);
421 } else {
422 addmv_out_sparse_csr<scalar_t, int32_t>(mat, vec, beta, alpha, result);
423 }
424 });
425 #else
426 sparse::impl::mkl::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
427 #endif
428 }
429
430 /*
431 Computes a sum of two sparse matrices defined as
432 result <- mat1 + alpha*mat2
433
434 Args:
435 * `mat1` - CSR Tensor storing sparse m x n matrix.
436 * `mat2` - CSR Tensor storing sparse m x n matrix.
437 * `result` - [in] CSR Tensor storing sparse m x n matrix.
438 [out] result of the operation.
439 */
add_out_sparse_csr(const Tensor & mat1,const Tensor & mat2,const Scalar & alpha,const Tensor & result)440 void add_out_sparse_csr(
441 const Tensor& mat1,
442 const Tensor& mat2,
443 const Scalar& alpha,
444 const Tensor& result) {
445 #if !AT_MKL_ENABLED()
446 TORCH_CHECK(
447 false,
448 "Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ",
449 "Please use PyTorch built MKL support.");
450 #else
451 sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result);
452 #endif
453 }
454
triangular_solve_out_sparse_csr(const Tensor & A,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)455 void triangular_solve_out_sparse_csr(
456 const Tensor& A,
457 const Tensor& B,
458 const Tensor& X,
459 bool upper,
460 bool transpose,
461 bool unitriangular) {
462 #if !AT_MKL_ENABLED()
463 TORCH_CHECK(
464 false,
465 "Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. ",
466 "Please use PyTorch built MKL support.");
467 #else
468 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == kSparseCsr || A.layout() == kSparseBsr);
469 sparse::impl::mkl::triangular_solve_out_sparse_csr(A, B, X, upper, transpose, unitriangular);
470 #endif
471 }
472
473 } // namespace cpu
474 } // namespace at::native::sparse::impl
475