xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Dispatch.h>
4 #include <ATen/NestedTensorImpl.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/core/Tensor.h>
7 #include <c10/core/DispatchKeySet.h>
8 #include <c10/core/TensorImpl.h>
9 #include <c10/macros/Macros.h>
10 #include <c10/util/Exception.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/cat.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/ones_native.h>
20 #include <ATen/ops/prod.h>
21 #include <ATen/ops/stack_native.h>
22 #include <ATen/ops/tensor.h>
23 #endif
24 
25 #include <utility>
26 #include <vector>
27 
28 namespace at::native {
29 struct NestedTensorImpl;
30 
31 // The following functions are used to construct nested tensors from buffers and
32 // metadata.
33 
wrap_buffer(at::Tensor buffer,at::Tensor nested_sizes)34 inline at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_sizes) {
35   TORCH_CHECK(
36       buffer.dim() == 1,
37       "Expected given buffer to be 1dim, but got ",
38       buffer.dim(),
39       " instead.");
40   TORCH_CHECK(
41       buffer.is_contiguous(), "Expected given buffer to be contiguous.");
42   return at::detail::make_tensor<NestedTensorImpl>(
43       std::move(buffer), std::move(nested_sizes));
44 }
45 
46 // TODO: Figure out if we need a non-moving wrap_buffer()
wrap_buffer(at::Tensor buffer,at::Tensor nested_sizes,at::Tensor nested_strides,at::Tensor storage_offsets)47 inline at::Tensor wrap_buffer(
48     at::Tensor buffer,
49     at::Tensor nested_sizes,
50     at::Tensor nested_strides,
51     at::Tensor storage_offsets) {
52   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
53       buffer.is_contiguous(), "Given buffer must be contiguous.");
54   return at::detail::make_tensor<NestedTensorImpl>(
55       std::move(buffer),
56       std::move(nested_sizes),
57       std::move(nested_strides),
58       std::move(storage_offsets));
59 }
60 
get_buffer(const at::Tensor & tensor)61 inline at::Tensor get_buffer(const at::Tensor& tensor) {
62   return get_nested_tensor_impl(tensor)->get_buffer();
63 }
64 
65 /**
66  * Create a new nested tensor that is a view of a base nested tensor
67  *
68  * create_view_tensor calls a specialized constructor that copys the
69  * the keys from base onto the new view tensor being created.
70  * The storage is shared between the base and the returned view tensor
71  *
72  * All callers of this helper must:
73  * - Only return a view of the input
74  * - Must be explicit and define a derivative
75  *
76  * @param base Base tensor to construct view from.
77  * @param nested_sizes View tensors' sizes.
78  * @param nested_strides View tensors' strides.
79  * @param storage_offsets View tensors' offsets.
80  * @return A newly constructed view tensor
81  */
create_nested_view_tensor(const at::Tensor & base,at::Tensor nested_sizes,at::Tensor nested_strides,at::Tensor storage_offsets)82 inline at::Tensor create_nested_view_tensor(
83     const at::Tensor& base,
84     at::Tensor nested_sizes,
85     at::Tensor nested_strides,
86     at::Tensor storage_offsets) {
87   TORCH_INTERNAL_ASSERT(
88       base.is_nested(),
89       "This function can only be used to create nested tensor views");
90   TORCH_INTERNAL_ASSERT(
91       c10::impl::tls_local_dispatch_key_set().excluded_.has(
92           c10::DispatchKey::AutogradFunctionality),
93       "Creating a non differentiable nested tensor view in a CompositeImplicit function is not allowed.");
94   return at::detail::make_tensor<NestedTensorImpl>(
95       c10::TensorImpl::VIEW,
96       base,
97       nested_sizes,
98       nested_strides,
99       storage_offsets);
100 }
101 //  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102 
103 // Helper functions for getting information about a nested tensor's shape.
104 
105 int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
106 
107 // The sizes of the underlying tensors
NestedTensor_get_sizes(const NestedTensorImpl * self_ptr)108 inline std::vector<IntArrayRef> NestedTensor_get_sizes(
109     const NestedTensorImpl* self_ptr) {
110   int64_t ntensors = self_ptr->size(0);
111   std::vector<IntArrayRef> sizes(ntensors);
112   if (ntensors == 0) {
113     return sizes;
114   }
115   const Tensor& sizemat = self_ptr->get_nested_sizes();
116   int64_t orig_dim = sizemat.size(1);
117   // nesting scalars has empty sizes
118   if (orig_dim == 0) {
119     return sizes;
120   }
121   const int64_t* sizemat_ptr = sizemat.const_data_ptr<int64_t>();
122 
123   for (const auto i : c10::irange(ntensors)) {
124     sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
125     sizemat_ptr += orig_dim;
126   }
127   return sizes;
128 }
129 
130 TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
131     const NestedTensorImpl& nt);
132 
133 std::vector<int64_t> NestedTensor_get_max_size_from_size_tensor(
134     const Tensor& sizes);
135 
NestedTensor_get_sizes(const at::Tensor & self)136 inline std::vector<IntArrayRef> NestedTensor_get_sizes(const at::Tensor& self) {
137   const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
138   return NestedTensor_get_sizes(self_ptr);
139 }
140 // The strides of the underlying tensors
NestedTensor_get_strides(const NestedTensorImpl * self_ptr)141 inline std::vector<IntArrayRef> NestedTensor_get_strides(
142     const NestedTensorImpl* self_ptr) {
143   int64_t ntensors = self_ptr->size(0);
144   std::vector<IntArrayRef> strides(ntensors);
145   if (ntensors == 0) {
146     return strides;
147   }
148   const Tensor& stridemat = self_ptr->get_nested_strides();
149   int64_t orig_dim = stridemat.size(1);
150   // nesting scalars has empty strides
151   if (orig_dim == 0) {
152     return strides;
153   }
154   const int64_t* stridemat_ptr = stridemat.const_data_ptr<int64_t>();
155   for (const auto i : c10::irange(ntensors)) {
156     strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim);
157     stridemat_ptr += orig_dim;
158   }
159   return strides;
160 }
161 
NestedTensor_get_strides(const at::Tensor & self)162 inline std::vector<IntArrayRef> NestedTensor_get_strides(
163     const at::Tensor& self) {
164   const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
165   return NestedTensor_get_strides(self_ptr);
166 }
167 
check_numel_equals_buffer_size(const at::Tensor & self)168 inline void check_numel_equals_buffer_size(const at::Tensor& self) {
169   auto self_impl = get_nested_tensor_impl(self);
170   TORCH_CHECK(
171       self.numel() == static_cast<int64_t>(self_impl->get_buffer_size()),
172       "Number of elements in nested tensor must match number of elements in buffer.");
173 }
174 
check_numel_equals_buffer_size(const NestedTensorImpl * self_ptr)175 inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) {
176   TORCH_CHECK(
177       self_ptr->numel() == static_cast<int64_t>(self_ptr->get_buffer_size()),
178       "Number of elements in nested tensor must match number of elements in buffer.");
179 }
180 
181 // Helper function to get size / stride / offset for a nested/normal tensor.
get_size_for_index(const Tensor & tensor,int64_t i)182 inline IntArrayRef get_size_for_index(const Tensor& tensor, int64_t i) {
183   if (tensor.is_nested()) {
184     std::vector<IntArrayRef> tensor_sizes =
185         NestedTensor_get_sizes(get_nested_tensor_impl(tensor));
186     return tensor_sizes[i];
187   } else {
188     return tensor.sizes().slice(1);
189   }
190 }
191 
get_stride_for_index(const Tensor & tensor,int64_t i)192 inline IntArrayRef get_stride_for_index(const Tensor& tensor, int64_t i) {
193   if (tensor.is_nested()) {
194     std::vector<IntArrayRef> tensor_strides =
195         NestedTensor_get_strides(get_nested_tensor_impl(tensor));
196     return tensor_strides[i];
197   } else {
198     return tensor.strides().slice(1);
199   }
200 }
201 
get_offset_for_index(const Tensor & tensor,int64_t i)202 inline int64_t get_offset_for_index(const Tensor& tensor, int64_t i) {
203   if (tensor.is_nested()) {
204     int64_t* offsets_ptr = get_nested_tensor_impl(tensor)
205                                ->get_storage_offsets()
206                                .data_ptr<int64_t>();
207     return offsets_ptr[i];
208 
209   } else {
210     int64_t offset = tensor.storage_offset();
211     return offset + tensor.strides()[0] * i;
212   }
213 }
214 //  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
215 // Data structures and functions for generically applying a function on a nested
216 // tensor.
217 namespace impl {
218 
219 template <typename T>
220 struct NestedNode {
221   NestedNode() = delete;
NestedNodeNestedNode222   explicit NestedNode(std::vector<T> children)
223       : _is_leaf(false), _children(std::move(children)) {}
NestedNodeNestedNode224   explicit NestedNode(TensorList children)
225       : _is_leaf(false), _children(children.vec()) {}
NestedNodeNestedNode226   explicit NestedNode(T payload)
227       : _is_leaf(true), _payload(std::move(payload)) {}
228   NestedNode(const NestedNode&) = delete;
229   NestedNode& operator=(const NestedNode&) = delete;
230   NestedNode(NestedNode&&) noexcept = default;
231   NestedNode& operator=(NestedNode&&) noexcept = default;
is_leafNestedNode232   inline bool is_leaf() const {
233     return _is_leaf;
234   }
degreeNestedNode235   inline size_t degree() const {
236     return _children.size();
237   }
unbindNestedNode238   inline const std::vector<T> unbind() const {
239     return _children;
240   }
childrenNestedNode241   inline T children(size_t i) const {
242     return _children[i];
243   }
payloadNestedNode244   inline const T& payload() const {
245     return _payload;
246   }
payloadNestedNode247   inline T& payload() {
248     return _payload;
249   }
250 
251  private:
252   bool _is_leaf;
253   std::vector<T> _children;
254   T _payload{};
255 };
256 
257 using TensorNode = NestedNode<at::Tensor>;
258 
259 template <class F, class A, class TypeList>
260 class _map;
261 
262 template <class F, class A, class... Args>
263 class _map<F, A, c10::guts::typelist::typelist<Args...>> {
264  public:
function_one(F && fn,const Args &...nested_node)265   static A function_one(F&& fn, const Args&... nested_node) {
266     return std::forward<F>(fn)(nested_node...);
267   }
function(const F & fn,const NestedNode<Args> &...nested_node)268   static NestedNode<A> function(
269       const F& fn,
270       const NestedNode<Args>&... nested_node) {
271     size_t degree = 0;
272     bool all_leaf = true;
273     c10::guts::tuple_map(
274         std::forward_as_tuple(nested_node...), [&all_leaf, &degree](auto n) {
275           all_leaf = all_leaf && (n.is_leaf());
276           if (degree > 1 && n.degree() > 1) {
277             TORCH_CHECK(
278                 degree == n.degree(), "NestedNodes must match in degree.");
279           }
280           if (n.degree() > degree) {
281             degree = n.degree();
282           }
283           return nullptr;
284         });
285     // All NestedNodes just wrap regular objects.
286     if (all_leaf) {
287       return NestedNode<A>(std::forward<F>(fn)(nested_node.payload()...));
288     }
289     // Some NestedNodes wrap regular Tensors, some NestedTensors and some other
290     // types.
291     std::vector<A> result;
292     for (size_t i = 0; i < degree; i++) {
293       auto children = c10::guts::tuple_map(
294           std::forward_as_tuple(nested_node...), [&i](auto a) {
295             static_assert(
296                 c10::guts::is_instantiation_of<NestedNode, decltype(a)>::value,
297                 "Internal error.");
298             // Broadcast regular arguments across NestedTensor constituents.
299             // This could be a Tensor, integer or anything else really.
300             if (a.is_leaf()) {
301               return a.payload();
302             }
303             // Broadcast NestedTensors with one constituent.
304             if (a.degree() == 1 && !a.is_leaf()) {
305               return a.children(0);
306             }
307             TORCH_CHECK(a.degree() > 0, "Internal assert.");
308             return a.children(i);
309           });
310       c10::guts::apply(
311           [&result, &fn](Args... filtered) {
312             result.emplace_back(function_one(fn, filtered...));
313           },
314           std::move(children));
315     }
316     return NestedNode<A>(std::move(result));
317   }
318 };
319 
320 // TODO: Add static assert to verify lambda arguments match nested_node types
321 template <class F, class... B>
322 static inline NestedNode<
323     typename c10::guts::infer_function_traits<F>::type::return_type>
map(F && fn,const NestedNode<B> &...nested_node)324 map(F&& fn, const NestedNode<B>&... nested_node) {
325   return _map<
326       F,
327       typename c10::guts::infer_function_traits<F>::type::return_type,
328       typename c10::guts::infer_function_traits<F>::type::parameter_types>::
329       function(std::forward<F>(fn), nested_node...);
330 }
331 
get_nested_tensor_structure(at::Tensor tensor)332 inline TensorNode get_nested_tensor_structure(at::Tensor tensor) {
333   if (get_nested_tensor_impl_or_null(tensor) == nullptr) {
334     return TensorNode(std::move(tensor));
335   }
336   return TensorNode(tensor.unbind());
337 }
338 
wrap_tensor_node(TensorNode tensor_node,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)339 inline Tensor wrap_tensor_node(
340     TensorNode tensor_node,
341     std::optional<ScalarType> dtype,
342     std::optional<Layout> layout,
343     std::optional<Device> device,
344     std::optional<bool> pin_memory) {
345   TORCH_CHECK(
346       !tensor_node.is_leaf(), "Expected TensorNode to wrap a list of Tensors.");
347   TensorOptions options_ =
348       TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
349           pin_memory);
350   if (tensor_node.degree() == 0) {
351     return wrap_buffer(ones({0}, dtype, layout, device), ones({}));
352   }
353 
354   // Fast path: if all tensors are on CPU, have contiguous memory, and the same
355   // dtype, copying can be done much faster.
356   bool all_tensors_cpu = true;
357   bool all_tensors_contiguous = true;
358   bool all_tensors_same_dtype = true;
359   auto first_dtype = tensor_node.children(0).dtype();
360   std::vector<long> start_offsets(tensor_node.degree());
361   start_offsets[0] = 0;
362   long total_size = 0;
363   for (const auto i : c10::irange(tensor_node.degree())) {
364     all_tensors_cpu = all_tensors_cpu && tensor_node.children(i).is_cpu();
365     all_tensors_contiguous =
366         all_tensors_contiguous && tensor_node.children(i).is_contiguous();
367     all_tensors_same_dtype = all_tensors_same_dtype &&
368         (first_dtype == tensor_node.children(i).dtype());
369     if (!(all_tensors_cpu && all_tensors_contiguous &&
370           all_tensors_same_dtype)) {
371       break;
372     }
373     if (i > 0) {
374       start_offsets[i] =
375           start_offsets[i - 1] + tensor_node.children(i - 1).numel();
376     }
377     total_size += tensor_node.children(i).numel();
378   }
379 
380   TensorOptions options;
381   Tensor nt_buffer, nt_sizes;
382   if (all_tensors_cpu && all_tensors_contiguous && all_tensors_same_dtype) {
383     nt_buffer = at::empty({total_size}, tensor_node.children(0).options());
384     nt_sizes = at::empty(
385         {static_cast<long>(tensor_node.degree()),
386          static_cast<long>(tensor_node.children(0).sizes().size())},
387         TensorOptions().dtype(kLong));
388     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
389         at::ScalarType::Half,
390         at::ScalarType::Bool,
391         at::ScalarType::BFloat16,
392         c10::typeMetaToScalarType(first_dtype),
393         "create_nt_buffer",
394         [&]() {
395           at::parallel_for(
396               0, tensor_node.degree(), 1, [&](int64_t begin, int64_t end) {
397                 for (int64_t i = begin; i < end; ++i) {
398                   // Only try copying memory if there is more than 0 elements
399                   // for a certain tensor
400                   if (tensor_node.children(i).numel() > 0) {
401                     memcpy(
402                         nt_buffer.mutable_data_ptr<scalar_t>() + start_offsets[i],
403                         tensor_node.children(i).const_data_ptr<scalar_t>(),
404                         tensor_node.children(i).numel() * sizeof(scalar_t));
405                   }
406                 }
407               });
408         });
409     long sizes_offset = 0;
410     for (size_t i = 0; i < tensor_node.degree(); ++i) {
411       auto tensor_sizes = tensor_node.children(i).sizes();
412       for (int64_t tensor_size : tensor_sizes) {
413         nt_sizes.mutable_data_ptr<int64_t>()[sizes_offset++] = tensor_size;
414       }
415     }
416     options = nt_buffer.options().merge_in(options_);
417   } else { // Slow path
418     std::vector<Tensor> flat_tensors;
419     std::vector<Tensor> sizes;
420     for (const auto i : c10::irange(tensor_node.degree())) {
421       flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous());
422       sizes.push_back(
423           tensor(c10::IntArrayRef(tensor_node.children(i).sizes())));
424     }
425     options = flat_tensors[0].options().merge_in(options_);
426     nt_buffer = at::cat(flat_tensors);
427     nt_sizes = at::native::stack(sizes);
428   }
429 
430   return wrap_buffer(nt_buffer.to(options), nt_sizes);
431 }
432 
433 } // namespace impl
434 
435 // This function is meant to ease rapid operator coverage for
436 // NestedTensor kernels. It is not meant to be efficient. Use it judiciously.
437 template <class F, class... A>
map_nested_tensor(F && fn,A...a)438 inline at::Tensor map_nested_tensor(F&& fn, A... a) {
439   return wrap_tensor_node(
440       impl::map(std::forward<F>(fn), impl::get_nested_tensor_structure(a)...),
441       std::nullopt,
442       std::nullopt,
443       std::nullopt,
444       std::nullopt);
445 }
446 
447 } // namespace at::native
448