1 #pragma once
2 #include <ATen/MemoryOverlap.h>
3 #include <ATen/Tensor.h>
4 #include <c10/core/DispatchKey.h>
5 #include <c10/core/DispatchKeySet.h>
6 #include <c10/core/MemoryFormat.h>
7 #include <c10/core/TensorImpl.h>
8 #include <c10/util/ArrayRef.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/Metaprogramming.h>
11 #include <c10/util/irange.h>
12
13 namespace at::native {
14 struct NestedTensorImpl;
15 inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
16 int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
17 at::Tensor construct_nested_strides(const at::Tensor& nested_size);
18 at::Tensor construct_offsets(const at::Tensor& nested_size);
19
20 struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
21 explicit NestedTensorImpl(
22 Storage storage,
23 c10::DispatchKeySet key_set,
24 const caffe2::TypeMeta data_type,
25 at::Tensor nested_sizes,
26 at::Tensor nested_strides,
27 at::Tensor storage_offsets);
28
29 explicit NestedTensorImpl(
30 const at::Tensor& buffer,
31 at::Tensor nested_sizes,
32 at::Tensor nested_strides,
33 at::Tensor storage_offsets);
34 // assume contiguous, `nested_strides` and `offsets`
35 // can be infered from `nested_sizes`
36 explicit NestedTensorImpl(
37 const at::Tensor& buffer,
38 const at::Tensor& nested_sizes);
39
40 // This constructor is used creating view tensors from nested tensors
41 explicit NestedTensorImpl(
42 c10::TensorImpl::ImplType impl_type,
43 const at::Tensor& base_tensor,
44 at::Tensor nested_sizes,
45 at::Tensor nested_strides,
46 at::Tensor storage_offsets);
47
48 // TODO: don't expose private implementation details like this; in
49 // particular, resizing this tensor will mess up our dim() and
50 // callers cannot fix it.
get_nested_sizesNestedTensorImpl51 const Tensor& get_nested_sizes() const {
52 return nested_sizes_;
53 }
54 // TODO: don't expose private implementation details like this
get_nested_stridesNestedTensorImpl55 const Tensor& get_nested_strides() const {
56 return nested_strides_;
57 }
get_storage_offsetsNestedTensorImpl58 const Tensor& get_storage_offsets() const {
59 return storage_offsets_;
60 }
61 // Returns nullopt if the ith dimension is irregular. The ith dimension
62 // of a NestedTensor is regular if the unbound tensors match in
63 // size at the (i-1)th dimension.
64 std::optional<int64_t> opt_size(int64_t d) const;
65
sizeNestedTensorImpl66 int64_t size(int64_t d) const {
67 std::optional<int64_t> optional_size = this->opt_size(d);
68 TORCH_CHECK(
69 optional_size.has_value(),
70 "Given dimension ",
71 d,
72 " is irregular and does not have a size.");
73 return *optional_size;
74 }
75 /**
76 * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
77 *
78 * The buffer tensor created by this function shares the same storage_impl as
79 * the original nested tensor, and therefore can be seen as a view.
80 *
81 * @return A newly constructed view tensor
82 */
get_bufferNestedTensorImpl83 at::Tensor get_buffer() const {
84 TORCH_CHECK(
85 nested_tensor_impl_is_contiguous(this),
86 "NestedTensor must be contiguous to get buffer.");
87 return get_unsafe_storage_as_tensor();
88 }
89 /**
90 * If possible use get_buffer() instead. This function returns the storage
91 * as a tensor directly, which is not safe to use in general. If using this
92 * function, The caller must ensure to account for nested_sizes,
93 * nested_strides and storage_offsets.
94 *
95 * @return A newly constructed view tensor
96 */
get_unsafe_storage_as_tensorNestedTensorImpl97 at::Tensor get_unsafe_storage_as_tensor() const {
98 auto buffer_key_set_ = generate_buffer_key_set();
99 const auto buffer_size = get_buffer_size();
100 auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
101 c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
102 buffer_tensor_impl->set_sizes_contiguous(
103 c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
104 return Tensor(buffer_tensor_impl);
105 }
106
get_buffer_sizeNestedTensorImpl107 size_t get_buffer_size() const {
108 return storage_.nbytes() / data_type_.itemsize();
109 }
110
111 protected:
112 const char* tensorimpl_type_name() const override;
113
114 // TODO: numel_custom and is_contiguous_custom can be profitably overridden
115 // with real implementations
116 int64_t numel_custom() const override;
117 c10::SymInt sym_numel_custom() const override;
118 bool is_contiguous_custom(MemoryFormat) const override;
size_customNestedTensorImpl119 int64_t size_custom(int64_t d) const override {
120 return this->size(d);
121 }
sym_size_customNestedTensorImpl122 c10::SymInt sym_size_custom(int64_t d) const override {
123 return c10::SymInt{this->size(d)};
124 }
125 IntArrayRef sizes_custom() const override;
126 c10::SymIntArrayRef sym_sizes_custom() const override;
127 IntArrayRef strides_custom() const override;
128 c10::SymIntArrayRef sym_strides_custom() const override;
129
130 // this one is real
131 int64_t dim_custom() const override;
132
133 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
134 const c10::VariableVersion& version_counter,
135 bool allow_tensor_metadata_change) const override;
136
137 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
138 c10::VariableVersion&& version_counter,
139 bool allow_tensor_metadata_change) const override;
140
shallow_copy_fromNestedTensorImpl141 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
142 copy_tensor_metadata(
143 /*src_impl=*/impl.get(),
144 /*dest_impl=*/this,
145 /*version_counter=*/version_counter(),
146 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
147 }
148
149 private:
150 // Must be called after any changes to our dim() to sync the state
151 // to TensorImpl.
152 void refresh_dim();
153
154 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
155 const at::Tensor nested_sizes_, nested_strides_;
156 // The starting positions of the underlying tensors in contiguous buffer
157 // i.e. the buffer memory offsets to get the underlying tensors
158 // The reason to keep this metadata is that, without strong enough constraint
159 // it cannot be derived from `nested_sizes_`
160 // and `nested_strides_`:
161 // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
162 // this can happen e.g. after slicing a nested tensor
163 // 2. when multiple tensors share a same memory
164 // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
165 // Some strong enough constraints are:
166 // 1. every underlying tensor is contiguous in memory
167 // && nesting in ascending order
168 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
169 const at::Tensor storage_offsets_;
170 // NOTE: -1 here means the size is missing
171 // Optional to allow it to be computed lazily from nested.
172 // TODO: maybe we can remove this metadata since
173 // we can compute it from `nested_sizes_`
174 mutable std::optional<std::vector<int64_t>> opt_sizes_;
175
176 template <typename VariableVersion>
177 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
178 VariableVersion&& version_counter,
179 bool allow_tensor_metadata_change) const;
180
181 /**
182 * Generates a non-nested key_set from a nested tensor.
183 *
184 * For many nested tensor kernel implementations a buffer tensor
185 * is generated and redispatched to a non-nested kernel this function
186 * generates the key set used by that buffer tensor
187 *
188 * @return Appropriate key set for non-nested tensor
189 */
generate_buffer_key_setNestedTensorImpl190 inline c10::DispatchKeySet generate_buffer_key_set() const {
191 auto buffer_key_set = this->key_set();
192 const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
193 // Remove nested tensor specific keys
194 buffer_key_set = buffer_key_set -
195 c10::DispatchKeySet{
196 c10::DispatchKey::NestedTensor,
197 c10::DispatchKey::AutogradNestedTensor};
198
199 // Add dense tensor specific keys
200 buffer_key_set =
201 buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
202 buffer_key_set = Autograd
203 ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
204 : buffer_key_set;
205
206 return buffer_key_set;
207 }
208 };
209
get_nested_tensor_impl_or_null(const at::Tensor & tensor)210 inline NestedTensorImpl* get_nested_tensor_impl_or_null(
211 const at::Tensor& tensor) {
212 if (tensor.is_nested()) {
213 return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
214 }
215 return nullptr;
216 }
217
get_nested_tensor_impl(const at::Tensor & tensor)218 inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
219 TORCH_CHECK(
220 tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
221 return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
222 }
223
nested_tensor_impl_is_contiguous(const NestedTensorImpl * nt)224 inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
225 int64_t ntensors = nt->size(0);
226 if (ntensors == 0) {
227 return true;
228 }
229 const Tensor &sizemat = nt->get_nested_sizes(),
230 &stridemat = nt->get_nested_strides();
231 const int64_t* offsets_ptr =
232 nt->get_storage_offsets().const_data_ptr<int64_t>();
233 int64_t orig_dim = sizemat.size(1);
234 // nesting scalars
235 if (orig_dim == 0) {
236 // each scalar must be contiguous
237 // if there is blank memory between underlying scalars
238 for (int64_t i = 0; i < ntensors; i++) {
239 if (offsets_ptr[i] != i) {
240 return false;
241 }
242 }
243 }
244 // nesting tensors
245 else {
246 // if any underlying tensor is non-contiguous
247 const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
248 *stridemat_ptr = stridemat.const_data_ptr<int64_t>();
249 for (int64_t i = 0; i < ntensors; i++) {
250 if (stridemat_ptr[orig_dim - 1] != 1) {
251 return false;
252 }
253 int64_t product = sizemat_ptr[orig_dim - 1];
254 for (int64_t j = orig_dim - 2; j >= 0; j--) {
255 if (stridemat_ptr[j] != product) {
256 return false;
257 }
258 product *= sizemat_ptr[j];
259 }
260 sizemat_ptr += orig_dim;
261 stridemat_ptr += orig_dim;
262 }
263 // if there is blank memory between underlying tensors
264 if (offsets_ptr[0] != 0) {
265 return false;
266 }
267 sizemat_ptr = sizemat.const_data_ptr<int64_t>();
268 stridemat_ptr = stridemat.const_data_ptr<int64_t>();
269 for (int64_t i = 1; i < ntensors; i++) {
270 if (offsets_ptr[i] !=
271 offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
272 return false;
273 }
274 sizemat_ptr += orig_dim;
275 stridemat_ptr += orig_dim;
276 }
277 }
278 // everything is fine
279 return true;
280 }
281
get_nested_sizes(const at::Tensor & tensor)282 inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
283 return get_nested_tensor_impl(tensor)->get_nested_sizes();
284 }
285
286 } // namespace at::native
287