xref: /aosp_15_r20/external/pytorch/aten/src/ATen/SparseCsrTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/core/TensorImpl.h>
5 #include <c10/core/impl/TorchDispatchModeTLS.h>
6 #include <c10/util/Exception.h>
7 namespace at {
8 
9 // Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
10 // denoting the data: `crow_indices_`, `col_indices_` and `values_`.
11 // The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
12 // that represents the compressed row indices of the CSR tensor. The
13 // `col_indices_` tensor is an integer tensor of shape `(nnz())`
14 // that explicitly stores the column indices of each value of the sparse
15 // tensor. The `values_` tensor can be of any pytorch-supported data type
16 // and has shape `(nnz())`.
17 //
18 // Since the main advantage of the CSR format over the COO format is speed of
19 // computation, care must be taken to facilitate smooth interfacing of
20 // these data structures with optimized libraries such as MKL and MAGMA.
21 // Since the MKL interface for pytorch currently uses indexing with int32
22 // type, it is important to make sure that the `crow_indices` and `col_indices`
23 // are of type int32 when calling MKL routines such as SPMM or SPMV.
24 //
25 // If not calling MKL, it should be alright to use 64 bit integer tensors
26 // for indexing.
27 struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
28   Tensor crow_indices_;
29   Tensor col_indices_;
30   Tensor values_;
31   Layout layout_;
32 
33  public:
34   explicit SparseCsrTensorImpl(
35       at::DispatchKeySet,
36       at::Device device,
37       Layout layout,
38       const caffe2::TypeMeta);
39 
40   void resize_(int64_t nnz, IntArrayRef size);
41   void resize_and_clear_(
42       int64_t sparse_dim,
43       int64_t dense_dim,
44       IntArrayRef size);
45   void resize_as_sparse_compressed_tensor_(const Tensor& src);
46   void set_member_tensors(
47       const Tensor& crow_indices,
48       const Tensor& col_indices,
49       const Tensor& values,
50       c10::SymIntArrayRef size);
51   void set_member_tensors(
52       const Tensor& crow_indices,
53       const Tensor& col_indices,
54       const Tensor& values,
55       IntArrayRef size);
compressed_indicesSparseCsrTensorImpl56   const Tensor& compressed_indices() const {
57     return crow_indices_;
58   }
plain_indicesSparseCsrTensorImpl59   const Tensor& plain_indices() const {
60     return col_indices_;
61   }
valuesSparseCsrTensorImpl62   const Tensor& values() const {
63     return values_;
64   }
nnzSparseCsrTensorImpl65   int64_t nnz() {
66     return col_indices_.size(-1);
67   }
68 
batch_dimSparseCsrTensorImpl69   inline int64_t batch_dim() const noexcept {
70     return crow_indices_.dim() - 1;
71   }
72 
sparse_dimSparseCsrTensorImpl73   inline int64_t sparse_dim() const noexcept {
74     return 2;
75   }
76 
dense_dimSparseCsrTensorImpl77   inline int64_t dense_dim() const noexcept {
78     return values_.dim() - batch_dim() - block_dim() - 1;
79   }
80 
81  private:
block_dimSparseCsrTensorImpl82   inline int64_t block_dim() const noexcept {
83     return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0);
84   }
85 
86  protected:
87   IntArrayRef strides_custom() const override;
88   SymIntArrayRef sym_strides_custom() const override;
89   bool is_contiguous_custom(MemoryFormat) const override;
90 
91  public:
92   void set_size(int64_t dim, int64_t new_size) override;
93   void set_stride(int64_t dim, int64_t new_stride) override;
94   void set_storage_offset(int64_t storage_offset) override;
layout_implSparseCsrTensorImpl95   Layout layout_impl() const override {
96     return layout_;
97   }
set_layoutSparseCsrTensorImpl98   void set_layout(Layout layout) {
99     switch (layout) {
100       case kSparseCsr:
101       case kSparseCsc:
102       case kSparseBsr:
103       case kSparseBsc:
104         layout_ = layout;
105         break;
106       default:
107         TORCH_CHECK(false, "unsupported layout ", layout);
108     }
109   }
110 
111   template <typename VariableVersion>
shallow_copy_and_detach_coreSparseCsrTensorImpl112   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
113       VariableVersion&& version_counter,
114       bool allow_tensor_metadata_change) const {
115     const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
116     c10::impl::PyInterpreter&& interpreter = nullptr;
117     if (mode_stack_len > 0 &&
118         !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
119       const auto& cur_torch_dispatch_mode_state =
120           c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
121       interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
122     } else if (
123         key_set_.has(DispatchKey::Python) &&
124         !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
125       interpreter = pyobj_slot_.load_pyobj_interpreter();
126     } else {
127       // otherwise just copy the SparseTensorImpl and not the PyObject.
128       auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
129           key_set(), device(), layout_impl(), dtype());
130       copy_tensor_metadata(
131           /*src_sparse_impl=*/this,
132           /*dest_sparse_impl=*/impl.get(),
133           /*version_counter=*/version_counter,
134           /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
135       impl->refresh_numel();
136       return impl;
137     }
138     auto r = interpreter->detach(this);
139     r->set_version_counter(std::forward<VariableVersion>(version_counter));
140     r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
141     return r;
142   }
143 
144   /**
145    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
146    *
147    * For usage of `version_counter` and `allow_tensor_metadata_change`,
148    * see NOTE [ TensorImpl Shallow-Copying ].
149    */
shallow_copy_and_detachSparseCsrTensorImpl150   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
151       const c10::VariableVersion& version_counter,
152       bool allow_tensor_metadata_change) const override {
153     return shallow_copy_and_detach_core(
154         version_counter, allow_tensor_metadata_change);
155   }
156 
157   /**
158    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
159    *
160    * For usage of `version_counter` and `allow_tensor_metadata_change`,
161    * see NOTE [ TensorImpl Shallow-Copying ].
162    */
shallow_copy_and_detachSparseCsrTensorImpl163   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
164       c10::VariableVersion&& version_counter,
165       bool allow_tensor_metadata_change) const override {
166     return shallow_copy_and_detach_core(
167         std::move(version_counter), allow_tensor_metadata_change);
168   }
169 
170  private:
171   explicit SparseCsrTensorImpl(
172       at::DispatchKeySet key_set,
173       const caffe2::TypeMeta data_type,
174       at::Tensor crow_indices,
175       at::Tensor col_indices,
176       at::Tensor values,
177       at::Layout layout);
178 
179   const char* tensorimpl_type_name() const override;
180 
181   /**
182    * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
183    * storage_offset) from one TensorImpl to another TensorImpl.
184    *
185    * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
186    * [ TensorImpl Shallow-Copying ].
187    */
copy_tensor_metadataSparseCsrTensorImpl188   static void copy_tensor_metadata(
189       const SparseCsrTensorImpl* src_sparse_impl,
190       SparseCsrTensorImpl* dest_sparse_impl,
191       c10::VariableVersion version_counter,
192       bool allow_tensor_metadata_change) {
193     TensorImpl::copy_tensor_metadata(
194         src_sparse_impl,
195         dest_sparse_impl,
196         std::move(version_counter),
197         allow_tensor_metadata_change);
198 
199     // Sparse-specific fields
200     dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices();
201     dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices();
202     dest_sparse_impl->values_ = src_sparse_impl->values();
203     dest_sparse_impl->layout_ = src_sparse_impl->layout_impl();
204   }
205 };
206 } // namespace at
207