xref: /aosp_15_r20/external/pytorch/aten/src/ATen/SparseTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/core/TensorImpl.h>
5 #include <c10/core/impl/TorchDispatchModeTLS.h>
6 #include <c10/util/Exception.h>
7 #include <c10/util/irange.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/resize.h>
14 #endif
15 
16 namespace at {
17 struct TORCH_API SparseTensorImpl : public TensorImpl {
18   // Stored in COO format, indices + values.
19 
20   // INVARIANTS:
21   // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
22   // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
23   // _indices.shape: dimensionality: 2,  shape: (sparse_dim, nnz)
24   // _values.shape:  dimensionality: 1 + dense_dim.  shape: (nnz,
25   // shape[sparse_dim:])
26 
27   int64_t sparse_dim_ = 0; // number of sparse dimensions
28   int64_t dense_dim_ = 0; // number of dense dimensions
29 
30   Tensor indices_; // always a LongTensor
31   Tensor values_;
32 
33   // A sparse tensor is 'coalesced' if every index occurs at most once in
34   // the indices tensor, and the indices are in sorted order.  (This means
35   // that it is very easy to convert a coalesced tensor to CSR format: you
36   // need only compute CSR format indices.)
37   //
38   // Most math operations can only be performed on coalesced sparse tensors,
39   // because many algorithms proceed by merging two sorted lists (of indices).
40   bool coalesced_ = false;
41 
42   // compute_numel with integer multiplication overflow check, see gh-57542
refresh_numelSparseTensorImpl43   void refresh_numel() {
44     TensorImpl::safe_refresh_numel();
45   }
46 
47  public:
48   // Public for now...
49   explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
50 
51   void release_resources() override;
52 
nnzSparseTensorImpl53   int64_t nnz() const {
54     return values_.size(0);
55   }
56 
sym_nnzSparseTensorImpl57   c10::SymInt sym_nnz() const {
58     return values_.sym_size(0);
59   }
sparse_dimSparseTensorImpl60   int64_t sparse_dim() const {
61     return sparse_dim_;
62   }
dense_dimSparseTensorImpl63   int64_t dense_dim() const {
64     return dense_dim_;
65   }
coalescedSparseTensorImpl66   bool coalesced() const {
67     return coalesced_;
68   }
indicesSparseTensorImpl69   Tensor indices() const {
70     return indices_;
71   }
valuesSparseTensorImpl72   Tensor values() const {
73     return values_;
74   }
75 
76   void set_size(int64_t dim, int64_t new_size) override;
77   void set_stride(int64_t dim, int64_t new_stride) override;
78   void set_storage_offset(int64_t storage_offset) override;
79 
80 #ifdef DEBUG
81   bool has_storage() const override;
82 #endif
83 
84   // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim
85   // with respect to indices and values
raw_resize_SparseTensorImpl86   void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
87     TORCH_CHECK(
88         allow_tensor_metadata_change(),
89         "raw_resize_ ",
90         err_msg_tensor_metadata_change_not_allowed);
91     TORCH_CHECK(
92         !has_symbolic_sizes_strides_,
93         "raw_resize_ called on tensor with symbolic shape")
94     set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
95     sparse_dim_ = sparse_dim;
96     dense_dim_ = dense_dim;
97     refresh_numel();
98   }
99 
100   // NOTE: This function preserves invariants of sparse_dim/dense_dim with
101   // respect to indices and values.
102   //
103   // NOTE: This function supports the following cases:
104   // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking
105   // the size of any of the dense dimensions.
106   // 2. When we keep the number of sparse dimensions unchanged, and NOT
107   // shrinking the size of any of the sparse dimensions.
108   // 3. When the sparse tensor has zero nnz, in which case we are free to change
109   // the shapes of both its sparse and dense dimensions.
110   //
111   // This function DOESN'T support (and will throw an error) the following
112   // cases:
113   // 1. When we attempt to change the number of sparse dimensions on a non-empty
114   // sparse tensor (such an operation will invalidate the indices stored).
115   // 2. When we attempt to change the number of dense dimensions on a non-empty
116   // sparse tensor (such an operation will behave differently from an equivalent
117   // dense tensor's resize method, and for API consistency we don't support it).
118   // 3. When we attempt to shrink the size of any of the dense dimensions on a
119   // non-empty sparse tensor (such an operation will behave differently from an
120   // equivalent dense tensor's resize method, and for API consistency we don't
121   // support it).
122   // 4. When we attempt to shrink the size of any of the sparse dimensions on a
123   // non-empty sparse tensor (this could make some of the stored indices
124   // out-of-bound and thus unsafe).
125   template <typename T>
_resize_SparseTensorImpl126   void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) {
127     TORCH_CHECK(
128         allow_tensor_metadata_change(),
129         "resize_ ",
130         err_msg_tensor_metadata_change_not_allowed);
131     TORCH_CHECK(
132         !has_symbolic_sizes_strides_,
133         "resize_ called on tensor with symbolic shape")
134     TORCH_CHECK(
135         sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
136         "number of dimensions must be sparse_dim (",
137         sparse_dim,
138         ") + dense_dim (",
139         dense_dim,
140         "), but got ",
141         size.size());
142     if (nnz() > 0) {
143       [[maybe_unused]] auto constexpr alt_options_msg =
144           "You could try the following options:\n\
145 1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
146 2. If you need to resize this tensor, you have the following options:\n\
147     1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
148     2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
149 
150       TORCH_CHECK(
151           sparse_dim == sparse_dim_,
152           "changing the number of sparse dimensions (from ",
153           sparse_dim_,
154           " to ",
155           sparse_dim,
156           ") on a non-empty sparse tensor is not supported.\n",
157           alt_options_msg);
158 
159       TORCH_CHECK(
160           dense_dim == dense_dim_,
161           "changing the number of dense dimensions (from ",
162           dense_dim_,
163           " to ",
164           dense_dim,
165           ") on a non-empty sparse tensor is not supported.\n",
166           alt_options_msg);
167 
168       bool shrinking_sparse_dims = false;
169       bool shrinking_dense_dim = false;
170       auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim);
171       auto sparse_size_new = size.slice(0, sparse_dim);
172       for (const auto i : c10::irange(sparse_dim)) {
173         if (sparse_size_new[i] < sparse_size_original[i]) {
174           shrinking_sparse_dims = true;
175           break;
176         }
177       }
178       auto dense_size_original = generic_sizes<T>().slice(sparse_dim);
179       auto dense_size_new = size.slice(sparse_dim);
180       for (const auto i : c10::irange(dense_dim)) {
181         if (dense_size_new[i] < dense_size_original[i]) {
182           shrinking_dense_dim = true;
183           break;
184         }
185       }
186 
187       TORCH_CHECK(
188           !shrinking_sparse_dims,
189           "shrinking the size of sparse dimensions (from ",
190           sparse_size_original,
191           " to ",
192           sparse_size_new,
193           ") on a non-empty sparse tensor is not supported.\n",
194           alt_options_msg);
195 
196       TORCH_CHECK(
197           !shrinking_dense_dim,
198           "shrinking the size of dense dimensions (from ",
199           dense_size_original,
200           " to ",
201           dense_size_new,
202           ") on a non-empty sparse tensor is not supported.\n",
203           alt_options_msg);
204     }
205 
206     auto sizes_and_strides = generic_sizes<T>();
207     const bool size_equals_sizes = std::equal(
208         size.begin(),
209         size.end(),
210         sizes_and_strides.begin(),
211         sizes_and_strides.end());
212     if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) ||
213         (dense_dim != dense_dim_)) {
214       auto nnz = at::symint::sizes<T>(values())[0];
215       std::vector<T> values_size = {nnz};
216       auto dense_size = size.slice(sparse_dim);
217       values_size.insert(
218           values_size.end(), dense_size.begin(), dense_size.end());
219       at::symint::resize_<T>(values_, values_size);
220       at::symint::resize_<T>(indices_, {T(sparse_dim), nnz});
221     }
222 
223     if (!size_equals_sizes) {
224       set_sizes_and_strides(size, std::vector<T>(size.size()));
225     }
226     sparse_dim_ = sparse_dim;
227     dense_dim_ = dense_dim;
228     refresh_numel();
229   }
230 
resize_SparseTensorImpl231   void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) {
232     return _resize_(sparse_dim, dense_dim, size);
233   }
234 
resize_SparseTensorImpl235   void resize_(
236       int64_t sparse_dim,
237       int64_t dense_dim,
238       ArrayRef<c10::SymInt> size) {
239     return _resize_(sparse_dim, dense_dim, size);
240   }
241 
242   // NOTE: this function will resize the sparse tensor and also set `indices`
243   // and `values` to empty.
resize_and_clear_SparseTensorImpl244   void resize_and_clear_(
245       int64_t sparse_dim,
246       int64_t dense_dim,
247       IntArrayRef size) {
248     TORCH_CHECK(
249         allow_tensor_metadata_change(),
250         "resize_and_clear_ ",
251         err_msg_tensor_metadata_change_not_allowed);
252     TORCH_CHECK(
253         !has_symbolic_sizes_strides_,
254         "resize_and_clear_ called on tensor with symbolic shape")
255     TORCH_CHECK(
256         sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
257         "number of dimensions must be sparse_dim (",
258         sparse_dim,
259         ") + dense_dim (",
260         dense_dim,
261         "), but got ",
262         size.size());
263 
264     set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
265     sparse_dim_ = sparse_dim;
266     dense_dim_ = dense_dim;
267 
268     auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
269     std::vector<int64_t> values_size = {0};
270     auto dense_size = sizes().slice(sparse_dim);
271     values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
272     auto empty_values = at::empty(values_size, values().options());
273     set_indices_and_values_unsafe(empty_indices, empty_values);
274     refresh_numel();
275   }
276 
set_coalescedSparseTensorImpl277   void set_coalesced(bool coalesced) {
278     TORCH_CHECK(
279         allow_tensor_metadata_change(),
280         "set_coalesced ",
281         err_msg_tensor_metadata_change_not_allowed);
282     coalesced_ = coalesced;
283   }
284 
285   // NOTE: this function is only used internally and not exposed to Python
286   // frontend
set_nnz_and_narrowSparseTensorImpl287   void set_nnz_and_narrow(int64_t new_nnz) {
288     TORCH_CHECK(
289         allow_tensor_metadata_change(),
290         "set_nnz_and_narrow ",
291         err_msg_tensor_metadata_change_not_allowed);
292     AT_ASSERT(new_nnz <= nnz());
293     indices_ = indices_.narrow(1, 0, new_nnz);
294     values_ = values_.narrow(0, 0, new_nnz);
295     if (new_nnz < 2) {
296       coalesced_ = true;
297     }
298   }
299 
300   // Takes indices and values and directly puts them into the sparse tensor, no
301   // copy. NOTE: this function is unsafe because it doesn't check whether any
302   // indices are out of boundaries of `sizes`, so it should ONLY be used where
303   // we know that the indices are guaranteed to be within bounds. This used to
304   // be called THSTensor_(_move) NB: This used to be able to avoid a refcount
305   // bump, but I was too lazy to make it happen
306   void set_indices_and_values_unsafe(
307       const Tensor& indices,
308       const Tensor& values);
309 
310   template <typename VariableVersion>
shallow_copy_and_detach_coreSparseTensorImpl311   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
312       VariableVersion&& version_counter,
313       bool allow_tensor_metadata_change) const {
314     const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
315     c10::impl::PyInterpreter&& interpreter = nullptr;
316     if (mode_stack_len > 0 &&
317         !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
318       const auto& cur_torch_dispatch_mode_state =
319           c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
320       interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
321     } else if (
322         key_set_.has(DispatchKey::Python) &&
323         !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
324       interpreter = pyobj_slot_.load_pyobj_interpreter();
325     } else {
326       // otherwise just copy the SparseTensorImpl and not the PyObject.
327       auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
328       copy_tensor_metadata(
329           /*src_sparse_impl=*/this,
330           /*dest_sparse_impl=*/impl.get(),
331           /*version_counter=*/version_counter,
332           /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
333       impl->refresh_numel();
334       return impl;
335     }
336     auto r = interpreter->detach(this);
337     r->set_version_counter(std::forward<VariableVersion>(version_counter));
338     r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
339     return r;
340   }
341 
342   /**
343    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
344    *
345    * For usage of `version_counter` and `allow_tensor_metadata_change`,
346    * see NOTE [ TensorImpl Shallow-Copying ].
347    */
shallow_copy_and_detachSparseTensorImpl348   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
349       const c10::VariableVersion& version_counter,
350       bool allow_tensor_metadata_change) const override {
351     return shallow_copy_and_detach_core(
352         version_counter, allow_tensor_metadata_change);
353   }
354 
355   /**
356    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
357    *
358    * For usage of `version_counter` and `allow_tensor_metadata_change`,
359    * see NOTE [ TensorImpl Shallow-Copying ].
360    */
shallow_copy_and_detachSparseTensorImpl361   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
362       c10::VariableVersion&& version_counter,
363       bool allow_tensor_metadata_change) const override {
364     return shallow_copy_and_detach_core(
365         std::move(version_counter), allow_tensor_metadata_change);
366   }
367 
368   /**
369    * Shallow-copies data from another TensorImpl into this TensorImpl.
370    *
371    * For why this function doesn't check this TensorImpl's
372    * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
373    */
shallow_copy_fromSparseTensorImpl374   void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
375     AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
376     auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
377     copy_tensor_metadata(
378         /*src_sparse_impl=*/sparse_impl,
379         /*dest_sparse_impl=*/this,
380         /*version_counter=*/version_counter(),
381         /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
382     refresh_numel();
383   }
384 
385  private:
386   explicit SparseTensorImpl(
387       at::DispatchKeySet,
388       const caffe2::TypeMeta,
389       at::Tensor indices,
390       at::Tensor values);
391 
392   /**
393    * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
394    * storage_offset) from one TensorImpl to another TensorImpl.
395    *
396    * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
397    * [ TensorImpl Shallow-Copying ].
398    */
copy_tensor_metadataSparseTensorImpl399   static void copy_tensor_metadata(
400       const SparseTensorImpl* src_sparse_impl,
401       SparseTensorImpl* dest_sparse_impl,
402       c10::VariableVersion version_counter,
403       bool allow_tensor_metadata_change) {
404     TensorImpl::copy_tensor_metadata(
405         src_sparse_impl,
406         dest_sparse_impl,
407         std::move(version_counter),
408         allow_tensor_metadata_change);
409 
410     // Sparse-specific fields
411     dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
412     dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
413     dest_sparse_impl->indices_ = src_sparse_impl->indices();
414     dest_sparse_impl->values_ = src_sparse_impl->values();
415     dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
416   }
417 
418   const char* tensorimpl_type_name() const override;
419 };
420 
421 } // namespace at
422