1 // Basic functions on sparse tensors
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/InitialTensorOptions.h>
7 #include <ATen/Layout.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/SparseCsrTensorImpl.h>
10 #include <ATen/SparseCsrTensorUtils.h>
11 #include <ATen/SparseTensorImpl.h>
12 #include <ATen/native/LinearAlgebraUtils.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
19 #include <ATen/ops/_nnz_native.h>
20 #include <ATen/ops/_pin_memory_native.h>
21 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
22 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
23 #include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
24 #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
25 #include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
26 #include <ATen/ops/_sparse_compressed_tensor_with_dims_native.h>
27 #include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
28 #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
29 #include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
30 #include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
31 #include <ATen/ops/_validate_sparse_csc_tensor_args_native.h>
32 #include <ATen/ops/_validate_sparse_bsr_tensor_args_native.h>
33 #include <ATen/ops/_validate_sparse_bsc_tensor_args_native.h>
34 #include <ATen/ops/aminmax.h>
35 #include <ATen/ops/ccol_indices_native.h>
36 #include <ATen/ops/clone_native.h>
37 #include <ATen/ops/col_indices_native.h>
38 #include <ATen/ops/copy_native.h>
39 #include <ATen/ops/crow_indices_native.h>
40 #include <ATen/ops/dense_dim_native.h>
41 #include <ATen/ops/empty.h>
42 #include <ATen/ops/empty_like_native.h>
43 #include <ATen/ops/empty_native.h>
44 #include <ATen/ops/is_pinned_native.h>
45 #include <ATen/ops/resize_as_sparse_native.h>
46 #include <ATen/ops/resize_native.h>
47 #include <ATen/ops/row_indices_native.h>
48 #include <ATen/ops/select_native.h>
49 #include <ATen/ops/select_copy.h>
50 #include <ATen/ops/select_copy_native.h>
51 #include <ATen/ops/sparse_compressed_tensor_native.h>
52 #include <ATen/ops/sparse_csr_tensor_native.h>
53 #include <ATen/ops/sparse_csc_tensor_native.h>
54 #include <ATen/ops/sparse_bsr_tensor_native.h>
55 #include <ATen/ops/sparse_bsc_tensor_native.h>
56 #include <ATen/ops/sparse_dim_native.h>
57 #include <ATen/ops/values_native.h>
58 #include <ATen/ops/_validate_compressed_sparse_indices.h>
59 #include <ATen/ops/where.h>
60 #endif
61
62 namespace at::native {
63
64 using namespace at::sparse_csr;
65
66 namespace {
67
solve_arange(const Tensor & input,int64_t & start,int64_t & end,int64_t & step)68 bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& step) {
69 /*
70 This function solves the equation
71
72 input == arange(start, end, step)
73
74 for integers start, end, and step, if possible. If the solution
75 exists, returns true.
76 */
77 int64_t n = input.numel();
78 if (n == 0) {
79 // a trivial solution
80 start = end = 0;
81 step = 1;
82 } else if (n == 1) {
83 // a simple solution
84 start = input[0].item<int64_t>();
85 end = start + 1;
86 step = 1;
87 } else {
88 Tensor first_last = input.slice(0, 0, n, n - 1).cpu();
89 int64_t start_candidate = first_last[0].item<int64_t>();
90 int64_t end_candidate = first_last[1].item<int64_t>() + 1;
91 if (end_candidate - start_candidate == n) {
92 // a special solution
93 start = start_candidate;
94 end = end_candidate;
95 step = 1;
96 } else {
97 // detect if general solution exists
98 Tensor possible_steps = input.slice(0, 1).sub(input.slice(0, 0, n - 1));
99 Tensor possible_step = possible_steps[0];
100 if ((possible_steps.eq(possible_step)).all().item<bool>()) {
101 start = start_candidate;
102 end = end_candidate;
103 step = possible_step.item<int64_t>();
104 } else {
105 // no solution
106 return false;
107 }
108 }
109 }
110 return true;
111 }
112
113 } // end anonymous namespace
114
115 /*
116 Validate the arguments to sparse compressed (CSR, CSC, BSR, and BSC)
117 tensor factory functions.
118
119 The CSR and BSR invariants for PyTorch are outlined in
120
121 https://pearu.github.io/csr_tensor_invariants.html
122 https://pearu.github.io/bsr_tensor_invariants.html
123
124 that in what follows are generalized for all sparse compressed
125 formats with support to batched and dense dimensions.
126 */
127
_validate_sparse_compressed_tensor_args_worker(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,const IntArrayRef size,const Layout & layout)128 static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) {
129 // Layout must be Sparse Compressed, 2.4
130 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{});
131
132 const std::string layout_name = layoutToString(layout, /*upper=*/ true);
133 const std::string compressed_indices_name = compressedIndicesName(layout);
134 const std::string plain_indices_name = plainIndicesName(layout);
135 const std::string compressed_dim_name = compressedDimName(layout);
136 const std::string plain_dim_name = plainDimName(layout);
137
138 // Layout Invariants
139
140 // Re 3.5 and 3.6: in the case of compressed/plain indices tensors,
141 // we require contiguity per-patch basis, that is, the last stride
142 // of these indices must be 1. The reasoning for this is that
143 // indices tensors within a patch are "atomic" in the sense that
144 // sliced compressed/plain indices would not represent the indices
145 // of any sparse compressed tensor as the slicing would break the
146 // description of the tensor index structure.
147
148 // 2.1
149 TORCH_CHECK(plain_indices.layout() == kStrided,
150 "expected ", plain_indices_name, " to be a strided tensor but got ", plain_indices.layout(), " tensor");
151
152 // 2.2
153 TORCH_CHECK(compressed_indices.layout() == kStrided,
154 "expected ", compressed_indices_name, " to be a strided tensor but got ", compressed_indices.layout(), " tensor");
155
156 const int base_ndim = 2; // corresponds to compressed and plain indices
157 const auto batch_ndim = compressed_indices.dim() - 1;
158 const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
159 layout, "validate_sparse_compressed_tensor_args",
160 [&] { return 0; }, [&] { return 2; });
161 const auto dense_ndim = values.dim() - batch_ndim - block_ndim - 1;
162
163 // 2.3
164 TORCH_CHECK(values.layout() == kStrided,
165 "expected values to be a strided tensor but got ", values.layout(), " tensor");
166
167 // 3.7 is dropped, that is, values tensor does not need to be
168 // contiguous, in general. Particular algorithms on sparse
169 // compressed tensors may require contiguity though.
170
171 // Shape and Strides invariants
172
173 // 3.2
174 TORCH_CHECK(
175 batch_ndim >= 0,
176 compressed_indices_name, " must have dimensionality >= 1 but got ", compressed_indices.dim());
177
178 // 3.3
179 TORCH_CHECK(
180 compressed_indices.dim() == plain_indices.dim(),
181 compressed_indices_name, " and ", plain_indices_name, " dimensionalities must be equal but got ",
182 compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively");
183
184 // 3.4
185 TORCH_CHECK(
186 dense_ndim >= 0,
187 "values must have dimensionality > sum of batch and block dimensionalities (=",
188 batch_ndim, " + ", block_ndim, ") but got ", values.dim());
189
190 // 3.5
191 TORCH_CHECK(plain_indices.stride(-1) == 1,
192 "expected ", plain_indices_name, " to be a contiguous tensor per batch");
193
194 // 3.6
195 TORCH_CHECK(compressed_indices.stride(-1) == 1,
196 "expected ", compressed_indices_name, " to be a contiguous tensor per batch");
197
198 // 3.1
199 TORCH_CHECK(
200 static_cast<int>(size.size()) == batch_ndim + base_ndim + dense_ndim,
201 "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=",
202 batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size());
203
204 // For CSR/CSC formats, we define blocksize=(1, 1) so that checking
205 // the sparse compressed tensor invariants can be unified with the
206 // BSR/BSC invariants.
207 // 3.10
208 DimVector blocksize{
209 (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 1)) : 1),
210 (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 2)) : 1),
211 };
212 TORCH_INTERNAL_ASSERT(blocksize.size() == 2 && blocksize[0] > 0 && blocksize[1] > 0);
213
214 // All batch sizes must be the same and consistent with tensor batchsize, 3.1, 3.8, 3.9, 3.10
215 DimVector batchsize = DimVector(size.slice(0, batch_ndim));
216 DimVector compressed_indices_batchsize = DimVector(compressed_indices.sizes().slice(0, batch_ndim));
217 DimVector plain_indices_batchsize = DimVector(plain_indices.sizes().slice(0, batch_ndim));
218 DimVector values_batchsize = DimVector(values.sizes().slice(0, batch_ndim));
219 const int64_t values_nnz = values.size(batch_ndim);
220 DimVector values_blocksize = DimVector(values.sizes().slice(batch_ndim + 1, block_ndim));
221 DimVector values_densesize = DimVector(values.sizes().slice(batch_ndim + 1 + block_ndim, dense_ndim));
222 TORCH_CHECK(
223 batchsize == compressed_indices_batchsize && batchsize == plain_indices_batchsize && batchsize == values_batchsize,
224 "all batch dimensions of ", compressed_indices_name," (=", compressed_indices_batchsize, "), ", plain_indices_name," (=",
225 plain_indices_batchsize, "), and values (=", values_batchsize, ") must be equal to tensor batch dimensions (=",
226 batchsize, ")");
227
228 // A tensor constitutes of full blocks, 3.1
229 for (int i=0; i<block_ndim; i++) {
230 TORCH_CHECK(size[batch_ndim + i] % blocksize[i] == 0,
231 "tensor shape[", batch_ndim + i, "] (=", size[batch_ndim + i],
232 ") must be divisible with blocksize[", i, "] (=", blocksize[i],
233 ") as defined by values shape");
234 }
235 const int64_t nrows = size[batch_ndim] / blocksize[0];
236 const int64_t ncols = size[batch_ndim + 1] / blocksize[1];
237 auto [compressed_dim_size, plain_dim_size] = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args",
238 [&] { return std::make_tuple(nrows, ncols); },
239 [&] { return std::make_tuple(ncols, nrows); });
240 // 3.8
241 TORCH_CHECK(
242 compressed_indices.size(-1) == compressed_dim_size + 1,
243 compressed_indices_name, ".shape[-1] must be equal to the number of ",
244 compressed_dim_name, "s + 1 (=", compressed_dim_size + 1, "), but got ", compressed_indices.size(-1));
245 // 3.9, 3.10
246 TORCH_CHECK(
247 plain_indices.size(-1) == values_nnz,
248 plain_indices_name, ".shape[-1] must be equal to nnz (=", values_nnz,
249 ") as defined by values.shape[", batch_ndim, "], but got ", plain_indices.size(-1));
250 // Type Invariants
251 auto compressed_indices_type = compressed_indices.scalar_type();
252 auto plain_indices_type = plain_indices.scalar_type();
253 // 1.1, 1.2, 1.3
254 TORCH_CHECK(
255 compressed_indices_type == plain_indices_type,
256 compressed_indices_name, " and ", plain_indices_name, " must have the same dtype, bot got ",
257 compressed_indices_type, " and ", plain_indices_type, ", respectively");
258 TORCH_CHECK(
259 compressed_indices_type == kInt || compressed_indices_type == kLong,
260 compressed_indices_name, " and ", plain_indices_name, " dtype must be Int or Long, but got ",
261 compressed_indices_type);
262
263 if (compressed_indices.is_meta()) {
264 TORCH_CHECK(values_nnz == 0, "expected nnz to be 0 for sparse ", layout_name, " meta tensor but got ", values_nnz);
265 } else {
266 // Indices invariants
267 at::_validate_compressed_sparse_indices(
268 /*is_crow = */layout == kSparseCsr || layout == kSparseBsr,
269 compressed_indices,
270 plain_indices,
271 compressed_dim_size,
272 plain_dim_size,
273 values_nnz);
274 }
275
276 // Device Invariants
277 // 4.1
278 TORCH_CHECK(
279 values.device().type() == kCPU || values.device().type() == kCUDA || values.device().type() == kMeta,
280 "device type of values (",
281 values.device().type(),
282 ") must be CPU or CUDA or Meta");
283 // 4.2, 4.3, 4.4
284 TORCH_CHECK(
285 compressed_indices.get_device() == values.get_device(),
286 "device of ", compressed_indices_name, " (=",
287 compressed_indices.device(),
288 ") must match device of values (=",
289 values.device(),
290 ")");
291 TORCH_CHECK(
292 compressed_indices.get_device() == plain_indices.get_device(),
293 "device of ", compressed_indices_name, " (=",
294 compressed_indices.device(),
295 ") must match device of ", plain_indices_name, " (=",
296 plain_indices.device(),
297 ")");
298 TORCH_CHECK(
299 compressed_indices.is_pinned() == values.is_pinned(),
300 "memory pinning of ", compressed_indices_name, " (=",
301 compressed_indices.is_pinned(),
302 ") must match memory pinning of values (=",
303 values.is_pinned(),
304 ")");
305 TORCH_CHECK(
306 compressed_indices.is_pinned() == plain_indices.is_pinned(),
307 "memory pinning of ", compressed_indices_name, " (=",
308 compressed_indices.is_pinned(),
309 ") must match memory pinning of ", plain_indices_name, " (=",
310 plain_indices.is_pinned(),
311 ")");
312
313 // Autograd Invariants
314 //
315 // These are internal asserts because users should not be able to
316 // create non-floating point dtype tensors with requires_grad flag
317 // set to true.
318 TORCH_INTERNAL_ASSERT(!compressed_indices.requires_grad());
319 TORCH_INTERNAL_ASSERT(!plain_indices.requires_grad());
320 }
321
_validate_sparse_compressed_tensor_args(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,IntArrayRef size,Layout layout)322 void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) {
323 _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout);
324 }
325
_validate_sparse_csr_tensor_args(const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,IntArrayRef size)326 void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
327 _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr);
328 }
329
_validate_sparse_csc_tensor_args(const Tensor & ccol_indices,const Tensor & row_indices,const Tensor & values,IntArrayRef size)330 void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
331 _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc);
332 }
333
_validate_sparse_bsr_tensor_args(const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,IntArrayRef size)334 void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
335 _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr);
336 }
337
_validate_sparse_bsc_tensor_args(const Tensor & ccol_indices,const Tensor & row_indices,const Tensor & values,IntArrayRef size)338 void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
339 _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc);
340 }
341
342 // Construction of CSR, CSC, BSR, and BSC tensors.
343
344 // Note: The usage of "Csr" in names like SparseCsrTensor,
345 // SparseCsrCPU, SparseCsrCUDA, and SparseCsrTensorImpl exists because
346 // of historical reasons (that ought to be removed in future) and does
347 // not mean that the corresponding functionality would be CSR layout
348 // only specific.
new_compressed_tensor(const TensorOptions & options)349 static SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
350 // TODO: remove this comment after enabling autograd support for CSR tensor
351 // constructor.
352 // TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch());
353 Layout layout = AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(options.layout(), "new_compressed_tensor", [&] { return the_layout; });
354 DispatchKey dispatch_key = DispatchKey::Undefined;
355
356 switch(options.device().type()) {
357 case kCPU:
358 dispatch_key = DispatchKey::SparseCsrCPU;
359 break;
360 case kCUDA:
361 dispatch_key = DispatchKey::SparseCsrCUDA;
362 break;
363 case kMeta:
364 dispatch_key = DispatchKey::SparseCsrMeta;
365 break;
366 case kPrivateUse1:
367 dispatch_key = DispatchKey::SparseCsrPrivateUse1;
368 break;
369 default:
370 TORCH_CHECK_NOT_IMPLEMENTED(false, "Could not run 'new_compressed_tensor' from the '", options.device(), "' device.)");
371 }
372
373 return detail::make_tensor<SparseCsrTensorImpl>(DispatchKeySet(dispatch_key), options.device(), layout, options.dtype());
374 }
375
sparse_compressed_tensor_with_dims(int64_t nnz,int64_t dense_dim,c10::IntArrayRef size,c10::IntArrayRef blocksize,ScalarType index_dtype,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)376 Tensor sparse_compressed_tensor_with_dims(
377 int64_t nnz,
378 int64_t dense_dim,
379 c10::IntArrayRef size,
380 c10::IntArrayRef blocksize,
381 ScalarType index_dtype,
382 std::optional<ScalarType> dtype,
383 std::optional<Layout> layout,
384 std::optional<Device> device,
385 std::optional<bool> pin_memory) {
386 // sparse_compressed_tensor_with_dims is a generalization of empty
387 // that enables the specification of nnz, dense_dim, blocksize, and
388 // index_dtype for sparse compressed tensors.
389 //
390 // sparse_compressed_tensor_with_dims indices and values tensors are
391 // created as empty tensors, so the returned sparse compressed
392 // tensor will not satisfy the sparse compressed tensor
393 // invariants. The caller is responsible for initializing the
394 // indices tensors properly.
395 TORCH_CHECK(layout, "sparse_compressed_tensor_with_dims: expected sparse compressed tensor layout but got none");
396
397 Layout layout_ = layout.value();
398 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_with_dims", [&]{});
399
400 constexpr int64_t sparse_dim = 2;
401 int64_t batch_dim = size.size() - dense_dim - sparse_dim;
402 TORCH_CHECK(batch_dim >= 0, "sparse_compressed_tensor_with_dims: dimensionality must be at least dense_dim(=",
403 dense_dim, ") + sparse_dim(=", sparse_dim, "), but got ", size.size());
404
405 TORCH_CHECK(nnz >= 0, "sparse_compressed_tensor_with_dims: nnz must be non-negative, got ", nnz);
406
407 auto plain_indices_size = DimVector(size.slice(0, batch_dim));
408 auto compressed_indices_size = DimVector(size.slice(0, batch_dim));
409 auto values_size = DimVector(size.slice(0, batch_dim));
410
411 plain_indices_size.push_back(nnz);
412 values_size.push_back(nnz);
413
414 if (layout_ == kSparseBsr || layout_ == kSparseBsc) {
415 TORCH_CHECK(blocksize.size() == (size_t)sparse_dim, "sparse_compressed_tensor_with_dims: blocksize needs to be a tuple of size ",
416 sparse_dim, ", but got ", blocksize.size());
417 auto d0 = (layout_ == kSparseBsr ? 0 : 1);
418 auto d1 = (layout_ == kSparseBsr ? 1 : 0);
419 TORCH_CHECK(blocksize[0] > 0 && blocksize[1] > 0, "sparse_compressed_tensor_with_dims: blocksize needs to be positive, but got ", blocksize);
420 auto compressed_size = size[compressedDimension(layout_, size, dense_dim)];
421 auto plain_size = size[plainDimension(layout_, size, dense_dim)];
422 TORCH_CHECK(compressed_size % blocksize[d0] == 0, "sparse_compressed_tensor_with_dims: dimension ",
423 compressedDimension(layout_, size, dense_dim), " must be multiple of blocksize[", d0, "](=", blocksize[d0], ") but got ", compressed_size);
424 TORCH_CHECK(plain_size % blocksize[d1] == 0, "sparse_compressed_tensor_with_dims: dimension ", plainDimension(layout_, size, dense_dim),
425 " must be multiple of blocksize[", d1, "](=", blocksize[d1], ") but got ", plain_size);
426 compressed_indices_size.push_back(compressed_size / blocksize[d0] + 1);
427 values_size.append(DimVector(blocksize));
428 } else {
429 TORCH_CHECK(blocksize.size() == 0, "sparse_compressed_tensor_with_dims: blocksize cannot be specified for non-block layout ", layout_);
430 compressed_indices_size.push_back(size[compressedDimension(layout_, size, dense_dim)] + 1);
431 }
432
433 values_size.append(DimVector(size.slice(batch_dim + sparse_dim, dense_dim)));
434 TORCH_CHECK(
435 index_dtype == ScalarType::Int || index_dtype == ScalarType::Long,
436 "indices dtype must be Int or Long, but got ", index_dtype);
437
438 TensorOptions options_ = TensorOptions().layout(Layout::Strided).device(device).pinned_memory(pin_memory);
439 auto compressed_indices = at::empty(compressed_indices_size, options_.dtype(index_dtype));
440 auto plain_indices = at::empty(plain_indices_size, options_.dtype(index_dtype));
441 auto values = at::empty(values_size, options_.dtype(dtype));
442 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
443 SparseCsrTensor self = new_compressed_tensor(options);
444 if (pin_memory.value_or(false) && !values.is_pinned()) {
445 get_sparse_csr_impl(self)->set_member_tensors(compressed_indices.pin_memory(), plain_indices.pin_memory(), values.pin_memory(), size);
446 } else {
447 get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
448 }
449 return self;
450 }
451
_sparse_compressed_tensor_unsafe_symint(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,c10::SymIntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)452 Tensor _sparse_compressed_tensor_unsafe_symint(
453 const Tensor& compressed_indices,
454 const Tensor& plain_indices,
455 const Tensor& values,
456 c10::SymIntArrayRef size,
457 std::optional<ScalarType> dtype,
458 std::optional<Layout> layout,
459 std::optional<Device> device,
460 std::optional<bool> pin_memory) {
461 if (!layout) {
462 AT_ERROR("sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none");
463 }
464 Layout layout_ = layout.value();
465 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
466 if (at::globalContext().checkSparseTensorInvariants()) {
467 _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, C10_AS_INTARRAYREF_SLOW(size), layout_);
468 }
469 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
470 SparseCsrTensor self = new_compressed_tensor(options);
471 if (pin_memory.value_or(false) && !values.is_pinned()) {
472 get_sparse_csr_impl(self)->set_member_tensors(compressed_indices.pin_memory(), plain_indices.pin_memory(), values.pin_memory(), size);
473 } else {
474 get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
475 }
476 return self;
477 }
478
479 template <Layout required_layout>
_sparse_compressed_tensor_unsafe_template(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)480 Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices,
481 const Tensor& plain_indices,
482 const Tensor& values,
483 IntArrayRef size,
484 std::optional<ScalarType> dtype,
485 std::optional<Layout> layout,
486 std::optional<Device> device,
487 std::optional<bool> pin_memory) {
488 Layout layout_ = layout.value_or(required_layout);
489 TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
490 if (at::globalContext().checkSparseTensorInvariants()) {
491 _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
492 }
493 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
494 SparseCsrTensor self = new_compressed_tensor(options);
495 if (pin_memory.value_or(false) && !values.is_pinned()) {
496 get_sparse_csr_impl(self)->set_member_tensors(compressed_indices.pin_memory(), plain_indices.pin_memory(), values.pin_memory(), size);
497 } else {
498 get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
499 }
500 return self;
501 }
502
503 #define SPARSE_COMPRESSED_TENSOR_UNSAFE(KIND, REQUIRED_LAYOUT) \
504 Tensor _sparse_##KIND##_tensor_unsafe(const Tensor& compressed_indices, \
505 const Tensor& plain_indices, \
506 const Tensor& values, \
507 IntArrayRef size, \
508 std::optional<ScalarType> dtype, \
509 std::optional<Layout> layout, \
510 std::optional<Device> device, \
511 std::optional<bool> pin_memory) { \
512 return _sparse_compressed_tensor_unsafe_template<REQUIRED_LAYOUT>(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \
513 }
514
515 SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr);
516 SPARSE_COMPRESSED_TENSOR_UNSAFE(csc, kSparseCsc);
517 SPARSE_COMPRESSED_TENSOR_UNSAFE(bsr, kSparseBsr);
518 SPARSE_COMPRESSED_TENSOR_UNSAFE(bsc, kSparseBsc);
519
_estimate_sparse_compressed_tensor_size(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,Layout layout)520 static DimVector _estimate_sparse_compressed_tensor_size(
521 const Tensor& compressed_indices,
522 const Tensor& plain_indices,
523 const Tensor& values,
524 Layout layout) {
525 const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size", [&] { return 0; }, [&] { return 2; });
526 const int base_ndim = 2; // corresponds to compressed and plain indices
527 const auto batch_ndim = compressed_indices.dim() - 1;
528 const std::string compressed_indices_name = compressedIndicesName(layout);
529 const std::string plain_indices_name = plainIndicesName(layout);
530 TORCH_CHECK(
531 batch_ndim >= 0,
532 compressed_indices_name, " must have dimensionality >= 1 but got ", compressed_indices.dim());
533 TORCH_CHECK(
534 compressed_indices.dim() == plain_indices.dim(),
535 compressed_indices_name, " and ", plain_indices_name, " dimensionalities must be equal but got ",
536 compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively");
537 const int64_t dense_ndim = values.dim() - batch_ndim - block_ndim - 1;
538 TORCH_CHECK(
539 dense_ndim >= 0,
540 "values must have dimensionality > sum of batch and block dimensionalities (=",
541 batch_ndim, " + ", block_ndim, ") but got ", values.dim());
542 DimVector blocksize{
543 (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 1)) : 1),
544 (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 2)) : 1)
545 };
546 DimVector size = DimVector(compressed_indices.sizes().slice(0, batch_ndim));
547 int64_t compressed_dim_size = (compressed_indices.dim() > 0 && compressed_indices.size(-1) > 0 ? compressed_indices.size(-1) - 1 : 0);
548 int64_t plain_dim_size = AT_DISPATCH_INTEGRAL_TYPES(plain_indices.scalar_type(), "estimate_sparse_compressed_tensor_size",
549 [&]() -> int64_t {
550 if (plain_indices.numel() > 0) {
551 return plain_indices.max().item<scalar_t>() + 1;
552 } else {
553 return 0;
554 }
555 });
556 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size",
557 [&]{
558 size.push_back(compressed_dim_size * blocksize[0]);
559 size.push_back(plain_dim_size * blocksize[1]);
560 },
561 [&]{
562 size.push_back(plain_dim_size * blocksize[0]);
563 size.push_back(compressed_dim_size * blocksize[1]);
564 });
565 for (int i=0; i<dense_ndim; i++) {
566 int64_t j = batch_ndim + 1 + block_ndim + i;
567 size.push_back((j < values.dim() ? values.size(j) : 1));
568 }
569 TORCH_CHECK(
570 static_cast<int>(size.size()) == batch_ndim + base_ndim + dense_ndim,
571 "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=",
572 batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size());
573 return size;
574 }
575
576 // TODO: This constructor should probably use an ATen abstract method in order
577 // to make autograd dispatch available for the CSR constructor. See the relevant
578 // note in native_functions.yaml.
sparse_compressed_tensor(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)579 Tensor sparse_compressed_tensor(
580 const Tensor& compressed_indices,
581 const Tensor& plain_indices,
582 const Tensor& values,
583 IntArrayRef size,
584 std::optional<ScalarType> dtype,
585 std::optional<Layout> layout,
586 std::optional<Device> device,
587 std::optional<bool> pin_memory) {
588
589 if (!layout) {
590 AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none");
591 }
592 Layout layout_ = layout.value();
593 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{});
594
595 // See [Note: hacky wrapper removal for TensorOptions]
596 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
597
598 return at::_sparse_compressed_tensor_unsafe(
599 compressed_indices,
600 plain_indices,
601 values,
602 size,
603 optTypeMetaToScalarType(options.dtype_opt()),
604 options.layout_opt(),
605 options.device_opt(),
606 options.pinned_memory_opt());
607 }
608
sparse_compressed_tensor(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)609 Tensor sparse_compressed_tensor(
610 const Tensor& compressed_indices,
611 const Tensor& plain_indices,
612 const Tensor& values,
613 std::optional<ScalarType> dtype,
614 std::optional<Layout> layout,
615 std::optional<Device> device,
616 std::optional<bool> pin_memory) {
617
618 if (!layout) {
619 AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none");
620 }
621 Layout layout_ = layout.value();
622 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{});
623
624 DimVector size = _estimate_sparse_compressed_tensor_size(compressed_indices, plain_indices, values, layout_);
625
626 // See [Note: hacky wrapper removal for TensorOptions]
627 TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
628
629 return at::_sparse_compressed_tensor_unsafe(
630 compressed_indices,
631 plain_indices,
632 values,
633 size,
634 optTypeMetaToScalarType(options.dtype_opt()),
635 options.layout_opt(),
636 options.device_opt(),
637 options.pinned_memory_opt());
638 }
639
640 #define SPARSE_COMPRESSED_TENSOR(KIND, REQUIRED_LAYOUT) \
641 Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \
642 const Tensor& plain_indices, \
643 const Tensor& values, \
644 std::optional<ScalarType> dtype, \
645 std::optional<Layout> layout, \
646 std::optional<Device> device, \
647 std::optional<bool> pin_memory) { \
648 if (layout) { \
649 TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \
650 } \
651 std::optional<Layout> layout_(REQUIRED_LAYOUT); \
652 return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, dtype, layout_, device, pin_memory); \
653 } \
654 Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \
655 const Tensor& plain_indices, \
656 const Tensor& values, \
657 IntArrayRef size, \
658 std::optional<ScalarType> dtype, \
659 std::optional<Layout> layout, \
660 std::optional<Device> device, \
661 std::optional<bool> pin_memory) { \
662 if (layout) { \
663 TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \
664 } \
665 std::optional<Layout> layout_(REQUIRED_LAYOUT); \
666 return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, size, dtype, layout_, device, pin_memory); \
667 }
668
SPARSE_COMPRESSED_TENSOR(csr,kSparseCsr)669 SPARSE_COMPRESSED_TENSOR(csr, kSparseCsr)
670 SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc)
671 SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr)
672 SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc)
673
674 Tensor empty_sparse_compressed_symint(
675 SymIntArrayRef size,
676 std::optional<ScalarType> dtype,
677 std::optional<Layout> layout,
678 std::optional<Device> device,
679 std::optional<bool> pin_memory,
680 std::optional<MemoryFormat> optional_memory_format) {
681 // TODO: Don't specialize
682 return empty_sparse_compressed(C10_AS_INTARRAYREF_SLOW_ALLOC(size), dtype, layout, device, pin_memory, optional_memory_format);
683 }
684
685 // Warning: ideally, torch.empty(..., layout=<sparse compressed
686 // format>) ought to be unsupported because it does not return a valid
687 // sparse compressed tensor without initialization of compressed
688 // indices. The implementation below is kept for BC.
empty_sparse_compressed(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<MemoryFormat> optional_memory_format)689 Tensor empty_sparse_compressed(
690 IntArrayRef size,
691 std::optional<ScalarType> dtype,
692 std::optional<Layout> layout,
693 std::optional<Device> device,
694 std::optional<bool> pin_memory,
695 std::optional<MemoryFormat> optional_memory_format) {
696 check_size_nonnegative(size);
697 TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size ", size);
698
699 // Strided is the default layout for torch.empty.
700 Layout layout_ = layout.value_or(Layout::Strided);
701
702 // torch.empty cannot be used to create blocked tensors because its
703 // API lacks a method to specify the block size.
704 AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(layout_, "empty_sparse_compressed", [&]{});
705
706 int64_t nnz = 0;
707 auto compressed_indices_size = DimVector(size.slice(0, size.size() - 2));
708 auto plain_indices_and_values_size = DimVector(size.slice(0, size.size() - 2));
709 compressed_indices_size.push_back(size[compressedDimension(layout_, size)] + 1);
710 plain_indices_and_values_size.push_back(nnz);
711
712 TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
713 auto compressed_indices = at::empty(compressed_indices_size, options);
714 auto plain_indices = at::empty(plain_indices_and_values_size, options);
715 auto values = at::empty(plain_indices_and_values_size, options.dtype(dtype));
716 // torch.empty on produces garbage so that the resulting empty
717 // sparse compressed tensor may fail to satisfy the following
718 // compressed sparse tensor invariants:
719 //
720 // compressed_indices[..., 0] == 0
721 // compressed_indices[..., -1] == nnz.
722 // compressed_indices must be non-decreasing sequence
723 //
724 // Therefore, avoid using empty to create sparse compressed
725 // tensors. Instead, use compressed sparse constructors directly or
726 // other factory functions such as torch.zeros, etc.
727 return at::_sparse_compressed_tensor_unsafe(compressed_indices,
728 plain_indices,
729 values,
730 size,
731 dtype,
732 layout,
733 device,
734 pin_memory);
735 }
736
resize_sparse_csr_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)737 const Tensor& resize_sparse_csr_(
738 const Tensor& self,
739 IntArrayRef size,
740 std::optional<MemoryFormat> optional_memory_format) {
741 check_size_nonnegative(size);
742 TORCH_CHECK(size.size() >= 2, "torch.resize_: Only batched sparse CSR matrices are supported, but got size ", size);
743 TORCH_CHECK(
744 self.size(-1) <= size[size.size() - 1],
745 "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
746 "The original number of columns is ",
747 self.size(-1),
748 " while the requested new number of columns is ", size[size.size() - 1], ".");
749 get_sparse_csr_impl(self)->resize_(self._nnz(), size);
750 return self;
751 }
752
copy_sparse_compressed_(Tensor & self,const Tensor & src,bool non_blocking)753 Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocking) {
754 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{});
755 TORCH_CHECK(
756 self.layout() == src.layout(),
757 "torch.copy_: copy of sparse compressed tensors having different layouts is not supported.",
758 " self layout is ", self.layout(), " and src layout is ", src.layout());
759 TORCH_CHECK(
760 self._nnz() == src._nnz(), // actually, values copy allows different shapes as long as operands are broadcastable
761 "torch.copy_: only sparse compressed tensors with the same number of specified elements are supported.");
762 auto self_compressed_dim = compressedDimension(self.layout(), self.sizes());
763 auto src_compressed_dim = compressedDimension(src.layout(), src.sizes());
764 auto self_compressed_dims = self.size(self_compressed_dim);
765 auto src_compressed_dims = src.size(compressedDimension(src.layout(), src.sizes()));
766 if (self_compressed_dim == src_compressed_dim) {
767 TORCH_CHECK(self_compressed_dims == src_compressed_dims,
768 "torch.copy_: expected shapes of self and src to match along dimension ",
769 self_compressed_dim, " for ",
770 self.layout(), " layout but the corresponding dimensions of self and src are ",
771 self_compressed_dims, " and ", src_compressed_dims, ", respectively.");
772 } else {
773 TORCH_CHECK(self_compressed_dims == src_compressed_dims,
774 "torch.copy_: expected shapes of self and src to match along dimensions ",
775 self_compressed_dim, " and ", src_compressed_dim, ", respectively, for ",
776 self.layout(), " layout but the corresponding dimensions of self and src are ",
777 self_compressed_dims, " and ", src_compressed_dims, ", respectively.");
778 }
779 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
780 [&]{},
781 [&]{
782 auto self_values = self.values();
783 auto src_values = src.values();
784 auto self_blocksize = DimVector(self_values.sizes().slice(self_values.dim()-2, 2));
785 auto src_blocksize = DimVector(src_values.sizes().slice(src_values.dim()-2, 2));
786 TORCH_CHECK(self_blocksize == src_blocksize,
787 "torch.copy_: copy of sparse compressed tensors having different block sizes is not supported.",
788 " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectively.");
789 });
790 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
791 [&]{
792 self.crow_indices().copy_(src.crow_indices(), non_blocking);
793 self.col_indices().copy_(src.col_indices(), non_blocking);
794 },
795 [&]{
796 self.ccol_indices().copy_(src.ccol_indices(), non_blocking);
797 self.row_indices().copy_(src.row_indices(), non_blocking);
798 });
799 self.values().copy_(src.values(), non_blocking);
800 return self;
801 }
802
803 // Access members of CSR tensors.
_nnz_sparse_csr(const SparseCsrTensor & self)804 int64_t _nnz_sparse_csr(const SparseCsrTensor& self) {
805 return get_sparse_csr_impl(self)->nnz();
806 }
807
values_sparse_csr(const Tensor & self)808 Tensor values_sparse_csr(const Tensor& self) {
809 return get_sparse_csr_impl(self)->values().alias();
810 }
811
crow_indices_sparse_csr(const Tensor & self)812 Tensor crow_indices_sparse_csr(const Tensor& self) {
813 return AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(self.layout(),
814 "crow_indices",
815 [&]{ return get_sparse_csr_impl(self)->compressed_indices().alias(); });
816 }
817
col_indices_sparse_csr(const Tensor & self)818 Tensor col_indices_sparse_csr(const Tensor& self) {
819 return AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(self.layout(),
820 "col_indices",
821 [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); });
822 }
823
ccol_indices_sparse_csr(const Tensor & self)824 Tensor ccol_indices_sparse_csr(const Tensor& self) {
825 return AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(self.layout(),
826 "ccol_indices",
827 [&]{ return get_sparse_csr_impl(self)->compressed_indices().alias(); });
828 }
829
row_indices_sparse_csr(const Tensor & self)830 Tensor row_indices_sparse_csr(const Tensor& self) {
831 return AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(self.layout(),
832 "row_indices",
833 [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); });
834 }
835
crow_indices_default(const Tensor & self)836 Tensor crow_indices_default(const Tensor& self) {
837 TORCH_CHECK(false, "crow_indices expected sparse row compressed tensor layout but got ", self.layout());
838 }
839
col_indices_default(const Tensor & self)840 Tensor col_indices_default(const Tensor& self) {
841 TORCH_CHECK(false, "col_indices expected sparse row compressed tensor layout but got ", self.layout());
842 }
843
ccol_indices_default(const Tensor & self)844 Tensor ccol_indices_default(const Tensor& self) {
845 TORCH_CHECK(false, "ccol_indices expected sparse column compressed tensor layout but got ", self.layout());
846 }
847
row_indices_default(const Tensor & self)848 Tensor row_indices_default(const Tensor& self) {
849 TORCH_CHECK(false, "row_indices expected sparse column compressed tensor layout but got ", self.layout());
850 }
851
sparse_dim_sparse_csr(const SparseCsrTensor & self)852 int64_t sparse_dim_sparse_csr(const SparseCsrTensor& self) {
853 return get_sparse_csr_impl(self)->sparse_dim();
854 }
855
dense_dim_sparse_csr(const SparseCsrTensor & self)856 int64_t dense_dim_sparse_csr(const SparseCsrTensor& self) {
857 return get_sparse_csr_impl(self)->dense_dim();
858 }
859
resize_as_sparse_compressed_(const SparseCsrTensor & self,const SparseCsrTensor & src)860 const SparseCsrTensor& resize_as_sparse_compressed_(
861 const SparseCsrTensor& self,
862 const SparseCsrTensor& src) {
863 auto src_layout = src.layout();
864 auto self_layout = self.layout();
865 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
866 src_layout, "resize_as_sparse_compressed_: src ", []() {});
867 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
868 self_layout, "resize_as_sparse_compressed_: self ", []() {});
869 // Note: The impl method does all required checking to see if resize/data copy
870 // on member tensors is required.
871 get_sparse_csr_impl(self)->resize_as_sparse_compressed_tensor_(src);
872 return self;
873 }
874
clone_sparse_compressed(const SparseCsrTensor & self,std::optional<c10::MemoryFormat> optional_memory_format)875 SparseCsrTensor clone_sparse_compressed(
876 const SparseCsrTensor& self,
877 std::optional<c10::MemoryFormat> optional_memory_format) {
878 TORCH_CHECK(
879 !optional_memory_format.has_value(),
880 "unsupported memory format option ",
881 optional_memory_format.value());
882 TensorOptions options = self.options();
883 auto compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(),
884 "clone_sparse_compressed",
885 [&]{ return self.crow_indices(); },
886 [&]{ return self.ccol_indices(); });
887 auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(),
888 "clone_sparse_compressed",
889 [&]{ return self.col_indices(); },
890 [&]{ return self.row_indices(); });
891 return at::_sparse_compressed_tensor_unsafe(
892 compressed_indices.clone(),
893 plain_indices.clone(),
894 self.values().clone(),
895 self.sizes(),
896 optTypeMetaToScalarType(options.dtype_opt()),
897 options.layout_opt(),
898 options.device_opt(),
899 options.pinned_memory_opt());
900 }
901
empty_like_sparse_csr(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)902 Tensor empty_like_sparse_csr(
903 const Tensor& self,
904 std::optional<ScalarType> dtype,
905 std::optional<Layout> layout,
906 std::optional<Device> device,
907 std::optional<bool> pin_memory,
908 std::optional<c10::MemoryFormat> optional_memory_format) {
909 TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
910 TensorOptions options =
911 self.options()
912 .merge_in(options_)
913 .merge_memory_format(optional_memory_format);
914
915 TORCH_CHECK(options.layout() == self.layout(),
916 "empty_like with different sparse layout is not supported (self is ",
917 self.layout(), " but you requested ", options.layout(), ")");
918 if (options.layout() == kSparseCsr) {
919 auto result = at::native::_sparse_csr_tensor_unsafe(
920 self.crow_indices().to(options.device(), self.crow_indices().dtype(), false, true),
921 self.col_indices().to(options.device(), self.col_indices().dtype(), false, true),
922 at::empty(self.values().sizes(), options.layout(kStrided)),
923 self.sizes(),
924 optTypeMetaToScalarType(options.dtype()),
925 self.layout(),
926 options.device());
927 return result;
928 } else if (options.layout() == kSparseCsc) {
929 auto result = at::native::_sparse_csc_tensor_unsafe(
930 self.ccol_indices().to(options.device(), self.ccol_indices().dtype(), false, true),
931 self.row_indices().to(options.device(), self.row_indices().dtype(), false, true),
932 at::empty(self.values().sizes(), options.layout(kStrided)),
933 self.sizes(),
934 optTypeMetaToScalarType(options.dtype()),
935 self.layout(),
936 options.device());
937 return result;
938 } else if (options.layout() == kSparseBsr) {
939 auto result = at::native::_sparse_bsr_tensor_unsafe(
940 self.crow_indices().to(options.device(), self.crow_indices().dtype(), false, true),
941 self.col_indices().to(options.device(), self.col_indices().dtype(), false, true),
942 at::empty(self.values().sizes(), options.layout(kStrided)),
943 self.sizes(),
944 optTypeMetaToScalarType(options.dtype()),
945 self.layout(),
946 options.device());
947
948 return result;
949 } else if (options.layout() == kSparseBsc) {
950 auto result = at::native::_sparse_bsc_tensor_unsafe(
951 self.ccol_indices().to(options.device(), self.ccol_indices().dtype(), false, true),
952 self.row_indices().to(options.device(), self.row_indices().dtype(), false, true),
953 at::empty(self.values().sizes(), options.layout(kStrided)),
954 self.sizes(),
955 optTypeMetaToScalarType(options.dtype()),
956 self.layout(),
957 options.device());
958 return result;
959 } else if (options.layout() == kStrided) {
960 return at::native::empty_like(self, dtype, layout, device, pin_memory, optional_memory_format);
961 } else {
962 TORCH_CHECK(false, "Layout ", options.layout(), " is not supported");
963 }
964 }
965
966 template <bool require_view, bool require_copy>
select_sparse_csr_worker(const Tensor & self,int64_t dim,int64_t index)967 Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index) {
968 #ifndef STRIP_ERROR_MESSAGES
969 constexpr const char* select_name = (require_view ? "select()" : "select_copy()");
970 #endif
971 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
972 self.layout(), "select", []() { return; });
973 TORCH_CHECK_INDEX(
974 self.dim() != 0, select_name, " cannot be applied to a 0-dim tensor.");
975 dim = maybe_wrap_dim(dim, self.dim());
976 auto size = self.size(dim);
977 if (index < -size || index >= size) {
978 TORCH_CHECK_INDEX(
979 false,
980 select_name, ": index ",
981 index,
982 " out of range for tensor of size ",
983 self.sizes(),
984 " at dimension ",
985 dim);
986 }
987 if (index < 0) {
988 index += size;
989 }
990
991 auto select_strided = [](const Tensor& self, int64_t dim, int64_t index) {
992 if (require_copy) {
993 return at::select_copy(self, dim, index);
994 } else {
995 return self.select(dim, index);
996 }
997 };
998
999 TORCH_INTERNAL_ASSERT(dim >= 0 && dim < self.dim());
1000
1001 auto new_sizes = DimVector(self.sizes());
1002 new_sizes.erase(new_sizes.begin() + dim);
1003 auto options = self.options();
1004
1005 auto [compressed_indices, plain_indices] =
1006 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
1007 self.layout(),
1008 "select",
1009 [&]() {
1010 return std::make_pair(self.crow_indices(), self.col_indices());
1011 },
1012 [&]() {
1013 return std::make_pair(self.ccol_indices(), self.row_indices());
1014 });
1015 auto n_batch = compressed_indices.dim() - 1;
1016
1017 if (dim < n_batch) {
1018 // Selecting batch dimension
1019 return at::_sparse_compressed_tensor_unsafe(
1020 compressed_indices.select(dim, index),
1021 plain_indices.select(dim, index),
1022 select_strided(self.values(), dim, index),
1023 new_sizes,
1024 optTypeMetaToScalarType(options.dtype_opt()),
1025 options.layout_opt(),
1026 options.device_opt(),
1027 options.pinned_memory_opt());
1028 } else if (dim < n_batch + 2) {
1029 // Selecting sparse dimension
1030 TORCH_CHECK(
1031 n_batch == 0,
1032 select_name, ": selecting sparse dimensions is not supported for batched sparse compressed tensors.")
1033 TORCH_INTERNAL_ASSERT(dim == 0 || dim == 1);
1034
1035 DimVector blocksize{1, 1};
1036 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&] {}, [&] {
1037 blocksize[0] = std::max<int64_t>(1, self.values().size(n_batch + 1));
1038 blocksize[1] = std::max<int64_t>(1, self.values().size(n_batch + 2));
1039 });
1040
1041 auto indices_options = compressed_indices.options();
1042 int64_t fast_dim = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() { return 0; }, [&]() { return 1; });
1043 int64_t other_dim = (dim == 0 ? 1 : 0);
1044 Tensor indices;
1045 Tensor values;
1046 bool is_view = dim == fast_dim;
1047 if (is_view) {
1048 // select is always a view operation
1049 Tensor start_end = compressed_indices.narrow(0, index / blocksize[dim], 2).cpu();
1050 int64_t start = start_end[0].item<int64_t>();
1051 int64_t end = start_end[1].item<int64_t>();
1052 indices = plain_indices.slice(0, start, end);
1053 values = self.values().slice(0, start, end);
1054 } else {
1055 Tensor decompressed_indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
1056 .select(0, 0);
1057
1058 Tensor dim_indices = at::where(plain_indices.eq(index / blocksize[dim]))[0];
1059 // Notice that dim_indices is a sorted sequence of non-negative
1060 // distinct integers. Below we'll try to solve `dim_indices ==
1061 // arange(start, stop, step)`. If the solution exists then the
1062 // select will be a view operation also for the `dim !=
1063 // fast_dim` case.
1064 int64_t start{}, end{}, step{};
1065 if (solve_arange(dim_indices, start, end, step)) {
1066 indices = decompressed_indices.slice(0, start, end, step);
1067 values = self.values().slice(0, start, end, step);
1068 is_view = true;
1069 } else {
1070 // select will be a copy operation due to index_select!
1071 indices = decompressed_indices.index_select(0, dim_indices);
1072 values = self.values().index_select(0, dim_indices);
1073 }
1074 }
1075
1076 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() {},
1077 [&]() {
1078 /*
1079 The formula for select indices and values below are best
1080 explained by an example. Consider a BSR tensor with a
1081 block size (2, 3) having four blocks (the other two blocks
1082 contain all zeros and hence will not be specified):
1083
1084 [ 1 2 3] | [ 7 8 9]
1085 [ 4 5 6] | [10 11 12]
1086 ---------------------
1087 [13 14 15] | [ 0 0 0]
1088 [16 17 18] | [ 0 0 0]
1089 -----------------------
1090 [ 0 0 0] | [19 20 21]
1091 [ 0 0 0] | [22 23 24]
1092
1093 that represents a 6 x 6 tensor:
1094
1095 [ 1 2 3 7 8 9 ]
1096 [ 4 5 6 10 11 12 ]
1097 [ 13 14 15 0 0 0 ]
1098 [ 16 17 18 0 0 0 ]
1099 [ 0 0 0 19 20 21 ]
1100 [ 0 0 0 22 23 24 ]
1101
1102 The corresponding data for the BSR representation is:
1103
1104 crow_indices = [0 2 3 4]
1105 col_indices = [0 1 0 1]
1106 values = [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]], [[13 14 15], [16 17 18]], [[19 20 21], [22 23 24]] ]
1107 shape = (6, 6)
1108
1109 From crow_indices, we can find that
1110
1111 row_indices = [0 0 1 2]
1112
1113 In the following, we'll illustrate the details of
1114 computing the result of torch.select_copy(input, dim,
1115 index) where dim is 0 or 1, and index is in
1116 range(shape[dim]).
1117
1118 Select a row of a BSR tensor
1119 ----------------------------
1120
1121 We will consider first the dim=0 case that corresponds to
1122 selecting a index-th row of the tensor. For instance, for
1123 dim=0 and index=1, the expected result would represent a
1124 1D tensor:
1125
1126 [ 4 5 6 10 11 12 ]
1127
1128 that is a concatenated tensor of certain slices from the
1129 first and the second block that is computed as follows:
1130
1131 values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1)
1132 -> values[[0, 1]][:, 1 % 2].flatten(0, 1)
1133 -> [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]] ][:, 1].flatten(0, 1)
1134 -> [ [4 5 6], [10 11 12]].flatten(0, 1)
1135 -> [ 4 5 6 10 11 12]
1136
1137 where dim_indices is found as
1138
1139 where(row_indices == index//blocksize[dim])
1140 -> where([0 0 1 2] == 1//2)
1141 -> [0 1]
1142
1143 The corresponding column indices are computed as
1144
1145 (col_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1146
1147 where other_dim is 1 if dim is 0, and 0 if dim is 1. Let's
1148 expand the above expression with the data in the example:
1149
1150 -> (col_indices[[0, 1]].mul(3).unsqueeze(1) + arange(3).unsqueeze(0)).flatten(0, 1)
1151 -> ([[0 1].mul(3).unsqueeze(1) + [[0 1 2]]).flatten(0, 1)
1152 -> ([[[0], [3]] + [[0 1 2]]).flatten(0, 1) <- here addition will use broadcasting rules!
1153 -> ([[[0 1 2], [3 4 5]]).flatten(0, 1)
1154 -> [0 1 2 3 4 5]
1155
1156 Finally, the select(dim=0, index=1) op on the given sparse
1157 compressed tensors will return a COO tensor:
1158
1159 sparse_coo_tensor([0 1 2 3 4 5].unsqueeze(0), [4 5 6 10 11 12], (6,))
1160
1161 that represents the expected result: [ 4 5 6 10 11 12 ]
1162
1163 Select a column of a BSR tensor
1164 -------------------------------
1165
1166 Next, we'll consider the dim=1 case that corresponds to
1167 selecting the index-th column of the tensor. For instance,
1168 for dim=1 and index=4, the expected result would represent
1169 a 1D tensor:
1170
1171 [ 8 11 0 0 20 23]
1172
1173 that is a concatenated tensor of certain slices from the
1174 second and the last block:
1175
1176 values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1)
1177 -> values[[1, 3]][:, :, 4 % 3 ].flatten(0, 1)
1178 -> [ [[7 8 9], [10 11 12]], [[19 20 21], [22 23 24]] ][:, 1, 1].flatten(0, 1)
1179 -> [ [8 11], [20 23]].flatten(0, 1)
1180 -> [ 8 11 20 23 ]
1181
1182 The corresponding row indices are computed as
1183
1184 (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1185
1186 where dim_indices is
1187
1188 where(col_indices == index//blocksize[dim])
1189 -> where([0 1 0 1] == 4//3)
1190 -> [1 3]
1191
1192 and we have
1193
1194 (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1195 -> (row_indices[[1 3]].mul(2).unsqueeze(1) + arange(2).unsqueeze(0)).flatten(0, 1)
1196 -> ([0 4].unsqueeze(1) + [0 1].unsqueeze(0)).flatten(0, 1)
1197 -> ([[0], [4]] + [[0 1]]).flatten(0, 1) <- here addition will use broadcasting rules!
1198 -> ([[0 1], [4 5]]).flatten(0, 1)
1199 -> [ 0 1 4 5 ]
1200
1201 Finally, the select(dim=1, index=4) op on the given sparse
1202 compressed tensors will return a COO tensor:
1203
1204 sparse_coo_tensor([0 1 4 5].unsqueeze(0), [8 11 20 23], (6,))
1205
1206 that represents the expected result: [ 8 11 0 0 20 23 ]
1207
1208 */
1209 Tensor subblock_indices = at::arange(0, blocksize[other_dim], indices_options);
1210 indices = indices.mul(blocksize[other_dim]).unsqueeze(1).add(subblock_indices.unsqueeze(0)).flatten(0, 1);
1211 values = values.select(dim + 1, index % blocksize[dim]).flatten(0, 1);
1212 // flatten(0, 1) can be a view or a copy operation. If view
1213 // is required, it will be checked below via is_alias_of,
1214 // otherwise, we'll check if copy is made here to avoid
1215 // unnecessary clone below:
1216 if (require_copy) {
1217 is_view = values.is_alias_of(self.values());
1218 }
1219 });
1220
1221 if (require_view) {
1222 TORCH_CHECK(values.is_alias_of(self.values()), select_name,
1223 ": no view exists for the given input, consider using torch.select_copy.");
1224 }
1225
1226 indices = indices.unsqueeze(0).to(kLong);
1227 if (require_copy && is_view) {
1228 values = values.clone();
1229 }
1230 return at::_sparse_coo_tensor_unsafe(indices, values, new_sizes)._coalesced_(true);
1231 } else {
1232 // Selecting dense dimension
1233 Tensor new_values = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
1234 self.layout(),
1235 "select",
1236 // Non blocked layout (2 sparse dims become 1 nnz dim in values, so dim
1237 // is found one position to the left)
1238 [&]() { return select_strided(self.values(), dim - 1, index); },
1239 // Block layout (2 sparse dims become 1 nnz dim + 2 block-shape dims in
1240 // values, so dim is found 1 position to the right)
1241 [&]() { return select_strided(self.values(), dim + 1, index); });
1242 return at::_sparse_compressed_tensor_unsafe(
1243 compressed_indices,
1244 plain_indices,
1245 new_values,
1246 new_sizes,
1247 optTypeMetaToScalarType(options.dtype_opt()),
1248 options.layout_opt(),
1249 options.device_opt(),
1250 options.pinned_memory_opt());
1251 }
1252 }
1253
select_sparse_csr(const Tensor & self,int64_t dim,int64_t index)1254 Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
1255 return select_sparse_csr_worker<true, false>(self, dim, index);
1256 }
1257
select_copy_sparse_csr(const Tensor & self,int64_t dim,int64_t index)1258 Tensor select_copy_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
1259 return select_sparse_csr_worker<false, true>(self, dim, index);
1260 }
1261
is_pinned_sparse_compressed(const Tensor & self,std::optional<Device> device)1262 bool is_pinned_sparse_compressed(const Tensor& self, std::optional<Device> device) {
1263 // Assuming that compressed/plain_indices has the same pin memory state as values
1264 return self.values().is_pinned(device);
1265 }
1266
_pin_memory_sparse_compressed(const Tensor & self,std::optional<Device> device)1267 Tensor _pin_memory_sparse_compressed(const Tensor& self, std::optional<Device> device) {
1268 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_cuda());
1269 // pinning of sparse tensor is equivalent to cloning indices and
1270 // values that will not change the sparse tensor invariants. Hence,
1271 // we can skip checking the sparse tensor invariants for efficiency.
1272 CheckSparseTensorInvariants _(false);
1273 TensorOptions options = self.options().pinned_memory(true);
1274 auto impl = get_sparse_csr_impl(self);
1275 return at::_sparse_compressed_tensor_unsafe(
1276 impl->compressed_indices().pin_memory(device),
1277 impl->plain_indices().pin_memory(device),
1278 impl->values().pin_memory(device),
1279 self.sizes(),
1280 optTypeMetaToScalarType(options.dtype_opt()),
1281 options.layout_opt(),
1282 options.device_opt(),
1283 options.pinned_memory_opt());
1284 }
1285
1286 } // namespace at::native
1287