xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseMatMul.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/SparseTensorImpl.h>
7 #include <ATen/native/SparseTensorUtils.h>
8 #include <ATen/native/Resize.h>
9 #include <ATen/native/StridedRandomAccessor.h>
10 #include <ATen/native/CompositeRandomAccessor.h>
11 #include <c10/util/irange.h>
12 #include <unordered_map>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/_sparse_sparse_matmul_native.h>
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/empty_like_native.h>
21 #endif
22 
23 namespace at::native {
24 
25 using namespace at::sparse;
26 
27 /*
28     This is an implementation of the SMMP algorithm:
29      "Sparse Matrix Multiplication Package (SMMP)"
30 
31       Randolph E. Bank and Craig C. Douglas
32       https://doi.org/10.1007/BF02070824
33 */
34 namespace {
35 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
csr_to_coo(const int64_t n_row,const int64_t Ap[],int64_t Bi[])36 void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
37   /*
38     Expands a compressed row pointer into a row indices array
39     Inputs:
40       `n_row` is the number of rows in `Ap`
41       `Ap` is the row pointer
42 
43     Output:
44       `Bi` is the row indices
45   */
46   for (const auto i : c10::irange(n_row)) {
47     for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
48       Bi[jj] = i;
49     }
50   }
51 }
52 
53 template<typename index_t_ptr = int64_t*>
_csr_matmult_maxnnz(const int64_t n_row,const int64_t n_col,const index_t_ptr Ap,const index_t_ptr Aj,const index_t_ptr Bp,const index_t_ptr Bj)54 int64_t _csr_matmult_maxnnz(
55     const int64_t n_row,
56     const int64_t n_col,
57     const index_t_ptr Ap,
58     const index_t_ptr Aj,
59     const index_t_ptr Bp,
60     const index_t_ptr Bj) {
61   /*
62     Compute needed buffer size for matrix `C` in `C = A@B` operation.
63 
64     The matrices should be in proper CSR structure, and their dimensions
65     should be compatible.
66   */
67   std::vector<int64_t> mask(n_col, -1);
68   int64_t nnz = 0;
69   for (const auto i : c10::irange(n_row)) {
70     int64_t row_nnz = 0;
71 
72     for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
73       int64_t j = Aj[jj];
74       for (int64_t kk = Bp[j]; kk < Bp[j + 1]; kk++) {
75         int64_t k = Bj[kk];
76         if (mask[k] != i) {
77           mask[k] = i;
78           row_nnz++;
79         }
80       }
81     }
82     int64_t next_nnz = nnz + row_nnz;
83     nnz = next_nnz;
84   }
85   return nnz;
86 }
87 
88 template<typename index_t_ptr, typename scalar_t_ptr>
_csr_matmult(const int64_t n_row,const int64_t n_col,const index_t_ptr Ap,const index_t_ptr Aj,const scalar_t_ptr Ax,const index_t_ptr Bp,const index_t_ptr Bj,const scalar_t_ptr Bx,typename index_t_ptr::value_type Cp[],typename index_t_ptr::value_type Cj[],typename scalar_t_ptr::value_type Cx[])89 void _csr_matmult(
90     const int64_t n_row,
91     const int64_t n_col,
92     const index_t_ptr Ap,
93     const index_t_ptr Aj,
94     const scalar_t_ptr Ax,
95     const index_t_ptr Bp,
96     const index_t_ptr Bj,
97     const scalar_t_ptr Bx,
98     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
99     typename index_t_ptr::value_type Cp[],
100     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
101     typename index_t_ptr::value_type Cj[],
102     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
103     typename scalar_t_ptr::value_type Cx[]) {
104   /*
105     Compute CSR entries for matrix C = A@B.
106 
107     The matrices `A` and 'B' should be in proper CSR structure, and their dimensions
108     should be compatible.
109 
110     Inputs:
111       `n_row`         - number of row in A
112       `n_col`         - number of columns in B
113       `Ap[n_row+1]`   - row pointer
114       `Aj[nnz(A)]`    - column indices
115       `Ax[nnz(A)]     - nonzeros
116       `Bp[?]`         - row pointer
117       `Bj[nnz(B)]`    - column indices
118       `Bx[nnz(B)]`    - nonzeros
119     Outputs:
120       `Cp[n_row+1]` - row pointer
121       `Cj[nnz(C)]`  - column indices
122       `Cx[nnz(C)]`  - nonzeros
123 
124     Note:
125       Output arrays Cp, Cj, and Cx must be preallocated
126   */
127   using index_t = typename index_t_ptr::value_type;
128   using scalar_t = typename scalar_t_ptr::value_type;
129 
130   std::vector<index_t> next(n_col, -1);
131   std::vector<scalar_t> sums(n_col, 0);
132 
133   int64_t nnz = 0;
134 
135   Cp[0] = 0;
136 
137   for (const auto i : c10::irange(n_row)) {
138     index_t head = -2;
139     index_t length = 0;
140 
141     index_t jj_start = Ap[i];
142     index_t jj_end = Ap[i + 1];
143     for (const auto jj : c10::irange(jj_start, jj_end)) {
144       index_t j = Aj[jj];
145       scalar_t v = Ax[jj];
146 
147       index_t kk_start = Bp[j];
148       index_t kk_end = Bp[j + 1];
149       for (const auto kk : c10::irange(kk_start, kk_end)) {
150         index_t k = Bj[kk];
151 
152         sums[k] += v * Bx[kk];
153 
154         if (next[k] == -1) {
155           next[k] = head;
156           head = k;
157           length++;
158         }
159       }
160     }
161 
162     for (C10_UNUSED const auto jj : c10::irange(length)) {
163 
164       // NOTE: the linked list that encodes col indices
165       // is not guaranteed to be sorted.
166       Cj[nnz] = head;
167       Cx[nnz] = sums[head];
168       nnz++;
169 
170       index_t temp = head;
171       head = next[head];
172 
173       next[temp] = -1; // clear arrays
174       sums[temp] = 0;
175     }
176 
177     // Make sure that col indices are sorted.
178     // TODO: a better approach is to implement a CSR @ CSC kernel.
179     // NOTE: Cx arrays are expected to be contiguous!
180     auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
181     auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
182     auto kv_accessor = CompositeRandomAccessorCPU<
183       decltype(col_indices_accessor), decltype(val_accessor)
184     >(col_indices_accessor, val_accessor);
185     std::sort(kv_accessor, kv_accessor + length, [](const auto& lhs, const auto& rhs) -> bool {
186         return get<0>(lhs) < get<0>(rhs);
187     });
188 
189     Cp[i + 1] = nnz;
190   }
191 }
192 
193 
194 template <typename scalar_t>
sparse_matmul_kernel(Tensor & output,const Tensor & mat1,const Tensor & mat2)195 void sparse_matmul_kernel(
196     Tensor& output,
197     const Tensor& mat1,
198     const Tensor& mat2) {
199   /*
200     Computes  the sparse-sparse matrix multiplication between `mat1` and `mat2`, which are sparse tensors in COO format.
201   */
202 
203   auto M = mat1.size(0);
204   auto N = mat2.size(1);
205 
206   const auto mat1_csr = mat1.to_sparse_csr();
207   const auto mat2_csr = mat2.to_sparse_csr();
208 
209   auto mat1_crow_indices_ptr = StridedRandomAccessor<int64_t>(
210       mat1_csr.crow_indices().data_ptr<int64_t>(),
211       mat1_csr.crow_indices().stride(-1));
212   auto mat1_col_indices_ptr = StridedRandomAccessor<int64_t>(
213       mat1_csr.col_indices().data_ptr<int64_t>(),
214       mat1_csr.col_indices().stride(-1));
215   auto mat1_values_ptr = StridedRandomAccessor<scalar_t>(
216       mat1_csr.values().data_ptr<scalar_t>(),
217       mat1_csr.values().stride(-1));
218   auto mat2_crow_indices_ptr = StridedRandomAccessor<int64_t>(
219       mat2_csr.crow_indices().data_ptr<int64_t>(),
220       mat2_csr.crow_indices().stride(-1));
221   auto mat2_col_indices_ptr = StridedRandomAccessor<int64_t>(
222       mat2_csr.col_indices().data_ptr<int64_t>(),
223       mat2_csr.col_indices().stride(-1));
224   auto mat2_values_ptr = StridedRandomAccessor<scalar_t>(
225       mat2_csr.values().data_ptr<scalar_t>(),
226       mat2_csr.values().stride(-1));
227 
228   const auto nnz = _csr_matmult_maxnnz(
229       M,
230       N,
231       mat1_crow_indices_ptr,
232       mat1_col_indices_ptr,
233       mat2_crow_indices_ptr,
234       mat2_col_indices_ptr);
235 
236   auto output_indices = output._indices();
237   auto output_values = output._values();
238 
239   Tensor output_indptr = at::empty({M + 1}, kLong);
240   at::native::resize_output(output_indices, {2, nnz});
241   at::native::resize_output(output_values, nnz);
242 
243   Tensor output_row_indices = output_indices.select(0, 0);
244   Tensor output_col_indices = output_indices.select(0, 1);
245 
246   // TODO: replace with a CSR @ CSC kernel for better performance.
247   _csr_matmult(
248       M,
249       N,
250       mat1_crow_indices_ptr,
251       mat1_col_indices_ptr,
252       mat1_values_ptr,
253       mat2_crow_indices_ptr,
254       mat2_col_indices_ptr,
255       mat2_values_ptr,
256       output_indptr.data_ptr<int64_t>(),
257       output_col_indices.data_ptr<int64_t>(),
258       output_values.data_ptr<scalar_t>());
259 
260   csr_to_coo(M, output_indptr.data_ptr<int64_t>(), output_row_indices.data_ptr<int64_t>());
261   output._coalesced_(true);
262 }
263 
264 } // end anonymous namespace
265 
sparse_sparse_matmul_cpu(const Tensor & mat1_,const Tensor & mat2_)266 Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
267   TORCH_INTERNAL_ASSERT(mat1_.is_sparse());
268   TORCH_INTERNAL_ASSERT(mat2_.is_sparse());
269   TORCH_CHECK(mat1_.dim() == 2);
270   TORCH_CHECK(mat2_.dim() == 2);
271   TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat1_.dense_dim(), "D values");
272   TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat2_.dense_dim(), "D values");
273 
274   TORCH_CHECK(
275       mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (",
276       mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
277 
278   TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
279            "mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type());
280 
281   auto output = at::native::empty_like(mat1_);
282   output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
283 
284   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
285     sparse_matmul_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
286   });
287   return output;
288 }
289 
290 
291 } // namespace at::native
292