1 #include <ATen/ATen.h>
2 #include <ATen/SparseTensorImpl.h>
3 #include <ATen/InitialTensorOptions.h>
4 #include <ATen/core/LegacyTypeDispatch.h>
5
6 namespace at {
7
8 namespace {
sparseTensorSetToDeviceType(DispatchKeySet key_set)9 DeviceType sparseTensorSetToDeviceType(DispatchKeySet key_set) {
10 auto k = c10::highestPriorityBackendTypeId(key_set);
11 TORCH_CHECK(c10::toFunctionalityKey(k) == DispatchKey::Sparse,
12 "cannot create sparse tensor with non sparse dispatch key ", k);
13 return c10::dispatchKeyToDeviceType(k);
14 }
15 }
16
17
18 // An empty dense tensor defaults to a 1-dimensional tensor of size [0]
19 // (recall, it is not a 0-dimensional tensor, because such a tensor would
20 // a scalar and have one element)
21 //
22 // Thus, an empty sparse tensor should be a 1-dimensional tensor of size [0].
23 // Furthermore, we have dim == sparse_dim + dense_dim; since this is a sparse
24 // tensor, let us say that an empty sparse tensor has sparse_dim == 1 and
25 // dense_dim == 0. (There is a degree of freedom here, but given that this
26 // is a sparse dimension, it seems reasonable to demand that sparse_dim > 0).
27 //
28 // This means that we allocate a [1,0] size indices tensor and a [0] size
29 // values tensor for such an empty tensor.
SparseTensorImpl(at::DispatchKeySet key_set,const caffe2::TypeMeta data_type)30 SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type)
31 : SparseTensorImpl(key_set, data_type
32 , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(ScalarType::Long))
33 , at::empty({0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(data_type))) {}
34
SparseTensorImpl(at::DispatchKeySet key_set,const caffe2::TypeMeta data_type,at::Tensor indices,at::Tensor values)35 SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type, at::Tensor indices, at::Tensor values)
36 : TensorImpl(key_set, data_type, values.device())
37 , sparse_dim_(1)
38 , indices_(std::move(indices))
39 , values_(std::move(values)) {
40 // we proxy to this constructor so we can initialize the device correctly, but really only indices/values of this shape are allowed.
41 AT_ASSERT(indices_.sizes() == IntArrayRef({1, 0}));
42 AT_ASSERT(values_.sizes() == IntArrayRef({0}));
43 AT_ASSERT(values_.device() == indices_.device());
44 AT_ASSERT(values_.device() == device());
45
46 is_non_overlapping_and_dense_ = false;
47 set_storage_access_should_throw();
48 set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
49 }
50
51 // Destructor doesn't call release_resources because it's
52 // unnecessary; don't forget to change that if needed!
release_resources()53 void SparseTensorImpl::release_resources() {
54 TensorImpl::release_resources();
55 values_.reset();
56 indices_.reset();
57 }
58
set_size(int64_t dim,int64_t new_size)59 void SparseTensorImpl::set_size(int64_t dim, int64_t new_size) {
60 AT_ERROR("sparse tensors do not have set_size");
61 }
set_stride(int64_t dim,int64_t new_stride)62 void SparseTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
63 AT_ERROR("sparse tensors do not have set_stride");
64 }
set_storage_offset(int64_t storage_offset)65 void SparseTensorImpl::set_storage_offset(int64_t storage_offset) {
66 AT_ERROR("sparse tensors do not have set_storage_offset");
67 }
68 #ifdef DEBUG
has_storage() const69 bool SparseTensorImpl::has_storage() const {
70 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "SparseTensorImpl assumes that storage_ is never set");
71 return false;
72 }
73 #endif
74
tensorimpl_type_name() const75 const char* SparseTensorImpl::tensorimpl_type_name() const {
76 return "SparseTensorImpl";
77 }
78
set_indices_and_values_unsafe(const Tensor & indices,const Tensor & values)79 void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values) {
80 TORCH_CHECK(allow_tensor_metadata_change(), "set_indices_and_values_unsafe ", err_msg_tensor_metadata_change_not_allowed);
81
82 TORCH_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
83 TORCH_CHECK(!values.is_sparse(), "expected values to be a dense tensor, but got values of layout ", values.layout());
84
85 TORCH_CHECK(values.device().type() == device().type(), "device type of values (", values.device().type(), ") must match device type of device().type()", device().type(), ")");
86 TORCH_CHECK(values.scalar_type() == typeMetaToScalarType(dtype()), "dtype of values (", values.scalar_type(), ") must match dtype of sparse tensor (", typeMetaToScalarType(dtype()), ")");
87 TORCH_CHECK(indices.scalar_type() == kLong, "indices must be an int64 tensor");
88 TORCH_CHECK(indices.options().backend() == values.options().backend(), "backend of indices (", indices.options().backend(), ") must match backend of values (", values.options().backend(), ")");
89 TORCH_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")");
90
91 TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sym_sizes());
92 TORCH_CHECK(indices.sym_size(1) == values.sym_size(0), "indices and values must have same nnz, but got nnz from indices: ", indices.sym_size(1), ", nnz from values: ", values.sym_size(0));
93 TORCH_CHECK(indices.sym_size(0) == sparse_dim_, "indices has incorrect first dimension, expected ", sparse_dim_, ", got ", indices.sym_size(0));
94 TORCH_CHECK(values.dim() == dense_dim_ + 1, "values has incorrect number of dimensions, expected ", dense_dim_ + 1, ", got ", values.dim());
95
96 auto dense_size_original = sym_sizes().slice(sparse_dim_);
97 std::vector<c10::SymInt> expected_values_size_vec = {values.sym_size(0)};
98 expected_values_size_vec.insert(expected_values_size_vec.end(), dense_size_original.begin(), dense_size_original.end());
99 SymIntArrayRef expected_values_size(expected_values_size_vec);
100 auto new_values_size = values.sym_sizes();
101 TORCH_CHECK(
102 std::equal(expected_values_size.begin(), expected_values_size.end(), new_values_size.begin()),
103 "values has incorrect size, expected ", expected_values_size, ", got ", new_values_size
104 );
105
106 indices_ = indices;
107 values_ = values;
108 AT_ASSERT(device() == values_.device());
109 AT_ASSERT(values_.device() == indices_.device());
110
111 coalesced_ = TORCH_GUARD_SIZE_OBLIVIOUS(sym_nnz().sym_lt(2));
112 }
113
114
115 } // namespace at
116