xref: /aosp_15_r20/external/pytorch/aten/src/ATen/NestedTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/MemoryOverlap.h>
3 #include <ATen/Tensor.h>
4 #include <c10/core/DispatchKey.h>
5 #include <c10/core/DispatchKeySet.h>
6 #include <c10/core/MemoryFormat.h>
7 #include <c10/core/TensorImpl.h>
8 #include <c10/util/ArrayRef.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/Metaprogramming.h>
11 #include <c10/util/irange.h>
12 
13 namespace at::native {
14 struct NestedTensorImpl;
15 inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
16 int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
17 at::Tensor construct_nested_strides(const at::Tensor& nested_size);
18 at::Tensor construct_offsets(const at::Tensor& nested_size);
19 
20 struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
21   explicit NestedTensorImpl(
22       Storage storage,
23       c10::DispatchKeySet key_set,
24       const caffe2::TypeMeta data_type,
25       at::Tensor nested_sizes,
26       at::Tensor nested_strides,
27       at::Tensor storage_offsets);
28 
29   explicit NestedTensorImpl(
30       const at::Tensor& buffer,
31       at::Tensor nested_sizes,
32       at::Tensor nested_strides,
33       at::Tensor storage_offsets);
34   // assume contiguous, `nested_strides` and `offsets`
35   // can be infered from `nested_sizes`
36   explicit NestedTensorImpl(
37       const at::Tensor& buffer,
38       const at::Tensor& nested_sizes);
39 
40   // This constructor is used creating view tensors from nested tensors
41   explicit NestedTensorImpl(
42       c10::TensorImpl::ImplType impl_type,
43       const at::Tensor& base_tensor,
44       at::Tensor nested_sizes,
45       at::Tensor nested_strides,
46       at::Tensor storage_offsets);
47 
48   // TODO: don't expose private implementation details like this; in
49   // particular, resizing this tensor will mess up our dim() and
50   // callers cannot fix it.
get_nested_sizesNestedTensorImpl51   const Tensor& get_nested_sizes() const {
52     return nested_sizes_;
53   }
54   // TODO: don't expose private implementation details like this
get_nested_stridesNestedTensorImpl55   const Tensor& get_nested_strides() const {
56     return nested_strides_;
57   }
get_storage_offsetsNestedTensorImpl58   const Tensor& get_storage_offsets() const {
59     return storage_offsets_;
60   }
61   // Returns nullopt if the ith dimension is irregular. The ith dimension
62   // of a NestedTensor is regular if the unbound tensors match in
63   // size at the (i-1)th dimension.
64   std::optional<int64_t> opt_size(int64_t d) const;
65 
sizeNestedTensorImpl66   int64_t size(int64_t d) const {
67     std::optional<int64_t> optional_size = this->opt_size(d);
68     TORCH_CHECK(
69         optional_size.has_value(),
70         "Given dimension ",
71         d,
72         " is irregular and does not have a size.");
73     return *optional_size;
74   }
75   /**
76    * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
77    *
78    * The buffer tensor created by this function shares the same storage_impl as
79    * the original nested tensor, and therefore can be seen as a view.
80    *
81    * @return A newly constructed view tensor
82    */
get_bufferNestedTensorImpl83   at::Tensor get_buffer() const {
84     TORCH_CHECK(
85         nested_tensor_impl_is_contiguous(this),
86         "NestedTensor must be contiguous to get buffer.");
87     return get_unsafe_storage_as_tensor();
88   }
89   /**
90    * If possible use get_buffer() instead. This function returns the storage
91    * as a tensor directly, which is not safe to use in general. If using this
92    * function, The caller must ensure to account for nested_sizes,
93    * nested_strides and storage_offsets.
94    *
95    * @return A newly constructed view tensor
96    */
get_unsafe_storage_as_tensorNestedTensorImpl97   at::Tensor get_unsafe_storage_as_tensor() const {
98     auto buffer_key_set_ = generate_buffer_key_set();
99     const auto buffer_size = get_buffer_size();
100     auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
101         c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
102     buffer_tensor_impl->set_sizes_contiguous(
103         c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
104     return Tensor(buffer_tensor_impl);
105   }
106 
get_buffer_sizeNestedTensorImpl107   size_t get_buffer_size() const {
108     return storage_.nbytes() / data_type_.itemsize();
109   }
110 
111  protected:
112   const char* tensorimpl_type_name() const override;
113 
114   // TODO: numel_custom and is_contiguous_custom can be profitably overridden
115   // with real implementations
116   int64_t numel_custom() const override;
117   c10::SymInt sym_numel_custom() const override;
118   bool is_contiguous_custom(MemoryFormat) const override;
size_customNestedTensorImpl119   int64_t size_custom(int64_t d) const override {
120     return this->size(d);
121   }
sym_size_customNestedTensorImpl122   c10::SymInt sym_size_custom(int64_t d) const override {
123     return c10::SymInt{this->size(d)};
124   }
125   IntArrayRef sizes_custom() const override;
126   c10::SymIntArrayRef sym_sizes_custom() const override;
127   IntArrayRef strides_custom() const override;
128   c10::SymIntArrayRef sym_strides_custom() const override;
129 
130   // this one is real
131   int64_t dim_custom() const override;
132 
133   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
134       const c10::VariableVersion& version_counter,
135       bool allow_tensor_metadata_change) const override;
136 
137   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
138       c10::VariableVersion&& version_counter,
139       bool allow_tensor_metadata_change) const override;
140 
shallow_copy_fromNestedTensorImpl141   void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
142     copy_tensor_metadata(
143         /*src_impl=*/impl.get(),
144         /*dest_impl=*/this,
145         /*version_counter=*/version_counter(),
146         /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
147   }
148 
149  private:
150   // Must be called after any changes to our dim() to sync the state
151   // to TensorImpl.
152   void refresh_dim();
153 
154   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
155   const at::Tensor nested_sizes_, nested_strides_;
156   // The starting positions of the underlying tensors in contiguous buffer
157   // i.e. the buffer memory offsets to get the underlying tensors
158   // The reason to keep this metadata is that, without strong enough constraint
159   // it cannot be derived from `nested_sizes_`
160   // and `nested_strides_`:
161   // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
162   //    this can happen e.g. after slicing a nested tensor
163   // 2. when multiple tensors share a same memory
164   // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
165   // Some strong enough constraints are:
166   // 1. every underlying tensor is contiguous in memory
167   //    && nesting in ascending order
168   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
169   const at::Tensor storage_offsets_;
170   // NOTE: -1 here means the size is missing
171   // Optional to allow it to be computed lazily from nested.
172   // TODO: maybe we can remove this metadata since
173   //       we can compute it from `nested_sizes_`
174   mutable std::optional<std::vector<int64_t>> opt_sizes_;
175 
176   template <typename VariableVersion>
177   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
178       VariableVersion&& version_counter,
179       bool allow_tensor_metadata_change) const;
180 
181   /**
182    * Generates a non-nested key_set from a nested tensor.
183    *
184    * For many nested tensor kernel implementations a buffer tensor
185    * is generated and redispatched to a non-nested kernel this function
186    * generates the key set used by that buffer tensor
187    *
188    * @return Appropriate key set for non-nested tensor
189    */
generate_buffer_key_setNestedTensorImpl190   inline c10::DispatchKeySet generate_buffer_key_set() const {
191     auto buffer_key_set = this->key_set();
192     const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
193     // Remove nested tensor specific keys
194     buffer_key_set = buffer_key_set -
195         c10::DispatchKeySet{
196             c10::DispatchKey::NestedTensor,
197             c10::DispatchKey::AutogradNestedTensor};
198 
199     // Add dense tensor specific keys
200     buffer_key_set =
201         buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
202     buffer_key_set = Autograd
203         ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
204         : buffer_key_set;
205 
206     return buffer_key_set;
207   }
208 };
209 
get_nested_tensor_impl_or_null(const at::Tensor & tensor)210 inline NestedTensorImpl* get_nested_tensor_impl_or_null(
211     const at::Tensor& tensor) {
212   if (tensor.is_nested()) {
213     return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
214   }
215   return nullptr;
216 }
217 
get_nested_tensor_impl(const at::Tensor & tensor)218 inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
219   TORCH_CHECK(
220       tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
221   return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
222 }
223 
nested_tensor_impl_is_contiguous(const NestedTensorImpl * nt)224 inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
225   int64_t ntensors = nt->size(0);
226   if (ntensors == 0) {
227     return true;
228   }
229   const Tensor &sizemat = nt->get_nested_sizes(),
230                &stridemat = nt->get_nested_strides();
231   const int64_t* offsets_ptr =
232       nt->get_storage_offsets().const_data_ptr<int64_t>();
233   int64_t orig_dim = sizemat.size(1);
234   // nesting scalars
235   if (orig_dim == 0) {
236     // each scalar must be contiguous
237     // if there is blank memory between underlying scalars
238     for (int64_t i = 0; i < ntensors; i++) {
239       if (offsets_ptr[i] != i) {
240         return false;
241       }
242     }
243   }
244   // nesting tensors
245   else {
246     // if any underlying tensor is non-contiguous
247     const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
248                   *stridemat_ptr = stridemat.const_data_ptr<int64_t>();
249     for (int64_t i = 0; i < ntensors; i++) {
250       if (stridemat_ptr[orig_dim - 1] != 1) {
251         return false;
252       }
253       int64_t product = sizemat_ptr[orig_dim - 1];
254       for (int64_t j = orig_dim - 2; j >= 0; j--) {
255         if (stridemat_ptr[j] != product) {
256           return false;
257         }
258         product *= sizemat_ptr[j];
259       }
260       sizemat_ptr += orig_dim;
261       stridemat_ptr += orig_dim;
262     }
263     // if there is blank memory between underlying tensors
264     if (offsets_ptr[0] != 0) {
265       return false;
266     }
267     sizemat_ptr = sizemat.const_data_ptr<int64_t>();
268     stridemat_ptr = stridemat.const_data_ptr<int64_t>();
269     for (int64_t i = 1; i < ntensors; i++) {
270       if (offsets_ptr[i] !=
271           offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
272         return false;
273       }
274       sizemat_ptr += orig_dim;
275       stridemat_ptr += orig_dim;
276     }
277   }
278   // everything is fine
279   return true;
280 }
281 
get_nested_sizes(const at::Tensor & tensor)282 inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
283   return get_nested_tensor_impl(tensor)->get_nested_sizes();
284 }
285 
286 } // namespace at::native
287