1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <c10/core/TensorImpl.h> 5 #include <c10/core/impl/TorchDispatchModeTLS.h> 6 #include <c10/util/Exception.h> 7 #include <c10/util/irange.h> 8 9 #ifndef AT_PER_OPERATOR_HEADERS 10 #include <ATen/Functions.h> 11 #else 12 #include <ATen/ops/empty.h> 13 #include <ATen/ops/resize.h> 14 #endif 15 16 namespace at { 17 struct TORCH_API SparseTensorImpl : public TensorImpl { 18 // Stored in COO format, indices + values. 19 20 // INVARIANTS: 21 // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape) 22 // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape) 23 // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz) 24 // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz, 25 // shape[sparse_dim:]) 26 27 int64_t sparse_dim_ = 0; // number of sparse dimensions 28 int64_t dense_dim_ = 0; // number of dense dimensions 29 30 Tensor indices_; // always a LongTensor 31 Tensor values_; 32 33 // A sparse tensor is 'coalesced' if every index occurs at most once in 34 // the indices tensor, and the indices are in sorted order. (This means 35 // that it is very easy to convert a coalesced tensor to CSR format: you 36 // need only compute CSR format indices.) 37 // 38 // Most math operations can only be performed on coalesced sparse tensors, 39 // because many algorithms proceed by merging two sorted lists (of indices). 40 bool coalesced_ = false; 41 42 // compute_numel with integer multiplication overflow check, see gh-57542 refresh_numelSparseTensorImpl43 void refresh_numel() { 44 TensorImpl::safe_refresh_numel(); 45 } 46 47 public: 48 // Public for now... 49 explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta); 50 51 void release_resources() override; 52 nnzSparseTensorImpl53 int64_t nnz() const { 54 return values_.size(0); 55 } 56 sym_nnzSparseTensorImpl57 c10::SymInt sym_nnz() const { 58 return values_.sym_size(0); 59 } sparse_dimSparseTensorImpl60 int64_t sparse_dim() const { 61 return sparse_dim_; 62 } dense_dimSparseTensorImpl63 int64_t dense_dim() const { 64 return dense_dim_; 65 } coalescedSparseTensorImpl66 bool coalesced() const { 67 return coalesced_; 68 } indicesSparseTensorImpl69 Tensor indices() const { 70 return indices_; 71 } valuesSparseTensorImpl72 Tensor values() const { 73 return values_; 74 } 75 76 void set_size(int64_t dim, int64_t new_size) override; 77 void set_stride(int64_t dim, int64_t new_stride) override; 78 void set_storage_offset(int64_t storage_offset) override; 79 80 #ifdef DEBUG 81 bool has_storage() const override; 82 #endif 83 84 // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim 85 // with respect to indices and values raw_resize_SparseTensorImpl86 void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { 87 TORCH_CHECK( 88 allow_tensor_metadata_change(), 89 "raw_resize_ ", 90 err_msg_tensor_metadata_change_not_allowed); 91 TORCH_CHECK( 92 !has_symbolic_sizes_strides_, 93 "raw_resize_ called on tensor with symbolic shape") 94 set_sizes_and_strides(size, std::vector<int64_t>(size.size())); 95 sparse_dim_ = sparse_dim; 96 dense_dim_ = dense_dim; 97 refresh_numel(); 98 } 99 100 // NOTE: This function preserves invariants of sparse_dim/dense_dim with 101 // respect to indices and values. 102 // 103 // NOTE: This function supports the following cases: 104 // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking 105 // the size of any of the dense dimensions. 106 // 2. When we keep the number of sparse dimensions unchanged, and NOT 107 // shrinking the size of any of the sparse dimensions. 108 // 3. When the sparse tensor has zero nnz, in which case we are free to change 109 // the shapes of both its sparse and dense dimensions. 110 // 111 // This function DOESN'T support (and will throw an error) the following 112 // cases: 113 // 1. When we attempt to change the number of sparse dimensions on a non-empty 114 // sparse tensor (such an operation will invalidate the indices stored). 115 // 2. When we attempt to change the number of dense dimensions on a non-empty 116 // sparse tensor (such an operation will behave differently from an equivalent 117 // dense tensor's resize method, and for API consistency we don't support it). 118 // 3. When we attempt to shrink the size of any of the dense dimensions on a 119 // non-empty sparse tensor (such an operation will behave differently from an 120 // equivalent dense tensor's resize method, and for API consistency we don't 121 // support it). 122 // 4. When we attempt to shrink the size of any of the sparse dimensions on a 123 // non-empty sparse tensor (this could make some of the stored indices 124 // out-of-bound and thus unsafe). 125 template <typename T> _resize_SparseTensorImpl126 void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) { 127 TORCH_CHECK( 128 allow_tensor_metadata_change(), 129 "resize_ ", 130 err_msg_tensor_metadata_change_not_allowed); 131 TORCH_CHECK( 132 !has_symbolic_sizes_strides_, 133 "resize_ called on tensor with symbolic shape") 134 TORCH_CHECK( 135 sparse_dim + dense_dim == static_cast<int64_t>(size.size()), 136 "number of dimensions must be sparse_dim (", 137 sparse_dim, 138 ") + dense_dim (", 139 dense_dim, 140 "), but got ", 141 size.size()); 142 if (nnz() > 0) { 143 [[maybe_unused]] auto constexpr alt_options_msg = 144 "You could try the following options:\n\ 145 1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\ 146 2. If you need to resize this tensor, you have the following options:\n\ 147 1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\ 148 2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor."; 149 150 TORCH_CHECK( 151 sparse_dim == sparse_dim_, 152 "changing the number of sparse dimensions (from ", 153 sparse_dim_, 154 " to ", 155 sparse_dim, 156 ") on a non-empty sparse tensor is not supported.\n", 157 alt_options_msg); 158 159 TORCH_CHECK( 160 dense_dim == dense_dim_, 161 "changing the number of dense dimensions (from ", 162 dense_dim_, 163 " to ", 164 dense_dim, 165 ") on a non-empty sparse tensor is not supported.\n", 166 alt_options_msg); 167 168 bool shrinking_sparse_dims = false; 169 bool shrinking_dense_dim = false; 170 auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim); 171 auto sparse_size_new = size.slice(0, sparse_dim); 172 for (const auto i : c10::irange(sparse_dim)) { 173 if (sparse_size_new[i] < sparse_size_original[i]) { 174 shrinking_sparse_dims = true; 175 break; 176 } 177 } 178 auto dense_size_original = generic_sizes<T>().slice(sparse_dim); 179 auto dense_size_new = size.slice(sparse_dim); 180 for (const auto i : c10::irange(dense_dim)) { 181 if (dense_size_new[i] < dense_size_original[i]) { 182 shrinking_dense_dim = true; 183 break; 184 } 185 } 186 187 TORCH_CHECK( 188 !shrinking_sparse_dims, 189 "shrinking the size of sparse dimensions (from ", 190 sparse_size_original, 191 " to ", 192 sparse_size_new, 193 ") on a non-empty sparse tensor is not supported.\n", 194 alt_options_msg); 195 196 TORCH_CHECK( 197 !shrinking_dense_dim, 198 "shrinking the size of dense dimensions (from ", 199 dense_size_original, 200 " to ", 201 dense_size_new, 202 ") on a non-empty sparse tensor is not supported.\n", 203 alt_options_msg); 204 } 205 206 auto sizes_and_strides = generic_sizes<T>(); 207 const bool size_equals_sizes = std::equal( 208 size.begin(), 209 size.end(), 210 sizes_and_strides.begin(), 211 sizes_and_strides.end()); 212 if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) || 213 (dense_dim != dense_dim_)) { 214 auto nnz = at::symint::sizes<T>(values())[0]; 215 std::vector<T> values_size = {nnz}; 216 auto dense_size = size.slice(sparse_dim); 217 values_size.insert( 218 values_size.end(), dense_size.begin(), dense_size.end()); 219 at::symint::resize_<T>(values_, values_size); 220 at::symint::resize_<T>(indices_, {T(sparse_dim), nnz}); 221 } 222 223 if (!size_equals_sizes) { 224 set_sizes_and_strides(size, std::vector<T>(size.size())); 225 } 226 sparse_dim_ = sparse_dim; 227 dense_dim_ = dense_dim; 228 refresh_numel(); 229 } 230 resize_SparseTensorImpl231 void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) { 232 return _resize_(sparse_dim, dense_dim, size); 233 } 234 resize_SparseTensorImpl235 void resize_( 236 int64_t sparse_dim, 237 int64_t dense_dim, 238 ArrayRef<c10::SymInt> size) { 239 return _resize_(sparse_dim, dense_dim, size); 240 } 241 242 // NOTE: this function will resize the sparse tensor and also set `indices` 243 // and `values` to empty. resize_and_clear_SparseTensorImpl244 void resize_and_clear_( 245 int64_t sparse_dim, 246 int64_t dense_dim, 247 IntArrayRef size) { 248 TORCH_CHECK( 249 allow_tensor_metadata_change(), 250 "resize_and_clear_ ", 251 err_msg_tensor_metadata_change_not_allowed); 252 TORCH_CHECK( 253 !has_symbolic_sizes_strides_, 254 "resize_and_clear_ called on tensor with symbolic shape") 255 TORCH_CHECK( 256 sparse_dim + dense_dim == static_cast<int64_t>(size.size()), 257 "number of dimensions must be sparse_dim (", 258 sparse_dim, 259 ") + dense_dim (", 260 dense_dim, 261 "), but got ", 262 size.size()); 263 264 set_sizes_and_strides(size, std::vector<int64_t>(size.size())); 265 sparse_dim_ = sparse_dim; 266 dense_dim_ = dense_dim; 267 268 auto empty_indices = at::empty({sparse_dim, 0}, indices().options()); 269 std::vector<int64_t> values_size = {0}; 270 auto dense_size = sizes().slice(sparse_dim); 271 values_size.insert(values_size.end(), dense_size.begin(), dense_size.end()); 272 auto empty_values = at::empty(values_size, values().options()); 273 set_indices_and_values_unsafe(empty_indices, empty_values); 274 refresh_numel(); 275 } 276 set_coalescedSparseTensorImpl277 void set_coalesced(bool coalesced) { 278 TORCH_CHECK( 279 allow_tensor_metadata_change(), 280 "set_coalesced ", 281 err_msg_tensor_metadata_change_not_allowed); 282 coalesced_ = coalesced; 283 } 284 285 // NOTE: this function is only used internally and not exposed to Python 286 // frontend set_nnz_and_narrowSparseTensorImpl287 void set_nnz_and_narrow(int64_t new_nnz) { 288 TORCH_CHECK( 289 allow_tensor_metadata_change(), 290 "set_nnz_and_narrow ", 291 err_msg_tensor_metadata_change_not_allowed); 292 AT_ASSERT(new_nnz <= nnz()); 293 indices_ = indices_.narrow(1, 0, new_nnz); 294 values_ = values_.narrow(0, 0, new_nnz); 295 if (new_nnz < 2) { 296 coalesced_ = true; 297 } 298 } 299 300 // Takes indices and values and directly puts them into the sparse tensor, no 301 // copy. NOTE: this function is unsafe because it doesn't check whether any 302 // indices are out of boundaries of `sizes`, so it should ONLY be used where 303 // we know that the indices are guaranteed to be within bounds. This used to 304 // be called THSTensor_(_move) NB: This used to be able to avoid a refcount 305 // bump, but I was too lazy to make it happen 306 void set_indices_and_values_unsafe( 307 const Tensor& indices, 308 const Tensor& values); 309 310 template <typename VariableVersion> shallow_copy_and_detach_coreSparseTensorImpl311 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core( 312 VariableVersion&& version_counter, 313 bool allow_tensor_metadata_change) const { 314 const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); 315 c10::impl::PyInterpreter&& interpreter = nullptr; 316 if (mode_stack_len > 0 && 317 !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { 318 const auto& cur_torch_dispatch_mode_state = 319 c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); 320 interpreter = cur_torch_dispatch_mode_state->pyinterpreter(); 321 } else if ( 322 key_set_.has(DispatchKey::Python) && 323 !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { 324 interpreter = pyobj_slot_.load_pyobj_interpreter(); 325 } else { 326 // otherwise just copy the SparseTensorImpl and not the PyObject. 327 auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype()); 328 copy_tensor_metadata( 329 /*src_sparse_impl=*/this, 330 /*dest_sparse_impl=*/impl.get(), 331 /*version_counter=*/version_counter, 332 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 333 impl->refresh_numel(); 334 return impl; 335 } 336 auto r = interpreter->detach(this); 337 r->set_version_counter(std::forward<VariableVersion>(version_counter)); 338 r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); 339 return r; 340 } 341 342 /** 343 * Return a TensorImpl that is a shallow-copy of this TensorImpl. 344 * 345 * For usage of `version_counter` and `allow_tensor_metadata_change`, 346 * see NOTE [ TensorImpl Shallow-Copying ]. 347 */ shallow_copy_and_detachSparseTensorImpl348 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 349 const c10::VariableVersion& version_counter, 350 bool allow_tensor_metadata_change) const override { 351 return shallow_copy_and_detach_core( 352 version_counter, allow_tensor_metadata_change); 353 } 354 355 /** 356 * Return a TensorImpl that is a shallow-copy of this TensorImpl. 357 * 358 * For usage of `version_counter` and `allow_tensor_metadata_change`, 359 * see NOTE [ TensorImpl Shallow-Copying ]. 360 */ shallow_copy_and_detachSparseTensorImpl361 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 362 c10::VariableVersion&& version_counter, 363 bool allow_tensor_metadata_change) const override { 364 return shallow_copy_and_detach_core( 365 std::move(version_counter), allow_tensor_metadata_change); 366 } 367 368 /** 369 * Shallow-copies data from another TensorImpl into this TensorImpl. 370 * 371 * For why this function doesn't check this TensorImpl's 372 * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. 373 */ shallow_copy_fromSparseTensorImpl374 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override { 375 AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); 376 auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get()); 377 copy_tensor_metadata( 378 /*src_sparse_impl=*/sparse_impl, 379 /*dest_sparse_impl=*/this, 380 /*version_counter=*/version_counter(), 381 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); 382 refresh_numel(); 383 } 384 385 private: 386 explicit SparseTensorImpl( 387 at::DispatchKeySet, 388 const caffe2::TypeMeta, 389 at::Tensor indices, 390 at::Tensor values); 391 392 /** 393 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / 394 * storage_offset) from one TensorImpl to another TensorImpl. 395 * 396 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE 397 * [ TensorImpl Shallow-Copying ]. 398 */ copy_tensor_metadataSparseTensorImpl399 static void copy_tensor_metadata( 400 const SparseTensorImpl* src_sparse_impl, 401 SparseTensorImpl* dest_sparse_impl, 402 c10::VariableVersion version_counter, 403 bool allow_tensor_metadata_change) { 404 TensorImpl::copy_tensor_metadata( 405 src_sparse_impl, 406 dest_sparse_impl, 407 std::move(version_counter), 408 allow_tensor_metadata_change); 409 410 // Sparse-specific fields 411 dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim(); 412 dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim(); 413 dest_sparse_impl->indices_ = src_sparse_impl->indices(); 414 dest_sparse_impl->values_ = src_sparse_impl->values(); 415 dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced(); 416 } 417 418 const char* tensorimpl_type_name() const override; 419 }; 420 421 } // namespace at 422