xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseBlasImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/mkl/Sparse.h>
4 #include <ATen/native/mkl/SparseBlasImpl.h>
5 #include <ATen/native/sparse/SparseBlasImpl.h>
6 #include <ATen/SparseCsrTensorUtils.h>
7 
8 // Required for checking whether Triton kernels are available
9 #include <ATen/core/dispatch/Dispatcher.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #include <ATen/Operators.h>
15 #else
16 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
17 #include <ATen/ops/empty_like.h>
18 #include <ATen/ops/zeros.h>
19 #endif
20 
21 #if !AT_USE_MKL_SPARSE()
22 #include <ATen/Dispatch.h>
23 #include <ATen/Parallel.h>
24 #endif
25 
26 
27 namespace at::native::sparse::impl {
28 
29 namespace {
30 
31 #ifndef USE_ROCM
operands_support_triton_mm_kernel(const Tensor & compressed,const Tensor & strided)32 bool operands_support_triton_mm_kernel(const Tensor& compressed, const Tensor& strided) {
33   // Triton works only with blocksizes which are powers of 2.
34   const auto is_power_of_2 = [](int64_t v) -> bool {
35     return !(v & (v - 1));
36   };
37   return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(compressed.layout(), "operands_support_triton_mm_kernel", [&] { return false; },
38      [&] {
39        const auto blocksize = at::sparse_csr::getBlockSize(compressed);
40        // Dtype and blocksize checks for potential Triton usage.
41        return ((strided.scalar_type() == ScalarType::Half
42                 || strided.scalar_type() == ScalarType::BFloat16
43                 || strided.scalar_type() == ScalarType::Float)
44                && compressed.scalar_type() == strided.scalar_type()
45                && is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
46                && (blocksize[0] >= 16) && (blocksize[1] >= 16)
47                // lhs is retiled to (b0, b1) while rhs is to (b1, b0),
48                // so the result is tiled to (b0, b0) and we need to make
49                // sure that strided.size(-1) is divisible by b0.
50                && strided.size(-1) % blocksize[0] == 0);
51      });
52 }
53 #endif
54 }
55 
_compressed_row_strided_mm_out(const Tensor & compressed,const Tensor & strided,Tensor & result)56 Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& strided, Tensor& result) {
57   const auto compressed_layout = compressed.layout();
58   const auto compressed_layout_str = at::sparse_csr::layoutToString(compressed_layout);
59 
60   // Device restrictions
61   TORCH_CHECK(compressed.device() == strided.device()
62       && compressed.device() == result.device(),
63       "spmm_out(): all input arguments are expected to be on the same device.");
64 
65   // Layout restrictions.
66   TORCH_CHECK(compressed_layout == kSparseCsr || compressed_layout == kSparseBsr,
67       "spmm(", compressed_layout_str, ", Strided): only Csr and Bsr formats are supported for the sparse argument.");
68   TORCH_CHECK(result.layout() == kStrided,
69       "spmm_out(): out argument is expected to be strided.");
70 
71   // Dtype restrictions.
72   TORCH_CHECK(compressed.scalar_type() == strided.scalar_type(),
73       "spmm(", compressed_layout_str, ", Strided): arguments expected to have the same dtype.");
74 
75   // Dim restrictions.
76   TORCH_CHECK(compressed.dim() == 2,
77       "spmm(", compressed_layout_str, ", Strided): sparse arguments which are not 2D are not supported.");
78   TORCH_CHECK(strided.dim() >= 2,
79       "spmm(", compressed_layout_str, ", Strided): expects strided inputs to be at least 2D.");
80 
81   const auto m = compressed.sizes()[0];
82   const auto k = compressed.sizes()[1];
83   const auto n = strided.size(-1);
84   // Matrix product size compatibility.
85   TORCH_CHECK(strided.size(-2) == k,
86       "spmm(", compressed_layout_str, "Strided): argument sizes are not compatible for matrix multiplication. ",
87       "Got ", compressed_layout_str, ".sizes(-1) == ", k, " is not equal to ",
88       "Strided.sizes(-2) == ", strided.size(-2), ".");
89 
90   // We assume that result is properly resized.
91   auto result_expected_size = at::DimVector(strided.sizes().slice(0, strided.dim() - 2));
92   result_expected_size.push_back(m);
93   result_expected_size.push_back(n);
94   TORCH_CHECK(result.sizes() == result_expected_size,
95       "spmm_out(): out argument has wrong size. ",
96       "Expected (", result_expected_size, ") but got (", result.sizes(), ").");
97 
98   auto values = compressed.values();
99 
100   using Blocksize = std::array<int64_t, 2>;
101   // We refer to these as (b0, b1) in the comments below.
102   Blocksize blocksize = {1, 1};
103   if (compressed_layout == kSparseBsr) {
104     blocksize = {values.size(-2), values.size(-1)};
105   }
106 
107 // No stable support for ROCM in Triton yet.
108 #ifndef USE_ROCM
109 
110   if (operands_support_triton_mm_kernel(compressed, strided)) {
111     const auto triton_schema = c10::Dispatcher::singleton()
112       .findSchema({"triton::_triton_bsr_dense_mm_out", ""});
113     if (triton_schema.has_value()) {
114       const auto triton_kernel = triton_schema.value().typed<Tensor&(const Tensor&, const Tensor&, Tensor&)>();
115       if (triton_kernel.hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
116         return triton_kernel.call(compressed, strided, result);
117       }
118     } /* else the schema is not defined and/or the key is not
119          overwritten, so skip and execute the code below. */
120   }
121 #endif
122 
123   // (..., r, c) -> (..., r / b0, c / b1, b0, b1)
124   // NOTE: this function ALWAYS creates a view upon successful execution.
125   const auto tile_tensor = [compressed_layout](
126       const Tensor& t, Blocksize blocksize) -> Tensor {
127     if (compressed_layout == kSparseCsr) {
128       return t.unsqueeze(-1).unsqueeze_(-1);
129     }
130     else {
131       const auto size_neg_2_blocked = t.size(-2) / blocksize[0];
132       const auto size_neg_1_blocked = t.size(-1) / blocksize[1];
133       auto tiled_sizes = at::DimVector(t.sizes().slice(0, t.dim() - 2));
134       tiled_sizes.push_back(size_neg_2_blocked);
135       tiled_sizes.push_back(blocksize[0]);
136       tiled_sizes.push_back(size_neg_1_blocked);
137       tiled_sizes.push_back(blocksize[1]);
138       return t.reshape(tiled_sizes).transpose(-3, -2);
139     }
140   };
141 
142   // Note that sparse values are (..., b0, b1). This means that
143   // the strided input has to be "tilable" to (..., b1, x) with
144   // any x >= 1 such that all the shapes are (block) matrix product
145   // compatible. The matrix product will then have shape (..., b0, x).
146   // This in turn means the result has to be "tilable" to
147   // (..., b0, x).
148   //
149   // These observations imply the following restrictions:
150   // 1. strided.size(-2) has to be divisible by b1.
151   // 2. result.size(-2) has to be divisible by b0.
152   // 3. both strided.size(-1) and result.size(-1)
153   //    have to be divisible by x.
154   //
155   // Restrictions 1 and 2 are trivially satisfied.
156   // Regarding restriction 3:
157   // it would make sense to take the largest possible x for better
158   // performance since it is very likely that the last dimension
159   // is contiguous. As such, this value is exactly
160   // x = strided.size(-1), since strided.size(-1) == result.size(-1)
161 
162   // See the comments above. This is our x.
163   const auto outer_blocksize = n;
164 
165   Blocksize strided_blocksize = {blocksize[1], outer_blocksize};
166   const auto strided_tiled = tile_tensor(strided, strided_blocksize);
167 
168   // Left argument is (..., b0, b1) and right is (..., b1, x).
169   // This naturally implies the result should be "tilable" as
170   // (..., b0, x).
171   Blocksize result_blocksize = {blocksize[0], outer_blocksize};
172   auto result_tiled = tile_tensor(result, result_blocksize);
173 
174   if (compressed_layout == kSparseCsr) {
175     values.unsqueeze_(-1).unsqueeze_(-1);
176   }
177 
178   auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(compressed);
179 
180   // Select block rows of the strided input that intersect with the block columns of the sparse input.
181   auto strided_tiled_selected_rows = strided_tiled.index_select(-4, plain_indices);
182 
183   // Promote to float if output is half or bfloat16 for better precision
184   const auto mm_dtype = (result.scalar_type() == kHalf || result.scalar_type() == kBFloat16)
185     ? kFloat : result.scalar_type();
186   // Now that we know which block rows intersect with which block columns,
187   // we can perform matrix products between pairs of blocks.
188   // NOTE: .to is a no-op when result.scalar_type() == mm_dtype.
189   const auto pairwise_block_mm = values.unsqueeze(-3).to(mm_dtype)
190     .matmul(strided_tiled_selected_rows.to(mm_dtype));
191 
192   // Having pairwise block matrix products stored in pairwise_block_mm,
193   // it is sufficient to sum all the block products that share the same row
194   // encoded in the sparse index. Since the reduction step is done via
195   // advanced indexing methods, the compressed index ought to get converted
196   // to the COO format.
197   const auto compressed_indices_coo = at::_convert_indices_from_csr_to_coo(
198       compressed_indices,
199       plain_indices,
200       compressed_indices.scalar_type() == kInt).select(0, 0);
201 
202   // Reduction step.
203   // If result is neither half nor bfloat16, do everything in-place.
204   if (result.scalar_type() == mm_dtype) {
205     // Zero out and sum over the blocks that share the same row indices.
206     result_tiled.zero_();
207     result_tiled.index_add_(
208         /*dim=*/-4,
209         /*index=*/compressed_indices_coo,
210         /*source=*/pairwise_block_mm);
211   }
212   // Otherwise accumulate into a buffer and then copy.
213   else {
214     // No need to zero out, sum over the blocks goes into a buffer
215     // followed by a copy into result.
216     auto promoted_result_tiled = at::zeros(
217         result_tiled.sizes(),
218         result_tiled.options().dtype(mm_dtype));
219     promoted_result_tiled.index_add_(
220         /*dim=*/-4,
221         /*index=*/compressed_indices_coo,
222         /*source=*/pairwise_block_mm);
223     result_tiled.copy_(promoted_result_tiled);
224   }
225 
226   return result;
227 }
228 
_compressed_row_strided_addmm_out(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)229 Tensor& _compressed_row_strided_addmm_out(
230     const Tensor& self,
231     const Tensor& mat1,
232     const Tensor& mat2,
233     const Scalar& beta,
234     const Scalar& alpha,
235     Tensor& result) {
236 
237 // No stable support for ROCM in Triton yet.
238 #ifndef USE_ROCM
239   if (operands_support_triton_mm_kernel(mat1, mat2)) {
240     const auto triton_schema = c10::Dispatcher::singleton()
241       .findSchema({"triton::_triton_bsr_dense_addmm_out", ""});
242     if (triton_schema.has_value()) {
243       const auto triton_kernel = triton_schema.value().typed<Tensor&(const Tensor&, const Tensor&, const Tensor&, const Scalar&, const Scalar&, Tensor&)>();
244       if (triton_kernel.hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
245         try {
246           return triton_kernel.call(self, mat1, mat2, beta, alpha, result);
247         } catch (std::runtime_error& e) {
248           const std::string msg = e.what();
249           if (msg != std::string("Unable to cast NotImplemented to Tensor")) {
250             throw std::runtime_error(msg);
251           }
252         } /* else triton_kernel returned NotImplemented, continue
253              with the generic method below */
254       }
255     } /* else the schema is not defined and/or the key is not
256            overwritten, so skip and execute the code below. */
257   }
258 #endif
259 
260   auto alpha_val = alpha.toComplexDouble();
261   auto beta_val = beta.toComplexDouble();
262   // If result is not the same as self, it could always be used as out argument to mm.
263   if (!result.is_same(self)) {
264     _compressed_row_strided_mm_out(mat1, mat2, result);
265     if (alpha_val != 1.) {
266       result.mul_(alpha);
267     }
268     // Process beta
269     if (beta_val != 0.) {
270       if (beta_val == 1.) {
271         result.add_(self);
272       } else {
273         result.add_(self.mul(beta));
274       }
275     }
276   }
277   // Otherwise we need to allocate external memory for mm if beta != 0.
278   else {
279     // Process beta
280     if (beta_val != 0.) {
281       if (beta_val != 1.) {
282         result.mul_(beta);
283       }
284       auto mm = at::empty_like(result);
285       _compressed_row_strided_mm_out(mat1, mat2, mm);
286       if (alpha_val != 1.) {
287         mm.mul_(alpha);
288       }
289       result.add_(mm);
290     }
291     else {
292       _compressed_row_strided_mm_out(mat1, mat2, result);
293       if (alpha_val != 1.) {
294         result.mul_(alpha);
295       }
296     }
297   }
298 
299   return result;
300 }
301 
302 namespace cpu {
303 #if !AT_USE_MKL_SPARSE()
304 namespace {
305 template<typename scalar_t, typename idx_t>
addmv_sparse_csr(const scalar_t * mat_values,const idx_t * crow_index,const idx_t * col_index,const int64_t mat_rows,const scalar_t * vec,const size_t vec_stride,const scalar_t alpha,const scalar_t beta,scalar_t * result,const size_t result_stride)306 void addmv_sparse_csr(
307     const scalar_t* mat_values,
308     const idx_t* crow_index,
309     const idx_t* col_index,
310     const int64_t mat_rows,
311     const scalar_t* vec,
312     const size_t vec_stride,
313     const scalar_t alpha,
314     const scalar_t beta,
315     scalar_t* result,
316     const size_t result_stride) {
317   at::parallel_for(0, mat_rows, 0, [&](int64_t rstart, int64_t rend) {
318     for(const auto row: c10::irange(rstart, rend)) {
319       scalar_t acc(0);
320       for(const auto idx: c10::irange(crow_index[row], crow_index[row + 1])) {
321         acc += mat_values[idx] * vec[col_index[idx] * vec_stride];
322       }
323       result[row * result_stride] = acc * alpha + result[row * result_stride] * beta;
324     }
325   });
326 }
327 
328 template<typename scalar_t, typename idx_t>
addmv_sparse_bsr(const scalar_t * mat_values,const idx_t * crow_index,const idx_t * col_index,const int64_t mat_rows,const int64_t blocksize_rows,const int64_t blocksize_cols,const scalar_t * vec,const size_t vec_stride,const scalar_t alpha,const scalar_t beta,scalar_t * result,const size_t result_stride)329 void addmv_sparse_bsr(
330     const scalar_t* mat_values,
331     const idx_t* crow_index,
332     const idx_t* col_index,
333     const int64_t mat_rows,
334     const int64_t blocksize_rows,
335     const int64_t blocksize_cols,
336     const scalar_t* vec,
337     const size_t vec_stride,
338     const scalar_t alpha,
339     const scalar_t beta,
340     scalar_t* result,
341     const size_t result_stride) {
342   at::parallel_for(0, mat_rows, 0, [&](int64_t rstart, int64_t rend) {
343     for(const auto row: c10::irange(rstart, rend)) {
344       const auto block_row = row / blocksize_rows;
345       const auto block_row_offset = row % blocksize_rows;
346       scalar_t acc(0);
347       for(const auto block_idx: c10::irange(crow_index[block_row], crow_index[block_row + 1])) {
348         const auto block_offs = (block_idx * blocksize_rows + block_row_offset) * blocksize_cols;
349         const auto vec_offs = col_index[block_idx]* blocksize_cols;
350         for(const auto idx: c10::irange(blocksize_cols)) {
351           acc += mat_values[block_offs + idx] * vec[(vec_offs + idx) * vec_stride];
352         }
353       }
354       result[row * result_stride] = acc * alpha + result[row * result_stride] * beta;
355     }
356   });
357 }
358 
359 template<typename scalar_t, typename idx_t>
addmv_out_sparse_csr(const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)360 void addmv_out_sparse_csr(
361     const Tensor& mat,
362     const Tensor& vec,
363     const Scalar& beta,
364     const Scalar& alpha,
365     const Tensor& result) {
366   auto cont_values = mat.values().contiguous();
367   if (mat.layout() == kSparseBsr) {
368     addmv_sparse_bsr(cont_values.data_ptr<scalar_t>(),
369         mat.crow_indices().data_ptr<idx_t>(),
370         mat.col_indices().data_ptr<idx_t>(),
371         mat.size(0),
372         mat.values().size(1),
373         mat.values().size(2),
374         vec.data_ptr<scalar_t>(),
375         vec.stride(0),
376         alpha.to<scalar_t>(),
377         beta.to<scalar_t>(),
378         result.data_ptr<scalar_t>(),
379         result.stride(0));
380   } else {
381     addmv_sparse_csr(cont_values.data_ptr<scalar_t>(),
382         mat.crow_indices().data_ptr<idx_t>(),
383         mat.col_indices().data_ptr<idx_t>(),
384         mat.size(0),
385         vec.data_ptr<scalar_t>(),
386         vec.stride(0),
387         alpha.to<scalar_t>(),
388         beta.to<scalar_t>(),
389         result.data_ptr<scalar_t>(),
390         result.stride(0));
391   }
392 }
393 } // anonymous namespace
394 #endif // !AT_USE_MKL_SPARSE()
395 
396 /*
397   Computes a sparse matrix-dense vector product defined as
398   y <- alpha*op(A)*x + beta*y
399 
400   Args:
401   * `mat` - Tensor storing sparse m x n matrix A.
402   * `vec` - Tensor storing dense vector x of size n.
403   * `result` - [in] Tensor storing dense vector y of size m.
404                [out] result of the operation.
405 */
addmv_out_sparse_csr(const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)406 void addmv_out_sparse_csr(
407     const Tensor& mat,
408     const Tensor& vec,
409     const Scalar& beta,
410     const Scalar& alpha,
411     const Tensor& result) {
412 #if !AT_USE_MKL_SPARSE()
413   TORCH_CHECK(mat.layout() == kSparseBsr || mat.layout() == kSparseCsr, "Unexpected layout", mat.layout());
414   if (beta.toComplexDouble() == 0.) {
415     result.zero_();
416   }
417   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
418       result.scalar_type(), "addmv_out_sparse_csr_impl_reference", [&] {
419         if (mat.crow_indices().scalar_type() == kLong) {
420           addmv_out_sparse_csr<scalar_t, int64_t>(mat, vec, beta, alpha, result);
421         } else {
422           addmv_out_sparse_csr<scalar_t, int32_t>(mat, vec, beta, alpha, result);
423         }
424       });
425 #else
426   sparse::impl::mkl::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
427 #endif
428 }
429 
430 /*
431   Computes a sum of two sparse matrices defined as
432   result <- mat1 + alpha*mat2
433 
434   Args:
435   * `mat1` - CSR Tensor storing sparse m x n matrix.
436   * `mat2` - CSR Tensor storing sparse m x n matrix.
437   * `result` - [in] CSR Tensor storing sparse m x n matrix.
438                [out] result of the operation.
439 */
add_out_sparse_csr(const Tensor & mat1,const Tensor & mat2,const Scalar & alpha,const Tensor & result)440 void add_out_sparse_csr(
441     const Tensor& mat1,
442     const Tensor& mat2,
443     const Scalar& alpha,
444     const Tensor& result) {
445 #if !AT_MKL_ENABLED()
446   TORCH_CHECK(
447       false,
448       "Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ",
449       "Please use PyTorch built MKL support.");
450 #else
451   sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result);
452 #endif
453 }
454 
triangular_solve_out_sparse_csr(const Tensor & A,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)455 void triangular_solve_out_sparse_csr(
456     const Tensor& A,
457     const Tensor& B,
458     const Tensor& X,
459     bool upper,
460     bool transpose,
461     bool unitriangular) {
462 #if !AT_MKL_ENABLED()
463   TORCH_CHECK(
464       false,
465       "Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. ",
466       "Please use PyTorch built MKL support.");
467 #else
468   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == kSparseCsr || A.layout() == kSparseBsr);
469   sparse::impl::mkl::triangular_solve_out_sparse_csr(A, B, X, upper, transpose, unitriangular);
470 #endif
471 }
472 
473 } // namespace cpu
474 } // namespace at::native::sparse::impl
475