1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/SparseTensorImpl.h>
7 #include <ATen/native/SparseTensorUtils.h>
8 #include <ATen/native/Resize.h>
9 #include <ATen/native/StridedRandomAccessor.h>
10 #include <ATen/native/CompositeRandomAccessor.h>
11 #include <c10/util/irange.h>
12 #include <unordered_map>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/_sparse_sparse_matmul_native.h>
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/empty_like_native.h>
21 #endif
22
23 namespace at::native {
24
25 using namespace at::sparse;
26
27 /*
28 This is an implementation of the SMMP algorithm:
29 "Sparse Matrix Multiplication Package (SMMP)"
30
31 Randolph E. Bank and Craig C. Douglas
32 https://doi.org/10.1007/BF02070824
33 */
34 namespace {
35 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
csr_to_coo(const int64_t n_row,const int64_t Ap[],int64_t Bi[])36 void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
37 /*
38 Expands a compressed row pointer into a row indices array
39 Inputs:
40 `n_row` is the number of rows in `Ap`
41 `Ap` is the row pointer
42
43 Output:
44 `Bi` is the row indices
45 */
46 for (const auto i : c10::irange(n_row)) {
47 for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
48 Bi[jj] = i;
49 }
50 }
51 }
52
53 template<typename index_t_ptr = int64_t*>
_csr_matmult_maxnnz(const int64_t n_row,const int64_t n_col,const index_t_ptr Ap,const index_t_ptr Aj,const index_t_ptr Bp,const index_t_ptr Bj)54 int64_t _csr_matmult_maxnnz(
55 const int64_t n_row,
56 const int64_t n_col,
57 const index_t_ptr Ap,
58 const index_t_ptr Aj,
59 const index_t_ptr Bp,
60 const index_t_ptr Bj) {
61 /*
62 Compute needed buffer size for matrix `C` in `C = A@B` operation.
63
64 The matrices should be in proper CSR structure, and their dimensions
65 should be compatible.
66 */
67 std::vector<int64_t> mask(n_col, -1);
68 int64_t nnz = 0;
69 for (const auto i : c10::irange(n_row)) {
70 int64_t row_nnz = 0;
71
72 for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
73 int64_t j = Aj[jj];
74 for (int64_t kk = Bp[j]; kk < Bp[j + 1]; kk++) {
75 int64_t k = Bj[kk];
76 if (mask[k] != i) {
77 mask[k] = i;
78 row_nnz++;
79 }
80 }
81 }
82 int64_t next_nnz = nnz + row_nnz;
83 nnz = next_nnz;
84 }
85 return nnz;
86 }
87
88 template<typename index_t_ptr, typename scalar_t_ptr>
_csr_matmult(const int64_t n_row,const int64_t n_col,const index_t_ptr Ap,const index_t_ptr Aj,const scalar_t_ptr Ax,const index_t_ptr Bp,const index_t_ptr Bj,const scalar_t_ptr Bx,typename index_t_ptr::value_type Cp[],typename index_t_ptr::value_type Cj[],typename scalar_t_ptr::value_type Cx[])89 void _csr_matmult(
90 const int64_t n_row,
91 const int64_t n_col,
92 const index_t_ptr Ap,
93 const index_t_ptr Aj,
94 const scalar_t_ptr Ax,
95 const index_t_ptr Bp,
96 const index_t_ptr Bj,
97 const scalar_t_ptr Bx,
98 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
99 typename index_t_ptr::value_type Cp[],
100 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
101 typename index_t_ptr::value_type Cj[],
102 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
103 typename scalar_t_ptr::value_type Cx[]) {
104 /*
105 Compute CSR entries for matrix C = A@B.
106
107 The matrices `A` and 'B' should be in proper CSR structure, and their dimensions
108 should be compatible.
109
110 Inputs:
111 `n_row` - number of row in A
112 `n_col` - number of columns in B
113 `Ap[n_row+1]` - row pointer
114 `Aj[nnz(A)]` - column indices
115 `Ax[nnz(A)] - nonzeros
116 `Bp[?]` - row pointer
117 `Bj[nnz(B)]` - column indices
118 `Bx[nnz(B)]` - nonzeros
119 Outputs:
120 `Cp[n_row+1]` - row pointer
121 `Cj[nnz(C)]` - column indices
122 `Cx[nnz(C)]` - nonzeros
123
124 Note:
125 Output arrays Cp, Cj, and Cx must be preallocated
126 */
127 using index_t = typename index_t_ptr::value_type;
128 using scalar_t = typename scalar_t_ptr::value_type;
129
130 std::vector<index_t> next(n_col, -1);
131 std::vector<scalar_t> sums(n_col, 0);
132
133 int64_t nnz = 0;
134
135 Cp[0] = 0;
136
137 for (const auto i : c10::irange(n_row)) {
138 index_t head = -2;
139 index_t length = 0;
140
141 index_t jj_start = Ap[i];
142 index_t jj_end = Ap[i + 1];
143 for (const auto jj : c10::irange(jj_start, jj_end)) {
144 index_t j = Aj[jj];
145 scalar_t v = Ax[jj];
146
147 index_t kk_start = Bp[j];
148 index_t kk_end = Bp[j + 1];
149 for (const auto kk : c10::irange(kk_start, kk_end)) {
150 index_t k = Bj[kk];
151
152 sums[k] += v * Bx[kk];
153
154 if (next[k] == -1) {
155 next[k] = head;
156 head = k;
157 length++;
158 }
159 }
160 }
161
162 for (C10_UNUSED const auto jj : c10::irange(length)) {
163
164 // NOTE: the linked list that encodes col indices
165 // is not guaranteed to be sorted.
166 Cj[nnz] = head;
167 Cx[nnz] = sums[head];
168 nnz++;
169
170 index_t temp = head;
171 head = next[head];
172
173 next[temp] = -1; // clear arrays
174 sums[temp] = 0;
175 }
176
177 // Make sure that col indices are sorted.
178 // TODO: a better approach is to implement a CSR @ CSC kernel.
179 // NOTE: Cx arrays are expected to be contiguous!
180 auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
181 auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
182 auto kv_accessor = CompositeRandomAccessorCPU<
183 decltype(col_indices_accessor), decltype(val_accessor)
184 >(col_indices_accessor, val_accessor);
185 std::sort(kv_accessor, kv_accessor + length, [](const auto& lhs, const auto& rhs) -> bool {
186 return get<0>(lhs) < get<0>(rhs);
187 });
188
189 Cp[i + 1] = nnz;
190 }
191 }
192
193
194 template <typename scalar_t>
sparse_matmul_kernel(Tensor & output,const Tensor & mat1,const Tensor & mat2)195 void sparse_matmul_kernel(
196 Tensor& output,
197 const Tensor& mat1,
198 const Tensor& mat2) {
199 /*
200 Computes the sparse-sparse matrix multiplication between `mat1` and `mat2`, which are sparse tensors in COO format.
201 */
202
203 auto M = mat1.size(0);
204 auto N = mat2.size(1);
205
206 const auto mat1_csr = mat1.to_sparse_csr();
207 const auto mat2_csr = mat2.to_sparse_csr();
208
209 auto mat1_crow_indices_ptr = StridedRandomAccessor<int64_t>(
210 mat1_csr.crow_indices().data_ptr<int64_t>(),
211 mat1_csr.crow_indices().stride(-1));
212 auto mat1_col_indices_ptr = StridedRandomAccessor<int64_t>(
213 mat1_csr.col_indices().data_ptr<int64_t>(),
214 mat1_csr.col_indices().stride(-1));
215 auto mat1_values_ptr = StridedRandomAccessor<scalar_t>(
216 mat1_csr.values().data_ptr<scalar_t>(),
217 mat1_csr.values().stride(-1));
218 auto mat2_crow_indices_ptr = StridedRandomAccessor<int64_t>(
219 mat2_csr.crow_indices().data_ptr<int64_t>(),
220 mat2_csr.crow_indices().stride(-1));
221 auto mat2_col_indices_ptr = StridedRandomAccessor<int64_t>(
222 mat2_csr.col_indices().data_ptr<int64_t>(),
223 mat2_csr.col_indices().stride(-1));
224 auto mat2_values_ptr = StridedRandomAccessor<scalar_t>(
225 mat2_csr.values().data_ptr<scalar_t>(),
226 mat2_csr.values().stride(-1));
227
228 const auto nnz = _csr_matmult_maxnnz(
229 M,
230 N,
231 mat1_crow_indices_ptr,
232 mat1_col_indices_ptr,
233 mat2_crow_indices_ptr,
234 mat2_col_indices_ptr);
235
236 auto output_indices = output._indices();
237 auto output_values = output._values();
238
239 Tensor output_indptr = at::empty({M + 1}, kLong);
240 at::native::resize_output(output_indices, {2, nnz});
241 at::native::resize_output(output_values, nnz);
242
243 Tensor output_row_indices = output_indices.select(0, 0);
244 Tensor output_col_indices = output_indices.select(0, 1);
245
246 // TODO: replace with a CSR @ CSC kernel for better performance.
247 _csr_matmult(
248 M,
249 N,
250 mat1_crow_indices_ptr,
251 mat1_col_indices_ptr,
252 mat1_values_ptr,
253 mat2_crow_indices_ptr,
254 mat2_col_indices_ptr,
255 mat2_values_ptr,
256 output_indptr.data_ptr<int64_t>(),
257 output_col_indices.data_ptr<int64_t>(),
258 output_values.data_ptr<scalar_t>());
259
260 csr_to_coo(M, output_indptr.data_ptr<int64_t>(), output_row_indices.data_ptr<int64_t>());
261 output._coalesced_(true);
262 }
263
264 } // end anonymous namespace
265
sparse_sparse_matmul_cpu(const Tensor & mat1_,const Tensor & mat2_)266 Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
267 TORCH_INTERNAL_ASSERT(mat1_.is_sparse());
268 TORCH_INTERNAL_ASSERT(mat2_.is_sparse());
269 TORCH_CHECK(mat1_.dim() == 2);
270 TORCH_CHECK(mat2_.dim() == 2);
271 TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat1_.dense_dim(), "D values");
272 TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat2_.dense_dim(), "D values");
273
274 TORCH_CHECK(
275 mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (",
276 mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
277
278 TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
279 "mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type());
280
281 auto output = at::native::empty_like(mat1_);
282 output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
283
284 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
285 sparse_matmul_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
286 });
287 return output;
288 }
289
290
291 } // namespace at::native
292