1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <c10/core/TensorImpl.h> 5 #include <c10/core/impl/TorchDispatchModeTLS.h> 6 #include <c10/util/Exception.h> 7 namespace at { 8 9 // Struct implementing a sparse CSR tensor. It uses three 1-D tensors for 10 // denoting the data: `crow_indices_`, `col_indices_` and `values_`. 11 // The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)` 12 // that represents the compressed row indices of the CSR tensor. The 13 // `col_indices_` tensor is an integer tensor of shape `(nnz())` 14 // that explicitly stores the column indices of each value of the sparse 15 // tensor. The `values_` tensor can be of any pytorch-supported data type 16 // and has shape `(nnz())`. 17 // 18 // Since the main advantage of the CSR format over the COO format is speed of 19 // computation, care must be taken to facilitate smooth interfacing of 20 // these data structures with optimized libraries such as MKL and MAGMA. 21 // Since the MKL interface for pytorch currently uses indexing with int32 22 // type, it is important to make sure that the `crow_indices` and `col_indices` 23 // are of type int32 when calling MKL routines such as SPMM or SPMV. 24 // 25 // If not calling MKL, it should be alright to use 64 bit integer tensors 26 // for indexing. 27 struct TORCH_API SparseCsrTensorImpl : public TensorImpl { 28 Tensor crow_indices_; 29 Tensor col_indices_; 30 Tensor values_; 31 Layout layout_; 32 33 public: 34 explicit SparseCsrTensorImpl( 35 at::DispatchKeySet, 36 at::Device device, 37 Layout layout, 38 const caffe2::TypeMeta); 39 40 void resize_(int64_t nnz, IntArrayRef size); 41 void resize_and_clear_( 42 int64_t sparse_dim, 43 int64_t dense_dim, 44 IntArrayRef size); 45 void resize_as_sparse_compressed_tensor_(const Tensor& src); 46 void set_member_tensors( 47 const Tensor& crow_indices, 48 const Tensor& col_indices, 49 const Tensor& values, 50 c10::SymIntArrayRef size); 51 void set_member_tensors( 52 const Tensor& crow_indices, 53 const Tensor& col_indices, 54 const Tensor& values, 55 IntArrayRef size); compressed_indicesSparseCsrTensorImpl56 const Tensor& compressed_indices() const { 57 return crow_indices_; 58 } plain_indicesSparseCsrTensorImpl59 const Tensor& plain_indices() const { 60 return col_indices_; 61 } valuesSparseCsrTensorImpl62 const Tensor& values() const { 63 return values_; 64 } nnzSparseCsrTensorImpl65 int64_t nnz() { 66 return col_indices_.size(-1); 67 } 68 batch_dimSparseCsrTensorImpl69 inline int64_t batch_dim() const noexcept { 70 return crow_indices_.dim() - 1; 71 } 72 sparse_dimSparseCsrTensorImpl73 inline int64_t sparse_dim() const noexcept { 74 return 2; 75 } 76 dense_dimSparseCsrTensorImpl77 inline int64_t dense_dim() const noexcept { 78 return values_.dim() - batch_dim() - block_dim() - 1; 79 } 80 81 private: block_dimSparseCsrTensorImpl82 inline int64_t block_dim() const noexcept { 83 return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0); 84 } 85 86 protected: 87 IntArrayRef strides_custom() const override; 88 SymIntArrayRef sym_strides_custom() const override; 89 bool is_contiguous_custom(MemoryFormat) const override; 90 91 public: 92 void set_size(int64_t dim, int64_t new_size) override; 93 void set_stride(int64_t dim, int64_t new_stride) override; 94 void set_storage_offset(int64_t storage_offset) override; layout_implSparseCsrTensorImpl95 Layout layout_impl() const override { 96 return layout_; 97 } set_layoutSparseCsrTensorImpl98 void set_layout(Layout layout) { 99 switch (layout) { 100 case kSparseCsr: 101 case kSparseCsc: 102 case kSparseBsr: 103 case kSparseBsc: 104 layout_ = layout; 105 break; 106 default: 107 TORCH_CHECK(false, "unsupported layout ", layout); 108 } 109 } 110 111 template <typename VariableVersion> shallow_copy_and_detach_coreSparseCsrTensorImpl112 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core( 113 VariableVersion&& version_counter, 114 bool allow_tensor_metadata_change) const { 115 const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); 116 c10::impl::PyInterpreter&& interpreter = nullptr; 117 if (mode_stack_len > 0 && 118 !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { 119 const auto& cur_torch_dispatch_mode_state = 120 c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); 121 interpreter = cur_torch_dispatch_mode_state->pyinterpreter(); 122 } else if ( 123 key_set_.has(DispatchKey::Python) && 124 !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { 125 interpreter = pyobj_slot_.load_pyobj_interpreter(); 126 } else { 127 // otherwise just copy the SparseTensorImpl and not the PyObject. 128 auto impl = c10::make_intrusive<SparseCsrTensorImpl>( 129 key_set(), device(), layout_impl(), dtype()); 130 copy_tensor_metadata( 131 /*src_sparse_impl=*/this, 132 /*dest_sparse_impl=*/impl.get(), 133 /*version_counter=*/version_counter, 134 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); 135 impl->refresh_numel(); 136 return impl; 137 } 138 auto r = interpreter->detach(this); 139 r->set_version_counter(std::forward<VariableVersion>(version_counter)); 140 r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); 141 return r; 142 } 143 144 /** 145 * Return a TensorImpl that is a shallow-copy of this TensorImpl. 146 * 147 * For usage of `version_counter` and `allow_tensor_metadata_change`, 148 * see NOTE [ TensorImpl Shallow-Copying ]. 149 */ shallow_copy_and_detachSparseCsrTensorImpl150 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 151 const c10::VariableVersion& version_counter, 152 bool allow_tensor_metadata_change) const override { 153 return shallow_copy_and_detach_core( 154 version_counter, allow_tensor_metadata_change); 155 } 156 157 /** 158 * Return a TensorImpl that is a shallow-copy of this TensorImpl. 159 * 160 * For usage of `version_counter` and `allow_tensor_metadata_change`, 161 * see NOTE [ TensorImpl Shallow-Copying ]. 162 */ shallow_copy_and_detachSparseCsrTensorImpl163 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 164 c10::VariableVersion&& version_counter, 165 bool allow_tensor_metadata_change) const override { 166 return shallow_copy_and_detach_core( 167 std::move(version_counter), allow_tensor_metadata_change); 168 } 169 170 private: 171 explicit SparseCsrTensorImpl( 172 at::DispatchKeySet key_set, 173 const caffe2::TypeMeta data_type, 174 at::Tensor crow_indices, 175 at::Tensor col_indices, 176 at::Tensor values, 177 at::Layout layout); 178 179 const char* tensorimpl_type_name() const override; 180 181 /** 182 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / 183 * storage_offset) from one TensorImpl to another TensorImpl. 184 * 185 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE 186 * [ TensorImpl Shallow-Copying ]. 187 */ copy_tensor_metadataSparseCsrTensorImpl188 static void copy_tensor_metadata( 189 const SparseCsrTensorImpl* src_sparse_impl, 190 SparseCsrTensorImpl* dest_sparse_impl, 191 c10::VariableVersion version_counter, 192 bool allow_tensor_metadata_change) { 193 TensorImpl::copy_tensor_metadata( 194 src_sparse_impl, 195 dest_sparse_impl, 196 std::move(version_counter), 197 allow_tensor_metadata_change); 198 199 // Sparse-specific fields 200 dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices(); 201 dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices(); 202 dest_sparse_impl->values_ = src_sparse_impl->values(); 203 dest_sparse_impl->layout_ = src_sparse_impl->layout_impl(); 204 } 205 }; 206 } // namespace at 207