xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseCsrTensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Basic functions on sparse tensors
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/InitialTensorOptions.h>
7 #include <ATen/Layout.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/SparseCsrTensorImpl.h>
10 #include <ATen/SparseCsrTensorUtils.h>
11 #include <ATen/SparseTensorImpl.h>
12 #include <ATen/native/LinearAlgebraUtils.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
19 #include <ATen/ops/_nnz_native.h>
20 #include <ATen/ops/_pin_memory_native.h>
21 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
22 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
23 #include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
24 #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
25 #include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
26 #include <ATen/ops/_sparse_compressed_tensor_with_dims_native.h>
27 #include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
28 #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
29 #include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
30 #include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
31 #include <ATen/ops/_validate_sparse_csc_tensor_args_native.h>
32 #include <ATen/ops/_validate_sparse_bsr_tensor_args_native.h>
33 #include <ATen/ops/_validate_sparse_bsc_tensor_args_native.h>
34 #include <ATen/ops/aminmax.h>
35 #include <ATen/ops/ccol_indices_native.h>
36 #include <ATen/ops/clone_native.h>
37 #include <ATen/ops/col_indices_native.h>
38 #include <ATen/ops/copy_native.h>
39 #include <ATen/ops/crow_indices_native.h>
40 #include <ATen/ops/dense_dim_native.h>
41 #include <ATen/ops/empty.h>
42 #include <ATen/ops/empty_like_native.h>
43 #include <ATen/ops/empty_native.h>
44 #include <ATen/ops/is_pinned_native.h>
45 #include <ATen/ops/resize_as_sparse_native.h>
46 #include <ATen/ops/resize_native.h>
47 #include <ATen/ops/row_indices_native.h>
48 #include <ATen/ops/select_native.h>
49 #include <ATen/ops/select_copy.h>
50 #include <ATen/ops/select_copy_native.h>
51 #include <ATen/ops/sparse_compressed_tensor_native.h>
52 #include <ATen/ops/sparse_csr_tensor_native.h>
53 #include <ATen/ops/sparse_csc_tensor_native.h>
54 #include <ATen/ops/sparse_bsr_tensor_native.h>
55 #include <ATen/ops/sparse_bsc_tensor_native.h>
56 #include <ATen/ops/sparse_dim_native.h>
57 #include <ATen/ops/values_native.h>
58 #include <ATen/ops/_validate_compressed_sparse_indices.h>
59 #include <ATen/ops/where.h>
60 #endif
61 
62 namespace at::native {
63 
64 using namespace at::sparse_csr;
65 
66 namespace {
67 
solve_arange(const Tensor & input,int64_t & start,int64_t & end,int64_t & step)68 bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& step) {
69   /*
70     This function solves the equation
71 
72       input == arange(start, end, step)
73 
74     for integers start, end, and step, if possible. If the solution
75     exists, returns true.
76   */
77   int64_t n = input.numel();
78   if (n == 0) {
79     // a trivial solution
80     start = end = 0;
81     step = 1;
82   } else if (n == 1) {
83     // a simple solution
84     start = input[0].item<int64_t>();
85     end = start + 1;
86     step = 1;
87   } else {
88     Tensor first_last = input.slice(0, 0, n, n - 1).cpu();
89     int64_t start_candidate = first_last[0].item<int64_t>();
90     int64_t end_candidate = first_last[1].item<int64_t>() + 1;
91     if (end_candidate - start_candidate == n) {
92       // a special solution
93       start = start_candidate;
94       end = end_candidate;
95       step = 1;
96     } else {
97       // detect if general solution exists
98       Tensor possible_steps = input.slice(0, 1).sub(input.slice(0, 0, n - 1));
99       Tensor possible_step = possible_steps[0];
100       if ((possible_steps.eq(possible_step)).all().item<bool>()) {
101         start = start_candidate;
102         end = end_candidate;
103         step = possible_step.item<int64_t>();
104       } else {
105         // no solution
106         return false;
107       }
108     }
109   }
110   return true;
111 }
112 
113 } // end anonymous namespace
114 
115 /*
116   Validate the arguments to sparse compressed (CSR, CSC, BSR, and BSC)
117   tensor factory functions.
118 
119   The CSR and BSR invariants for PyTorch are outlined in
120 
121     https://pearu.github.io/csr_tensor_invariants.html
122     https://pearu.github.io/bsr_tensor_invariants.html
123 
124   that in what follows are generalized for all sparse compressed
125   formats with support to batched and dense dimensions.
126 */
127 
_validate_sparse_compressed_tensor_args_worker(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,const IntArrayRef size,const Layout & layout)128 static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) {
129   // Layout must be Sparse Compressed, 2.4
130   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{});
131 
132   const std::string layout_name = layoutToString(layout, /*upper=*/ true);
133   const std::string compressed_indices_name = compressedIndicesName(layout);
134   const std::string plain_indices_name = plainIndicesName(layout);
135   const std::string compressed_dim_name = compressedDimName(layout);
136   const std::string plain_dim_name = plainDimName(layout);
137 
138   // Layout Invariants
139 
140   // Re 3.5 and 3.6: in the case of compressed/plain indices tensors,
141   // we require contiguity per-patch basis, that is, the last stride
142   // of these indices must be 1. The reasoning for this is that
143   // indices tensors within a patch are "atomic" in the sense that
144   // sliced compressed/plain indices would not represent the indices
145   // of any sparse compressed tensor as the slicing would break the
146   // description of the tensor index structure.
147 
148   // 2.1
149   TORCH_CHECK(plain_indices.layout() == kStrided,
150               "expected ", plain_indices_name, " to be a strided tensor but got ", plain_indices.layout(), " tensor");
151 
152   // 2.2
153   TORCH_CHECK(compressed_indices.layout() == kStrided,
154               "expected ", compressed_indices_name, " to be a strided tensor but got ", compressed_indices.layout(), " tensor");
155 
156   const int base_ndim = 2;  // corresponds to compressed and plain indices
157   const auto batch_ndim = compressed_indices.dim() - 1;
158   const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
159                            layout, "validate_sparse_compressed_tensor_args",
160                            [&] { return 0; }, [&] { return 2; });
161   const auto dense_ndim = values.dim() - batch_ndim - block_ndim - 1;
162 
163   // 2.3
164   TORCH_CHECK(values.layout() == kStrided,
165               "expected values to be a strided tensor but got ", values.layout(), " tensor");
166 
167   // 3.7 is dropped, that is, values tensor does not need to be
168   // contiguous, in general. Particular algorithms on sparse
169   // compressed tensors may require contiguity though.
170 
171   // Shape and Strides invariants
172 
173   // 3.2
174   TORCH_CHECK(
175               batch_ndim >= 0,
176               compressed_indices_name, " must have dimensionality >= 1 but got ", compressed_indices.dim());
177 
178   // 3.3
179   TORCH_CHECK(
180               compressed_indices.dim() == plain_indices.dim(),
181               compressed_indices_name, " and ", plain_indices_name, " dimensionalities must be equal but got ",
182               compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively");
183 
184   // 3.4
185   TORCH_CHECK(
186               dense_ndim >= 0,
187               "values must have dimensionality > sum of batch and block dimensionalities (=",
188               batch_ndim, " + ", block_ndim, ") but got ", values.dim());
189 
190   // 3.5
191   TORCH_CHECK(plain_indices.stride(-1) == 1,
192               "expected ", plain_indices_name, " to be a contiguous tensor per batch");
193 
194   // 3.6
195   TORCH_CHECK(compressed_indices.stride(-1) == 1,
196               "expected ", compressed_indices_name, " to be a contiguous tensor per batch");
197 
198   // 3.1
199   TORCH_CHECK(
200               static_cast<int>(size.size()) == batch_ndim + base_ndim + dense_ndim,
201               "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=",
202               batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size());
203 
204   // For CSR/CSC formats, we define blocksize=(1, 1) so that checking
205   // the sparse compressed tensor invariants can be unified with the
206   // BSR/BSC invariants.
207   // 3.10
208   DimVector blocksize{
209                       (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 1)) : 1),
210                       (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 2)) : 1),
211   };
212   TORCH_INTERNAL_ASSERT(blocksize.size() == 2 && blocksize[0] > 0 && blocksize[1] > 0);
213 
214   // All batch sizes must be the same and consistent with tensor batchsize, 3.1, 3.8, 3.9, 3.10
215   DimVector batchsize = DimVector(size.slice(0, batch_ndim));
216   DimVector compressed_indices_batchsize = DimVector(compressed_indices.sizes().slice(0, batch_ndim));
217   DimVector plain_indices_batchsize = DimVector(plain_indices.sizes().slice(0, batch_ndim));
218   DimVector values_batchsize = DimVector(values.sizes().slice(0, batch_ndim));
219   const int64_t values_nnz = values.size(batch_ndim);
220   DimVector values_blocksize = DimVector(values.sizes().slice(batch_ndim + 1, block_ndim));
221   DimVector values_densesize = DimVector(values.sizes().slice(batch_ndim + 1 + block_ndim, dense_ndim));
222   TORCH_CHECK(
223       batchsize == compressed_indices_batchsize && batchsize == plain_indices_batchsize && batchsize == values_batchsize,
224       "all batch dimensions of ", compressed_indices_name," (=", compressed_indices_batchsize, "), ", plain_indices_name," (=",
225       plain_indices_batchsize, "), and values (=", values_batchsize, ") must be equal to tensor batch dimensions (=",
226       batchsize, ")");
227 
228   // A tensor constitutes of full blocks, 3.1
229   for (int i=0; i<block_ndim; i++) {
230       TORCH_CHECK(size[batch_ndim + i] % blocksize[i] == 0,
231                   "tensor shape[", batch_ndim + i, "] (=", size[batch_ndim + i],
232                   ") must be divisible with blocksize[", i, "] (=", blocksize[i],
233                   ") as defined by values shape");
234   }
235   const int64_t nrows = size[batch_ndim] / blocksize[0];
236   const int64_t ncols = size[batch_ndim + 1] / blocksize[1];
237   auto [compressed_dim_size, plain_dim_size] = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args",
238                                                                                             [&] { return std::make_tuple(nrows, ncols); },
239                                                                                             [&] { return std::make_tuple(ncols, nrows); });
240   // 3.8
241   TORCH_CHECK(
242               compressed_indices.size(-1) == compressed_dim_size + 1,
243               compressed_indices_name, ".shape[-1] must be equal to the number of ",
244               compressed_dim_name, "s + 1 (=", compressed_dim_size + 1, "), but got ", compressed_indices.size(-1));
245   // 3.9, 3.10
246   TORCH_CHECK(
247               plain_indices.size(-1) == values_nnz,
248               plain_indices_name, ".shape[-1] must be equal to nnz (=", values_nnz,
249               ") as defined by values.shape[", batch_ndim, "], but got ", plain_indices.size(-1));
250   // Type Invariants
251   auto compressed_indices_type = compressed_indices.scalar_type();
252   auto plain_indices_type = plain_indices.scalar_type();
253   // 1.1, 1.2, 1.3
254   TORCH_CHECK(
255       compressed_indices_type == plain_indices_type,
256       compressed_indices_name, " and ", plain_indices_name, " must have the same dtype, bot got ",
257       compressed_indices_type, " and ", plain_indices_type, ", respectively");
258   TORCH_CHECK(
259       compressed_indices_type == kInt || compressed_indices_type == kLong,
260       compressed_indices_name, " and ", plain_indices_name, " dtype must be Int or Long, but got ",
261       compressed_indices_type);
262 
263   if (compressed_indices.is_meta()) {
264     TORCH_CHECK(values_nnz == 0, "expected nnz to be 0 for sparse ", layout_name, " meta tensor but got ", values_nnz);
265   } else {
266     // Indices invariants
267     at::_validate_compressed_sparse_indices(
268         /*is_crow = */layout == kSparseCsr || layout == kSparseBsr,
269         compressed_indices,
270         plain_indices,
271         compressed_dim_size,
272         plain_dim_size,
273         values_nnz);
274   }
275 
276   // Device Invariants
277   // 4.1
278   TORCH_CHECK(
279       values.device().type() == kCPU || values.device().type() == kCUDA || values.device().type() == kMeta,
280       "device type of values (",
281       values.device().type(),
282       ") must be CPU or CUDA or Meta");
283   // 4.2, 4.3, 4.4
284   TORCH_CHECK(
285       compressed_indices.get_device() == values.get_device(),
286       "device of ", compressed_indices_name, " (=",
287       compressed_indices.device(),
288       ") must match device of values (=",
289       values.device(),
290       ")");
291   TORCH_CHECK(
292       compressed_indices.get_device() == plain_indices.get_device(),
293       "device of ", compressed_indices_name, " (=",
294       compressed_indices.device(),
295       ") must match device of ", plain_indices_name, " (=",
296       plain_indices.device(),
297       ")");
298   TORCH_CHECK(
299       compressed_indices.is_pinned() == values.is_pinned(),
300       "memory pinning of ", compressed_indices_name, " (=",
301       compressed_indices.is_pinned(),
302       ") must match memory pinning of values (=",
303       values.is_pinned(),
304       ")");
305   TORCH_CHECK(
306       compressed_indices.is_pinned() == plain_indices.is_pinned(),
307       "memory pinning of ", compressed_indices_name, " (=",
308       compressed_indices.is_pinned(),
309       ") must match memory pinning of ", plain_indices_name, " (=",
310       plain_indices.is_pinned(),
311       ")");
312 
313   // Autograd Invariants
314   //
315   // These are internal asserts because users should not be able to
316   // create non-floating point dtype tensors with requires_grad flag
317   // set to true.
318   TORCH_INTERNAL_ASSERT(!compressed_indices.requires_grad());
319   TORCH_INTERNAL_ASSERT(!plain_indices.requires_grad());
320 }
321 
_validate_sparse_compressed_tensor_args(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,IntArrayRef size,Layout layout)322 void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) {
323   _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout);
324 }
325 
_validate_sparse_csr_tensor_args(const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,IntArrayRef size)326 void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
327   _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr);
328 }
329 
_validate_sparse_csc_tensor_args(const Tensor & ccol_indices,const Tensor & row_indices,const Tensor & values,IntArrayRef size)330 void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
331   _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc);
332 }
333 
_validate_sparse_bsr_tensor_args(const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,IntArrayRef size)334 void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
335   _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr);
336 }
337 
_validate_sparse_bsc_tensor_args(const Tensor & ccol_indices,const Tensor & row_indices,const Tensor & values,IntArrayRef size)338 void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
339   _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc);
340 }
341 
342 // Construction of CSR, CSC, BSR, and BSC tensors.
343 
344 // Note: The usage of "Csr" in names like SparseCsrTensor,
345 // SparseCsrCPU, SparseCsrCUDA, and SparseCsrTensorImpl exists because
346 // of historical reasons (that ought to be removed in future) and does
347 // not mean that the corresponding functionality would be CSR layout
348 // only specific.
new_compressed_tensor(const TensorOptions & options)349 static SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
350   // TODO: remove this comment after enabling autograd support for CSR tensor
351   // constructor.
352   // TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch());
353   Layout layout = AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(options.layout(), "new_compressed_tensor", [&] { return the_layout; });
354   DispatchKey dispatch_key = DispatchKey::Undefined;
355 
356   switch(options.device().type()) {
357   case kCPU:
358     dispatch_key = DispatchKey::SparseCsrCPU;
359     break;
360   case kCUDA:
361     dispatch_key = DispatchKey::SparseCsrCUDA;
362     break;
363   case kMeta:
364     dispatch_key = DispatchKey::SparseCsrMeta;
365     break;
366   case kPrivateUse1:
367     dispatch_key = DispatchKey::SparseCsrPrivateUse1;
368     break;
369   default:
370     TORCH_CHECK_NOT_IMPLEMENTED(false, "Could not run 'new_compressed_tensor' from the '", options.device(), "' device.)");
371   }
372 
373   return detail::make_tensor<SparseCsrTensorImpl>(DispatchKeySet(dispatch_key), options.device(), layout, options.dtype());
374 }
375 
sparse_compressed_tensor_with_dims(int64_t nnz,int64_t dense_dim,c10::IntArrayRef size,c10::IntArrayRef blocksize,ScalarType index_dtype,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)376 Tensor sparse_compressed_tensor_with_dims(
377      int64_t nnz,
378      int64_t dense_dim,
379      c10::IntArrayRef size,
380      c10::IntArrayRef blocksize,
381      ScalarType index_dtype,
382      std::optional<ScalarType> dtype,
383      std::optional<Layout> layout,
384      std::optional<Device> device,
385      std::optional<bool> pin_memory) {
386   // sparse_compressed_tensor_with_dims is a generalization of empty
387   // that enables the specification of nnz, dense_dim, blocksize, and
388   // index_dtype for sparse compressed tensors.
389   //
390   // sparse_compressed_tensor_with_dims indices and values tensors are
391   // created as empty tensors, so the returned sparse compressed
392   // tensor will not satisfy the sparse compressed tensor
393   // invariants. The caller is responsible for initializing the
394   // indices tensors properly.
395   TORCH_CHECK(layout, "sparse_compressed_tensor_with_dims: expected sparse compressed tensor layout but got none");
396 
397   Layout layout_ = layout.value();
398   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_with_dims", [&]{});
399 
400   constexpr int64_t sparse_dim = 2;
401   int64_t batch_dim = size.size() - dense_dim - sparse_dim;
402   TORCH_CHECK(batch_dim >= 0, "sparse_compressed_tensor_with_dims: dimensionality must be at least dense_dim(=",
403               dense_dim, ") + sparse_dim(=", sparse_dim, "), but got ", size.size());
404 
405   TORCH_CHECK(nnz >= 0, "sparse_compressed_tensor_with_dims: nnz must be non-negative, got ", nnz);
406 
407   auto plain_indices_size = DimVector(size.slice(0, batch_dim));
408   auto compressed_indices_size = DimVector(size.slice(0, batch_dim));
409   auto values_size = DimVector(size.slice(0, batch_dim));
410 
411   plain_indices_size.push_back(nnz);
412   values_size.push_back(nnz);
413 
414   if (layout_ == kSparseBsr || layout_ == kSparseBsc) {
415     TORCH_CHECK(blocksize.size() == (size_t)sparse_dim, "sparse_compressed_tensor_with_dims: blocksize needs to be a tuple of size ",
416                 sparse_dim, ", but got ", blocksize.size());
417     auto d0 = (layout_ == kSparseBsr ? 0 : 1);
418     auto d1 = (layout_ == kSparseBsr ? 1 : 0);
419     TORCH_CHECK(blocksize[0] > 0 && blocksize[1] > 0, "sparse_compressed_tensor_with_dims: blocksize needs to be positive, but got ", blocksize);
420     auto compressed_size = size[compressedDimension(layout_, size, dense_dim)];
421     auto plain_size = size[plainDimension(layout_, size, dense_dim)];
422     TORCH_CHECK(compressed_size % blocksize[d0] == 0, "sparse_compressed_tensor_with_dims: dimension ",
423                 compressedDimension(layout_, size, dense_dim), " must be multiple of blocksize[", d0, "](=", blocksize[d0], ") but got ", compressed_size);
424     TORCH_CHECK(plain_size % blocksize[d1] == 0, "sparse_compressed_tensor_with_dims: dimension ", plainDimension(layout_, size, dense_dim),
425                 " must be multiple of blocksize[", d1, "](=", blocksize[d1], ") but got ", plain_size);
426     compressed_indices_size.push_back(compressed_size / blocksize[d0] + 1);
427     values_size.append(DimVector(blocksize));
428   } else {
429     TORCH_CHECK(blocksize.size() == 0, "sparse_compressed_tensor_with_dims: blocksize cannot be specified for non-block layout ", layout_);
430     compressed_indices_size.push_back(size[compressedDimension(layout_, size, dense_dim)] + 1);
431   }
432 
433   values_size.append(DimVector(size.slice(batch_dim + sparse_dim, dense_dim)));
434   TORCH_CHECK(
435       index_dtype == ScalarType::Int || index_dtype == ScalarType::Long,
436       "indices dtype must be Int or Long, but got ", index_dtype);
437 
438   TensorOptions options_ = TensorOptions().layout(Layout::Strided).device(device).pinned_memory(pin_memory);
439   auto compressed_indices = at::empty(compressed_indices_size, options_.dtype(index_dtype));
440   auto plain_indices = at::empty(plain_indices_size, options_.dtype(index_dtype));
441   auto values = at::empty(values_size, options_.dtype(dtype));
442   TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
443   SparseCsrTensor self = new_compressed_tensor(options);
444   if (pin_memory.value_or(false) && !values.is_pinned()) {
445     get_sparse_csr_impl(self)->set_member_tensors(compressed_indices.pin_memory(), plain_indices.pin_memory(), values.pin_memory(), size);
446   } else {
447     get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
448   }
449   return self;
450 }
451 
_sparse_compressed_tensor_unsafe_symint(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,c10::SymIntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)452 Tensor _sparse_compressed_tensor_unsafe_symint(
453      const Tensor& compressed_indices,
454      const Tensor& plain_indices,
455      const Tensor& values,
456      c10::SymIntArrayRef size,
457      std::optional<ScalarType> dtype,
458      std::optional<Layout> layout,
459      std::optional<Device> device,
460      std::optional<bool> pin_memory) {
461   if (!layout) {
462     AT_ERROR("sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none");
463   }
464   Layout layout_ = layout.value();
465   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
466   if (at::globalContext().checkSparseTensorInvariants()) {
467     _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, C10_AS_INTARRAYREF_SLOW(size), layout_);
468   }
469   TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
470   SparseCsrTensor self = new_compressed_tensor(options);
471   if (pin_memory.value_or(false) && !values.is_pinned()) {
472     get_sparse_csr_impl(self)->set_member_tensors(compressed_indices.pin_memory(), plain_indices.pin_memory(), values.pin_memory(), size);
473   } else {
474     get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
475   }
476   return self;
477 }
478 
479 template <Layout required_layout>
_sparse_compressed_tensor_unsafe_template(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)480 Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices,
481                                                  const Tensor& plain_indices,
482                                                  const Tensor& values,
483                                                  IntArrayRef size,
484                                                  std::optional<ScalarType> dtype,
485                                                  std::optional<Layout> layout,
486                                                  std::optional<Device> device,
487                                                  std::optional<bool> pin_memory) {
488   Layout layout_ = layout.value_or(required_layout);
489   TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
490   if (at::globalContext().checkSparseTensorInvariants()) {
491     _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
492   }
493   TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
494   SparseCsrTensor self = new_compressed_tensor(options);
495   if (pin_memory.value_or(false) && !values.is_pinned()) {
496     get_sparse_csr_impl(self)->set_member_tensors(compressed_indices.pin_memory(), plain_indices.pin_memory(), values.pin_memory(), size);
497   } else {
498     get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
499   }
500   return self;
501 }
502 
503 #define SPARSE_COMPRESSED_TENSOR_UNSAFE(KIND, REQUIRED_LAYOUT)          \
504   Tensor _sparse_##KIND##_tensor_unsafe(const Tensor& compressed_indices, \
505                                         const Tensor& plain_indices,    \
506                                         const Tensor& values,           \
507                                         IntArrayRef size,               \
508                                         std::optional<ScalarType> dtype, \
509                                         std::optional<Layout> layout,   \
510                                         std::optional<Device> device,   \
511                                         std::optional<bool> pin_memory) { \
512     return _sparse_compressed_tensor_unsafe_template<REQUIRED_LAYOUT>(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \
513   }
514 
515 SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr);
516 SPARSE_COMPRESSED_TENSOR_UNSAFE(csc, kSparseCsc);
517 SPARSE_COMPRESSED_TENSOR_UNSAFE(bsr, kSparseBsr);
518 SPARSE_COMPRESSED_TENSOR_UNSAFE(bsc, kSparseBsc);
519 
_estimate_sparse_compressed_tensor_size(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,Layout layout)520 static DimVector _estimate_sparse_compressed_tensor_size(
521     const Tensor& compressed_indices,
522     const Tensor& plain_indices,
523     const Tensor& values,
524     Layout layout) {
525   const int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size", [&] { return 0; }, [&] { return 2; });
526   const int base_ndim = 2;  // corresponds to compressed and plain indices
527   const auto batch_ndim = compressed_indices.dim() - 1;
528   const std::string compressed_indices_name = compressedIndicesName(layout);
529   const std::string plain_indices_name = plainIndicesName(layout);
530   TORCH_CHECK(
531               batch_ndim >= 0,
532               compressed_indices_name, " must have dimensionality >= 1 but got ", compressed_indices.dim());
533   TORCH_CHECK(
534               compressed_indices.dim() == plain_indices.dim(),
535               compressed_indices_name, " and ", plain_indices_name, " dimensionalities must be equal but got ",
536               compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively");
537   const int64_t dense_ndim = values.dim() - batch_ndim - block_ndim - 1;
538   TORCH_CHECK(
539               dense_ndim >= 0,
540               "values must have dimensionality > sum of batch and block dimensionalities (=",
541               batch_ndim, " + ", block_ndim, ") but got ", values.dim());
542   DimVector blocksize{
543                       (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 1)) : 1),
544                       (block_ndim == 2 ? std::max<int64_t>(1, values.size(batch_ndim + 2)) : 1)
545   };
546   DimVector size = DimVector(compressed_indices.sizes().slice(0, batch_ndim));
547   int64_t compressed_dim_size = (compressed_indices.dim() > 0 && compressed_indices.size(-1) > 0 ? compressed_indices.size(-1) - 1 : 0);
548   int64_t plain_dim_size = AT_DISPATCH_INTEGRAL_TYPES(plain_indices.scalar_type(), "estimate_sparse_compressed_tensor_size",
549                                                       [&]() -> int64_t {
550                                                         if (plain_indices.numel() > 0) {
551                                                           return plain_indices.max().item<scalar_t>() + 1;
552                                                         } else {
553                                                           return 0;
554                                                         }
555                                                       });
556   AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size",
557       [&]{
558         size.push_back(compressed_dim_size * blocksize[0]);
559         size.push_back(plain_dim_size * blocksize[1]);
560       },
561       [&]{
562         size.push_back(plain_dim_size * blocksize[0]);
563         size.push_back(compressed_dim_size * blocksize[1]);
564       });
565   for (int i=0; i<dense_ndim; i++) {
566     int64_t j = batch_ndim + 1 + block_ndim + i;
567     size.push_back((j < values.dim() ? values.size(j) : 1));
568   }
569   TORCH_CHECK(
570               static_cast<int>(size.size()) == batch_ndim + base_ndim + dense_ndim,
571               "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=",
572               batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size());
573   return size;
574 }
575 
576 // TODO: This constructor should probably use an ATen abstract method in order
577 // to make autograd dispatch available for the CSR constructor. See the relevant
578 // note in native_functions.yaml.
sparse_compressed_tensor(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)579 Tensor sparse_compressed_tensor(
580     const Tensor& compressed_indices,
581     const Tensor& plain_indices,
582     const Tensor& values,
583     IntArrayRef size,
584     std::optional<ScalarType> dtype,
585     std::optional<Layout> layout,
586     std::optional<Device> device,
587     std::optional<bool> pin_memory) {
588 
589   if (!layout) {
590     AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none");
591   }
592   Layout layout_ = layout.value();
593   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{});
594 
595   // See [Note: hacky wrapper removal for TensorOptions]
596   TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
597 
598   return at::_sparse_compressed_tensor_unsafe(
599       compressed_indices,
600       plain_indices,
601       values,
602       size,
603       optTypeMetaToScalarType(options.dtype_opt()),
604       options.layout_opt(),
605       options.device_opt(),
606       options.pinned_memory_opt());
607 }
608 
sparse_compressed_tensor(const Tensor & compressed_indices,const Tensor & plain_indices,const Tensor & values,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)609 Tensor sparse_compressed_tensor(
610     const Tensor& compressed_indices,
611     const Tensor& plain_indices,
612     const Tensor& values,
613     std::optional<ScalarType> dtype,
614     std::optional<Layout> layout,
615     std::optional<Device> device,
616     std::optional<bool> pin_memory) {
617 
618   if (!layout) {
619     AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none");
620   }
621   Layout layout_ = layout.value();
622   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{});
623 
624   DimVector size = _estimate_sparse_compressed_tensor_size(compressed_indices, plain_indices, values, layout_);
625 
626   // See [Note: hacky wrapper removal for TensorOptions]
627   TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
628 
629   return at::_sparse_compressed_tensor_unsafe(
630       compressed_indices,
631       plain_indices,
632       values,
633       size,
634       optTypeMetaToScalarType(options.dtype_opt()),
635       options.layout_opt(),
636       options.device_opt(),
637       options.pinned_memory_opt());
638 }
639 
640 #define SPARSE_COMPRESSED_TENSOR(KIND, REQUIRED_LAYOUT)                 \
641   Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices,       \
642                                 const Tensor& plain_indices,            \
643                                 const Tensor& values,                   \
644                                 std::optional<ScalarType> dtype,        \
645                                 std::optional<Layout> layout,           \
646                                 std::optional<Device> device,           \
647                                 std::optional<bool> pin_memory) {       \
648     if (layout) {                                                       \
649       TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \
650     }                                                                   \
651     std::optional<Layout> layout_(REQUIRED_LAYOUT);                     \
652     return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, dtype, layout_, device, pin_memory); \
653   }                                                                     \
654   Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices,       \
655                                 const Tensor& plain_indices,            \
656                                 const Tensor& values,                   \
657                                 IntArrayRef size,                       \
658                                 std::optional<ScalarType> dtype,        \
659                                 std::optional<Layout> layout,           \
660                                 std::optional<Device> device,           \
661                                 std::optional<bool> pin_memory) {       \
662     if (layout) {                                                       \
663       TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \
664     }                                                                   \
665     std::optional<Layout> layout_(REQUIRED_LAYOUT);                     \
666     return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, size, dtype, layout_, device, pin_memory); \
667   }
668 
SPARSE_COMPRESSED_TENSOR(csr,kSparseCsr)669 SPARSE_COMPRESSED_TENSOR(csr, kSparseCsr)
670 SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc)
671 SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr)
672 SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc)
673 
674 Tensor empty_sparse_compressed_symint(
675     SymIntArrayRef size,
676     std::optional<ScalarType> dtype,
677     std::optional<Layout> layout,
678     std::optional<Device> device,
679     std::optional<bool> pin_memory,
680     std::optional<MemoryFormat> optional_memory_format) {
681   // TODO: Don't specialize
682   return empty_sparse_compressed(C10_AS_INTARRAYREF_SLOW_ALLOC(size), dtype, layout, device, pin_memory, optional_memory_format);
683 }
684 
685 // Warning: ideally, torch.empty(..., layout=<sparse compressed
686 // format>) ought to be unsupported because it does not return a valid
687 // sparse compressed tensor without initialization of compressed
688 // indices. The implementation below is kept for BC.
empty_sparse_compressed(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<MemoryFormat> optional_memory_format)689 Tensor empty_sparse_compressed(
690     IntArrayRef size,
691     std::optional<ScalarType> dtype,
692     std::optional<Layout> layout,
693     std::optional<Device> device,
694     std::optional<bool> pin_memory,
695     std::optional<MemoryFormat> optional_memory_format) {
696   check_size_nonnegative(size);
697   TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size ", size);
698 
699   // Strided is the default layout for torch.empty.
700   Layout layout_ = layout.value_or(Layout::Strided);
701 
702   // torch.empty cannot be used to create blocked tensors because its
703   // API lacks a method to specify the block size.
704   AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(layout_, "empty_sparse_compressed", [&]{});
705 
706   int64_t nnz = 0;
707   auto compressed_indices_size = DimVector(size.slice(0, size.size() - 2));
708   auto plain_indices_and_values_size = DimVector(size.slice(0, size.size() - 2));
709   compressed_indices_size.push_back(size[compressedDimension(layout_, size)] + 1);
710   plain_indices_and_values_size.push_back(nnz);
711 
712   TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
713   auto compressed_indices = at::empty(compressed_indices_size, options);
714   auto plain_indices = at::empty(plain_indices_and_values_size, options);
715   auto values = at::empty(plain_indices_and_values_size, options.dtype(dtype));
716   // torch.empty on produces garbage so that the resulting empty
717   // sparse compressed tensor may fail to satisfy the following
718   // compressed sparse tensor invariants:
719   //
720   //   compressed_indices[..., 0] == 0
721   //   compressed_indices[..., -1] == nnz.
722   //   compressed_indices must be non-decreasing sequence
723   //
724   // Therefore, avoid using empty to create sparse compressed
725   // tensors. Instead, use compressed sparse constructors directly or
726   // other factory functions such as torch.zeros, etc.
727   return at::_sparse_compressed_tensor_unsafe(compressed_indices,
728                                               plain_indices,
729                                               values,
730                                               size,
731                                               dtype,
732                                               layout,
733                                               device,
734                                               pin_memory);
735 }
736 
resize_sparse_csr_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)737 const Tensor& resize_sparse_csr_(
738     const Tensor& self,
739     IntArrayRef size,
740     std::optional<MemoryFormat> optional_memory_format) {
741   check_size_nonnegative(size);
742   TORCH_CHECK(size.size() >= 2, "torch.resize_: Only batched sparse CSR matrices are supported, but got size ", size);
743   TORCH_CHECK(
744       self.size(-1) <= size[size.size() - 1],
745       "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
746       "The original number of columns is ",
747       self.size(-1),
748       " while the requested new number of columns is ", size[size.size() - 1], ".");
749   get_sparse_csr_impl(self)->resize_(self._nnz(), size);
750   return self;
751 }
752 
copy_sparse_compressed_(Tensor & self,const Tensor & src,bool non_blocking)753 Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocking) {
754   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{});
755   TORCH_CHECK(
756       self.layout() == src.layout(),
757       "torch.copy_: copy of sparse compressed tensors having different layouts is not supported.",
758       " self layout is ", self.layout(), " and src layout is ", src.layout());
759   TORCH_CHECK(
760       self._nnz() == src._nnz(),  // actually, values copy allows different shapes as long as operands are broadcastable
761       "torch.copy_: only sparse compressed tensors with the same number of specified elements are supported.");
762   auto self_compressed_dim = compressedDimension(self.layout(), self.sizes());
763   auto src_compressed_dim = compressedDimension(src.layout(), src.sizes());
764   auto self_compressed_dims = self.size(self_compressed_dim);
765   auto src_compressed_dims = src.size(compressedDimension(src.layout(), src.sizes()));
766   if (self_compressed_dim == src_compressed_dim) {
767     TORCH_CHECK(self_compressed_dims == src_compressed_dims,
768                 "torch.copy_: expected shapes of self and src to match along dimension ",
769                 self_compressed_dim, " for ",
770                 self.layout(), " layout but the corresponding dimensions of self and src are ",
771                 self_compressed_dims, " and ", src_compressed_dims, ", respectively.");
772   } else {
773     TORCH_CHECK(self_compressed_dims == src_compressed_dims,
774                 "torch.copy_: expected shapes of self and src to match along dimensions ",
775                 self_compressed_dim, " and ", src_compressed_dim, ", respectively, for ",
776                 self.layout(), " layout but the corresponding dimensions of self and src are ",
777                 self_compressed_dims, " and ", src_compressed_dims, ", respectively.");
778   }
779   AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
780                                               [&]{},
781                                               [&]{
782                                                 auto self_values = self.values();
783                                                 auto src_values = src.values();
784                                                 auto self_blocksize = DimVector(self_values.sizes().slice(self_values.dim()-2, 2));
785                                                 auto src_blocksize = DimVector(src_values.sizes().slice(src_values.dim()-2, 2));
786                                                 TORCH_CHECK(self_blocksize == src_blocksize,
787                                                             "torch.copy_: copy of sparse compressed tensors having different block sizes is not supported.",
788                                                             " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectively.");
789                                               });
790   AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_",
791                                             [&]{
792                                               self.crow_indices().copy_(src.crow_indices(), non_blocking);
793                                               self.col_indices().copy_(src.col_indices(), non_blocking);
794                                             },
795                                             [&]{
796                                               self.ccol_indices().copy_(src.ccol_indices(), non_blocking);
797                                               self.row_indices().copy_(src.row_indices(), non_blocking);
798                                             });
799   self.values().copy_(src.values(), non_blocking);
800   return self;
801 }
802 
803 // Access members of CSR tensors.
_nnz_sparse_csr(const SparseCsrTensor & self)804 int64_t _nnz_sparse_csr(const SparseCsrTensor& self) {
805   return get_sparse_csr_impl(self)->nnz();
806 }
807 
values_sparse_csr(const Tensor & self)808 Tensor values_sparse_csr(const Tensor& self) {
809   return get_sparse_csr_impl(self)->values().alias();
810 }
811 
crow_indices_sparse_csr(const Tensor & self)812 Tensor crow_indices_sparse_csr(const Tensor& self) {
813   return AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(self.layout(),
814                                                    "crow_indices",
815                                                    [&]{ return get_sparse_csr_impl(self)->compressed_indices().alias(); });
816 }
817 
col_indices_sparse_csr(const Tensor & self)818 Tensor col_indices_sparse_csr(const Tensor& self) {
819   return AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(self.layout(),
820                                                    "col_indices",
821                                                    [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); });
822 }
823 
ccol_indices_sparse_csr(const Tensor & self)824 Tensor ccol_indices_sparse_csr(const Tensor& self) {
825   return AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(self.layout(),
826                                                    "ccol_indices",
827                                                    [&]{ return get_sparse_csr_impl(self)->compressed_indices().alias(); });
828 }
829 
row_indices_sparse_csr(const Tensor & self)830 Tensor row_indices_sparse_csr(const Tensor& self) {
831   return AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(self.layout(),
832                                                    "row_indices",
833                                                    [&]{ return get_sparse_csr_impl(self)->plain_indices().alias(); });
834 }
835 
crow_indices_default(const Tensor & self)836 Tensor crow_indices_default(const Tensor& self) {
837   TORCH_CHECK(false, "crow_indices expected sparse row compressed tensor layout but got ", self.layout());
838 }
839 
col_indices_default(const Tensor & self)840 Tensor col_indices_default(const Tensor& self) {
841   TORCH_CHECK(false, "col_indices expected sparse row compressed tensor layout but got ", self.layout());
842 }
843 
ccol_indices_default(const Tensor & self)844 Tensor ccol_indices_default(const Tensor& self) {
845   TORCH_CHECK(false, "ccol_indices expected sparse column compressed tensor layout but got ", self.layout());
846 }
847 
row_indices_default(const Tensor & self)848 Tensor row_indices_default(const Tensor& self) {
849   TORCH_CHECK(false, "row_indices expected sparse column compressed tensor layout but got ", self.layout());
850 }
851 
sparse_dim_sparse_csr(const SparseCsrTensor & self)852 int64_t sparse_dim_sparse_csr(const SparseCsrTensor& self) {
853   return get_sparse_csr_impl(self)->sparse_dim();
854 }
855 
dense_dim_sparse_csr(const SparseCsrTensor & self)856 int64_t dense_dim_sparse_csr(const SparseCsrTensor& self) {
857   return get_sparse_csr_impl(self)->dense_dim();
858 }
859 
resize_as_sparse_compressed_(const SparseCsrTensor & self,const SparseCsrTensor & src)860 const SparseCsrTensor& resize_as_sparse_compressed_(
861     const SparseCsrTensor& self,
862     const SparseCsrTensor& src) {
863   auto src_layout = src.layout();
864   auto self_layout = self.layout();
865   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
866       src_layout, "resize_as_sparse_compressed_: src ", []() {});
867   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
868       self_layout, "resize_as_sparse_compressed_: self ", []() {});
869   // Note: The impl method does all required checking to see if resize/data copy
870   // on member tensors is required.
871   get_sparse_csr_impl(self)->resize_as_sparse_compressed_tensor_(src);
872   return self;
873 }
874 
clone_sparse_compressed(const SparseCsrTensor & self,std::optional<c10::MemoryFormat> optional_memory_format)875 SparseCsrTensor clone_sparse_compressed(
876                                         const SparseCsrTensor& self,
877                                         std::optional<c10::MemoryFormat> optional_memory_format) {
878   TORCH_CHECK(
879       !optional_memory_format.has_value(),
880       "unsupported memory format option ",
881       optional_memory_format.value());
882   TensorOptions options = self.options();
883   auto compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(),
884                                                                       "clone_sparse_compressed",
885                                                                       [&]{ return self.crow_indices(); },
886                                                                       [&]{ return self.ccol_indices(); });
887   auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(),
888                                                                  "clone_sparse_compressed",
889                                                                  [&]{ return self.col_indices(); },
890                                                                  [&]{ return self.row_indices(); });
891   return at::_sparse_compressed_tensor_unsafe(
892        compressed_indices.clone(),
893        plain_indices.clone(),
894        self.values().clone(),
895        self.sizes(),
896        optTypeMetaToScalarType(options.dtype_opt()),
897        options.layout_opt(),
898        options.device_opt(),
899        options.pinned_memory_opt());
900 }
901 
empty_like_sparse_csr(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)902 Tensor empty_like_sparse_csr(
903     const Tensor& self,
904     std::optional<ScalarType> dtype,
905     std::optional<Layout> layout,
906     std::optional<Device> device,
907     std::optional<bool> pin_memory,
908     std::optional<c10::MemoryFormat> optional_memory_format) {
909   TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
910   TensorOptions options =
911       self.options()
912           .merge_in(options_)
913           .merge_memory_format(optional_memory_format);
914 
915   TORCH_CHECK(options.layout() == self.layout(),
916     "empty_like with different sparse layout is not supported (self is ",
917     self.layout(), " but you requested ", options.layout(), ")");
918   if (options.layout() == kSparseCsr) {
919     auto result = at::native::_sparse_csr_tensor_unsafe(
920         self.crow_indices().to(options.device(), self.crow_indices().dtype(), false, true),
921         self.col_indices().to(options.device(), self.col_indices().dtype(), false, true),
922         at::empty(self.values().sizes(), options.layout(kStrided)),
923         self.sizes(),
924         optTypeMetaToScalarType(options.dtype()),
925         self.layout(),
926         options.device());
927     return result;
928   } else if (options.layout() == kSparseCsc) {
929     auto result = at::native::_sparse_csc_tensor_unsafe(
930         self.ccol_indices().to(options.device(), self.ccol_indices().dtype(), false, true),
931         self.row_indices().to(options.device(), self.row_indices().dtype(), false, true),
932         at::empty(self.values().sizes(), options.layout(kStrided)),
933         self.sizes(),
934         optTypeMetaToScalarType(options.dtype()),
935         self.layout(),
936         options.device());
937     return result;
938   } else if (options.layout() == kSparseBsr) {
939     auto result = at::native::_sparse_bsr_tensor_unsafe(
940         self.crow_indices().to(options.device(), self.crow_indices().dtype(), false, true),
941         self.col_indices().to(options.device(), self.col_indices().dtype(), false, true),
942         at::empty(self.values().sizes(), options.layout(kStrided)),
943         self.sizes(),
944         optTypeMetaToScalarType(options.dtype()),
945         self.layout(),
946         options.device());
947 
948     return result;
949   } else if (options.layout() == kSparseBsc) {
950     auto result = at::native::_sparse_bsc_tensor_unsafe(
951         self.ccol_indices().to(options.device(), self.ccol_indices().dtype(), false, true),
952         self.row_indices().to(options.device(), self.row_indices().dtype(), false, true),
953         at::empty(self.values().sizes(), options.layout(kStrided)),
954         self.sizes(),
955         optTypeMetaToScalarType(options.dtype()),
956         self.layout(),
957         options.device());
958     return result;
959   } else if (options.layout() == kStrided) {
960     return at::native::empty_like(self, dtype, layout, device, pin_memory, optional_memory_format);
961   } else {
962     TORCH_CHECK(false, "Layout ", options.layout(), " is not supported");
963   }
964 }
965 
966 template <bool require_view, bool require_copy>
select_sparse_csr_worker(const Tensor & self,int64_t dim,int64_t index)967 Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index) {
968 #ifndef STRIP_ERROR_MESSAGES
969   constexpr const char* select_name = (require_view ? "select()" : "select_copy()");
970 #endif
971   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
972       self.layout(), "select", []() { return; });
973   TORCH_CHECK_INDEX(
974       self.dim() != 0, select_name, " cannot be applied to a 0-dim tensor.");
975   dim = maybe_wrap_dim(dim, self.dim());
976   auto size = self.size(dim);
977   if (index < -size || index >= size) {
978     TORCH_CHECK_INDEX(
979         false,
980         select_name, ": index ",
981         index,
982         " out of range for tensor of size ",
983         self.sizes(),
984         " at dimension ",
985         dim);
986   }
987   if (index < 0) {
988     index += size;
989   }
990 
991   auto select_strided = [](const Tensor& self, int64_t dim, int64_t index) {
992     if (require_copy) {
993       return at::select_copy(self, dim, index);
994     } else {
995       return self.select(dim, index);
996     }
997   };
998 
999   TORCH_INTERNAL_ASSERT(dim >= 0 && dim < self.dim());
1000 
1001   auto new_sizes = DimVector(self.sizes());
1002   new_sizes.erase(new_sizes.begin() + dim);
1003   auto options = self.options();
1004 
1005   auto [compressed_indices, plain_indices] =
1006       AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
1007           self.layout(),
1008           "select",
1009           [&]() {
1010             return std::make_pair(self.crow_indices(), self.col_indices());
1011           },
1012           [&]() {
1013             return std::make_pair(self.ccol_indices(), self.row_indices());
1014           });
1015   auto n_batch = compressed_indices.dim() - 1;
1016 
1017   if (dim < n_batch) {
1018     // Selecting batch dimension
1019     return at::_sparse_compressed_tensor_unsafe(
1020         compressed_indices.select(dim, index),
1021         plain_indices.select(dim, index),
1022         select_strided(self.values(), dim, index),
1023         new_sizes,
1024         optTypeMetaToScalarType(options.dtype_opt()),
1025         options.layout_opt(),
1026         options.device_opt(),
1027         options.pinned_memory_opt());
1028   } else if (dim < n_batch + 2) {
1029     // Selecting sparse dimension
1030     TORCH_CHECK(
1031         n_batch == 0,
1032         select_name, ": selecting sparse dimensions is not supported for batched sparse compressed tensors.")
1033     TORCH_INTERNAL_ASSERT(dim == 0 || dim == 1);
1034 
1035     DimVector blocksize{1, 1};
1036     AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&] {}, [&] {
1037       blocksize[0] = std::max<int64_t>(1, self.values().size(n_batch + 1));
1038       blocksize[1] = std::max<int64_t>(1, self.values().size(n_batch + 2));
1039     });
1040 
1041     auto indices_options = compressed_indices.options();
1042     int64_t fast_dim = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() { return 0; }, [&]() { return 1; });
1043     int64_t other_dim = (dim == 0 ? 1 : 0);
1044     Tensor indices;
1045     Tensor values;
1046     bool is_view = dim == fast_dim;
1047     if (is_view) {
1048       // select is always a view operation
1049       Tensor start_end = compressed_indices.narrow(0, index / blocksize[dim], 2).cpu();
1050       int64_t start = start_end[0].item<int64_t>();
1051       int64_t end = start_end[1].item<int64_t>();
1052       indices = plain_indices.slice(0, start, end);
1053       values = self.values().slice(0, start, end);
1054     } else {
1055       Tensor decompressed_indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
1056         .select(0, 0);
1057 
1058       Tensor dim_indices = at::where(plain_indices.eq(index / blocksize[dim]))[0];
1059       // Notice that dim_indices is a sorted sequence of non-negative
1060       // distinct integers. Below we'll try to solve `dim_indices ==
1061       // arange(start, stop, step)`. If the solution exists then the
1062       // select will be a view operation also for the `dim !=
1063       // fast_dim` case.
1064       int64_t start{}, end{}, step{};
1065       if (solve_arange(dim_indices, start, end, step)) {
1066         indices = decompressed_indices.slice(0, start, end, step);
1067         values = self.values().slice(0, start, end, step);
1068         is_view = true;
1069       } else {
1070         // select will be a copy operation due to index_select!
1071         indices = decompressed_indices.index_select(0, dim_indices);
1072         values = self.values().index_select(0, dim_indices);
1073       }
1074     }
1075 
1076     AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() {},
1077         [&]() {
1078           /*
1079             The formula for select indices and values below are best
1080             explained by an example. Consider a BSR tensor with a
1081             block size (2, 3) having four blocks (the other two blocks
1082             contain all zeros and hence will not be specified):
1083 
1084               [ 1  2  3] | [ 7  8  9]
1085               [ 4  5  6] | [10 11 12]
1086               ---------------------
1087               [13 14 15] | [ 0  0  0]
1088               [16 17 18] | [ 0  0  0]
1089               -----------------------
1090               [ 0  0  0] | [19 20 21]
1091               [ 0  0  0] | [22 23 24]
1092 
1093             that represents a 6 x 6 tensor:
1094 
1095               [  1  2  3  7  8  9 ]
1096               [  4  5  6 10 11 12 ]
1097               [ 13 14 15  0  0  0 ]
1098               [ 16 17 18  0  0  0 ]
1099               [  0  0  0 19 20 21 ]
1100               [  0  0  0 22 23 24 ]
1101 
1102             The corresponding data for the BSR representation is:
1103 
1104               crow_indices = [0 2 3 4]
1105               col_indices =  [0 1 0 1]
1106               values = [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]], [[13 14 15], [16 17 18]], [[19 20 21], [22 23 24]] ]
1107               shape = (6, 6)
1108 
1109             From crow_indices, we can find that
1110 
1111               row_indices = [0 0 1 2]
1112 
1113             In the following, we'll illustrate the details of
1114             computing the result of torch.select_copy(input, dim,
1115             index) where dim is 0 or 1, and index is in
1116             range(shape[dim]).
1117 
1118             Select a row of a BSR tensor
1119             ----------------------------
1120 
1121             We will consider first the dim=0 case that corresponds to
1122             selecting a index-th row of the tensor. For instance, for
1123             dim=0 and index=1, the expected result would represent a
1124             1D tensor:
1125 
1126               [  4  5  6 10 11 12 ]
1127 
1128             that is a concatenated tensor of certain slices from the
1129             first and the second block that is computed as follows:
1130 
1131               values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1)
1132               -> values[[0, 1]][:, 1 % 2].flatten(0, 1)
1133               -> [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]] ][:, 1].flatten(0, 1)
1134               -> [ [4 5 6], [10 11 12]].flatten(0, 1)
1135               -> [ 4 5 6 10 11 12]
1136 
1137             where dim_indices is found as
1138 
1139               where(row_indices == index//blocksize[dim])
1140               -> where([0 0 1 2] == 1//2)
1141               -> [0 1]
1142 
1143             The corresponding column indices are computed as
1144 
1145               (col_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1146 
1147             where other_dim is 1 if dim is 0, and 0 if dim is 1. Let's
1148             expand the above expression with the data in the example:
1149 
1150               -> (col_indices[[0, 1]].mul(3).unsqueeze(1) + arange(3).unsqueeze(0)).flatten(0, 1)
1151               -> ([[0 1].mul(3).unsqueeze(1) + [[0 1 2]]).flatten(0, 1)
1152               -> ([[[0], [3]] + [[0 1 2]]).flatten(0, 1)     <- here addition will use broadcasting rules!
1153               -> ([[[0 1 2], [3 4 5]]).flatten(0, 1)
1154               -> [0 1 2 3 4 5]
1155 
1156             Finally, the select(dim=0, index=1) op on the given sparse
1157             compressed tensors will return a COO tensor:
1158 
1159               sparse_coo_tensor([0 1 2 3 4 5].unsqueeze(0), [4 5 6 10 11 12], (6,))
1160 
1161             that represents the expected result: [ 4 5 6 10 11 12 ]
1162 
1163             Select a column of a BSR tensor
1164             -------------------------------
1165 
1166             Next, we'll consider the dim=1 case that corresponds to
1167             selecting the index-th column of the tensor. For instance,
1168             for dim=1 and index=4, the expected result would represent
1169             a 1D tensor:
1170 
1171               [  8 11 0  0 20 23]
1172 
1173             that is a concatenated tensor of certain slices from the
1174             second and the last block:
1175 
1176               values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1)
1177               -> values[[1, 3]][:, :, 4 % 3 ].flatten(0, 1)
1178               -> [ [[7 8 9], [10 11 12]], [[19 20 21], [22 23 24]] ][:, 1, 1].flatten(0, 1)
1179               -> [ [8 11], [20 23]].flatten(0, 1)
1180               -> [ 8 11 20 23 ]
1181 
1182             The corresponding row indices are computed as
1183 
1184               (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1185 
1186             where dim_indices is
1187 
1188               where(col_indices == index//blocksize[dim])
1189               -> where([0 1 0 1] == 4//3)
1190               -> [1 3]
1191 
1192             and we have
1193 
1194               (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
1195               -> (row_indices[[1 3]].mul(2).unsqueeze(1) + arange(2).unsqueeze(0)).flatten(0, 1)
1196               -> ([0 4].unsqueeze(1) + [0 1].unsqueeze(0)).flatten(0, 1)
1197               -> ([[0], [4]] + [[0 1]]).flatten(0, 1)     <- here addition will use broadcasting rules!
1198               -> ([[0 1], [4 5]]).flatten(0, 1)
1199               -> [ 0 1 4 5 ]
1200 
1201             Finally, the select(dim=1, index=4) op on the given sparse
1202             compressed tensors will return a COO tensor:
1203 
1204               sparse_coo_tensor([0 1 4 5].unsqueeze(0), [8 11 20 23], (6,))
1205 
1206             that represents the expected result: [ 8 11 0 0 20 23 ]
1207 
1208            */
1209           Tensor subblock_indices = at::arange(0, blocksize[other_dim], indices_options);
1210           indices = indices.mul(blocksize[other_dim]).unsqueeze(1).add(subblock_indices.unsqueeze(0)).flatten(0, 1);
1211           values = values.select(dim + 1, index % blocksize[dim]).flatten(0, 1);
1212           // flatten(0, 1) can be a view or a copy operation. If view
1213           // is required, it will be checked below via is_alias_of,
1214           // otherwise, we'll check if copy is made here to avoid
1215           // unnecessary clone below:
1216           if (require_copy) {
1217             is_view = values.is_alias_of(self.values());
1218           }
1219         });
1220 
1221     if (require_view) {
1222       TORCH_CHECK(values.is_alias_of(self.values()), select_name,
1223                   ": no view exists for the given input, consider using torch.select_copy.");
1224     }
1225 
1226     indices = indices.unsqueeze(0).to(kLong);
1227     if (require_copy && is_view) {
1228       values = values.clone();
1229     }
1230     return at::_sparse_coo_tensor_unsafe(indices, values, new_sizes)._coalesced_(true);
1231   } else {
1232     // Selecting dense dimension
1233     Tensor new_values = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
1234         self.layout(),
1235         "select",
1236         // Non blocked layout (2 sparse dims become 1 nnz dim in values, so dim
1237         // is found one position to the left)
1238         [&]() { return select_strided(self.values(), dim - 1, index); },
1239         // Block layout (2 sparse dims become 1 nnz dim + 2 block-shape dims in
1240         // values, so dim is found 1 position to the right)
1241         [&]() { return select_strided(self.values(), dim + 1, index); });
1242     return at::_sparse_compressed_tensor_unsafe(
1243         compressed_indices,
1244         plain_indices,
1245         new_values,
1246         new_sizes,
1247         optTypeMetaToScalarType(options.dtype_opt()),
1248         options.layout_opt(),
1249         options.device_opt(),
1250         options.pinned_memory_opt());
1251   }
1252 }
1253 
select_sparse_csr(const Tensor & self,int64_t dim,int64_t index)1254 Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
1255   return select_sparse_csr_worker<true, false>(self, dim, index);
1256 }
1257 
select_copy_sparse_csr(const Tensor & self,int64_t dim,int64_t index)1258 Tensor select_copy_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
1259   return select_sparse_csr_worker<false, true>(self, dim, index);
1260 }
1261 
is_pinned_sparse_compressed(const Tensor & self,std::optional<Device> device)1262 bool is_pinned_sparse_compressed(const Tensor& self, std::optional<Device> device) {
1263   // Assuming that compressed/plain_indices has the same pin memory state as values
1264   return self.values().is_pinned(device);
1265 }
1266 
_pin_memory_sparse_compressed(const Tensor & self,std::optional<Device> device)1267 Tensor _pin_memory_sparse_compressed(const Tensor& self, std::optional<Device> device) {
1268   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_cuda());
1269   // pinning of sparse tensor is equivalent to cloning indices and
1270   // values that will not change the sparse tensor invariants. Hence,
1271   // we can skip checking the sparse tensor invariants for efficiency.
1272   CheckSparseTensorInvariants _(false);
1273   TensorOptions options = self.options().pinned_memory(true);
1274   auto impl = get_sparse_csr_impl(self);
1275   return at::_sparse_compressed_tensor_unsafe(
1276         impl->compressed_indices().pin_memory(device),
1277         impl->plain_indices().pin_memory(device),
1278         impl->values().pin_memory(device),
1279         self.sizes(),
1280         optTypeMetaToScalarType(options.dtype_opt()),
1281         options.layout_opt(),
1282         options.device_opt(),
1283         options.pinned_memory_opt());
1284 }
1285 
1286 } // namespace at::native
1287