1 #pragma once
2
3 #include <ATen/Dispatch.h>
4 #include <ATen/NestedTensorImpl.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/core/Tensor.h>
7 #include <c10/core/DispatchKeySet.h>
8 #include <c10/core/TensorImpl.h>
9 #include <c10/macros/Macros.h>
10 #include <c10/util/Exception.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/cat.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/ones_native.h>
20 #include <ATen/ops/prod.h>
21 #include <ATen/ops/stack_native.h>
22 #include <ATen/ops/tensor.h>
23 #endif
24
25 #include <utility>
26 #include <vector>
27
28 namespace at::native {
29 struct NestedTensorImpl;
30
31 // The following functions are used to construct nested tensors from buffers and
32 // metadata.
33
wrap_buffer(at::Tensor buffer,at::Tensor nested_sizes)34 inline at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_sizes) {
35 TORCH_CHECK(
36 buffer.dim() == 1,
37 "Expected given buffer to be 1dim, but got ",
38 buffer.dim(),
39 " instead.");
40 TORCH_CHECK(
41 buffer.is_contiguous(), "Expected given buffer to be contiguous.");
42 return at::detail::make_tensor<NestedTensorImpl>(
43 std::move(buffer), std::move(nested_sizes));
44 }
45
46 // TODO: Figure out if we need a non-moving wrap_buffer()
wrap_buffer(at::Tensor buffer,at::Tensor nested_sizes,at::Tensor nested_strides,at::Tensor storage_offsets)47 inline at::Tensor wrap_buffer(
48 at::Tensor buffer,
49 at::Tensor nested_sizes,
50 at::Tensor nested_strides,
51 at::Tensor storage_offsets) {
52 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
53 buffer.is_contiguous(), "Given buffer must be contiguous.");
54 return at::detail::make_tensor<NestedTensorImpl>(
55 std::move(buffer),
56 std::move(nested_sizes),
57 std::move(nested_strides),
58 std::move(storage_offsets));
59 }
60
get_buffer(const at::Tensor & tensor)61 inline at::Tensor get_buffer(const at::Tensor& tensor) {
62 return get_nested_tensor_impl(tensor)->get_buffer();
63 }
64
65 /**
66 * Create a new nested tensor that is a view of a base nested tensor
67 *
68 * create_view_tensor calls a specialized constructor that copys the
69 * the keys from base onto the new view tensor being created.
70 * The storage is shared between the base and the returned view tensor
71 *
72 * All callers of this helper must:
73 * - Only return a view of the input
74 * - Must be explicit and define a derivative
75 *
76 * @param base Base tensor to construct view from.
77 * @param nested_sizes View tensors' sizes.
78 * @param nested_strides View tensors' strides.
79 * @param storage_offsets View tensors' offsets.
80 * @return A newly constructed view tensor
81 */
create_nested_view_tensor(const at::Tensor & base,at::Tensor nested_sizes,at::Tensor nested_strides,at::Tensor storage_offsets)82 inline at::Tensor create_nested_view_tensor(
83 const at::Tensor& base,
84 at::Tensor nested_sizes,
85 at::Tensor nested_strides,
86 at::Tensor storage_offsets) {
87 TORCH_INTERNAL_ASSERT(
88 base.is_nested(),
89 "This function can only be used to create nested tensor views");
90 TORCH_INTERNAL_ASSERT(
91 c10::impl::tls_local_dispatch_key_set().excluded_.has(
92 c10::DispatchKey::AutogradFunctionality),
93 "Creating a non differentiable nested tensor view in a CompositeImplicit function is not allowed.");
94 return at::detail::make_tensor<NestedTensorImpl>(
95 c10::TensorImpl::VIEW,
96 base,
97 nested_sizes,
98 nested_strides,
99 storage_offsets);
100 }
101 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102
103 // Helper functions for getting information about a nested tensor's shape.
104
105 int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
106
107 // The sizes of the underlying tensors
NestedTensor_get_sizes(const NestedTensorImpl * self_ptr)108 inline std::vector<IntArrayRef> NestedTensor_get_sizes(
109 const NestedTensorImpl* self_ptr) {
110 int64_t ntensors = self_ptr->size(0);
111 std::vector<IntArrayRef> sizes(ntensors);
112 if (ntensors == 0) {
113 return sizes;
114 }
115 const Tensor& sizemat = self_ptr->get_nested_sizes();
116 int64_t orig_dim = sizemat.size(1);
117 // nesting scalars has empty sizes
118 if (orig_dim == 0) {
119 return sizes;
120 }
121 const int64_t* sizemat_ptr = sizemat.const_data_ptr<int64_t>();
122
123 for (const auto i : c10::irange(ntensors)) {
124 sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
125 sizemat_ptr += orig_dim;
126 }
127 return sizes;
128 }
129
130 TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
131 const NestedTensorImpl& nt);
132
133 std::vector<int64_t> NestedTensor_get_max_size_from_size_tensor(
134 const Tensor& sizes);
135
NestedTensor_get_sizes(const at::Tensor & self)136 inline std::vector<IntArrayRef> NestedTensor_get_sizes(const at::Tensor& self) {
137 const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
138 return NestedTensor_get_sizes(self_ptr);
139 }
140 // The strides of the underlying tensors
NestedTensor_get_strides(const NestedTensorImpl * self_ptr)141 inline std::vector<IntArrayRef> NestedTensor_get_strides(
142 const NestedTensorImpl* self_ptr) {
143 int64_t ntensors = self_ptr->size(0);
144 std::vector<IntArrayRef> strides(ntensors);
145 if (ntensors == 0) {
146 return strides;
147 }
148 const Tensor& stridemat = self_ptr->get_nested_strides();
149 int64_t orig_dim = stridemat.size(1);
150 // nesting scalars has empty strides
151 if (orig_dim == 0) {
152 return strides;
153 }
154 const int64_t* stridemat_ptr = stridemat.const_data_ptr<int64_t>();
155 for (const auto i : c10::irange(ntensors)) {
156 strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim);
157 stridemat_ptr += orig_dim;
158 }
159 return strides;
160 }
161
NestedTensor_get_strides(const at::Tensor & self)162 inline std::vector<IntArrayRef> NestedTensor_get_strides(
163 const at::Tensor& self) {
164 const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
165 return NestedTensor_get_strides(self_ptr);
166 }
167
check_numel_equals_buffer_size(const at::Tensor & self)168 inline void check_numel_equals_buffer_size(const at::Tensor& self) {
169 auto self_impl = get_nested_tensor_impl(self);
170 TORCH_CHECK(
171 self.numel() == static_cast<int64_t>(self_impl->get_buffer_size()),
172 "Number of elements in nested tensor must match number of elements in buffer.");
173 }
174
check_numel_equals_buffer_size(const NestedTensorImpl * self_ptr)175 inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) {
176 TORCH_CHECK(
177 self_ptr->numel() == static_cast<int64_t>(self_ptr->get_buffer_size()),
178 "Number of elements in nested tensor must match number of elements in buffer.");
179 }
180
181 // Helper function to get size / stride / offset for a nested/normal tensor.
get_size_for_index(const Tensor & tensor,int64_t i)182 inline IntArrayRef get_size_for_index(const Tensor& tensor, int64_t i) {
183 if (tensor.is_nested()) {
184 std::vector<IntArrayRef> tensor_sizes =
185 NestedTensor_get_sizes(get_nested_tensor_impl(tensor));
186 return tensor_sizes[i];
187 } else {
188 return tensor.sizes().slice(1);
189 }
190 }
191
get_stride_for_index(const Tensor & tensor,int64_t i)192 inline IntArrayRef get_stride_for_index(const Tensor& tensor, int64_t i) {
193 if (tensor.is_nested()) {
194 std::vector<IntArrayRef> tensor_strides =
195 NestedTensor_get_strides(get_nested_tensor_impl(tensor));
196 return tensor_strides[i];
197 } else {
198 return tensor.strides().slice(1);
199 }
200 }
201
get_offset_for_index(const Tensor & tensor,int64_t i)202 inline int64_t get_offset_for_index(const Tensor& tensor, int64_t i) {
203 if (tensor.is_nested()) {
204 int64_t* offsets_ptr = get_nested_tensor_impl(tensor)
205 ->get_storage_offsets()
206 .data_ptr<int64_t>();
207 return offsets_ptr[i];
208
209 } else {
210 int64_t offset = tensor.storage_offset();
211 return offset + tensor.strides()[0] * i;
212 }
213 }
214 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
215 // Data structures and functions for generically applying a function on a nested
216 // tensor.
217 namespace impl {
218
219 template <typename T>
220 struct NestedNode {
221 NestedNode() = delete;
NestedNodeNestedNode222 explicit NestedNode(std::vector<T> children)
223 : _is_leaf(false), _children(std::move(children)) {}
NestedNodeNestedNode224 explicit NestedNode(TensorList children)
225 : _is_leaf(false), _children(children.vec()) {}
NestedNodeNestedNode226 explicit NestedNode(T payload)
227 : _is_leaf(true), _payload(std::move(payload)) {}
228 NestedNode(const NestedNode&) = delete;
229 NestedNode& operator=(const NestedNode&) = delete;
230 NestedNode(NestedNode&&) noexcept = default;
231 NestedNode& operator=(NestedNode&&) noexcept = default;
is_leafNestedNode232 inline bool is_leaf() const {
233 return _is_leaf;
234 }
degreeNestedNode235 inline size_t degree() const {
236 return _children.size();
237 }
unbindNestedNode238 inline const std::vector<T> unbind() const {
239 return _children;
240 }
childrenNestedNode241 inline T children(size_t i) const {
242 return _children[i];
243 }
payloadNestedNode244 inline const T& payload() const {
245 return _payload;
246 }
payloadNestedNode247 inline T& payload() {
248 return _payload;
249 }
250
251 private:
252 bool _is_leaf;
253 std::vector<T> _children;
254 T _payload{};
255 };
256
257 using TensorNode = NestedNode<at::Tensor>;
258
259 template <class F, class A, class TypeList>
260 class _map;
261
262 template <class F, class A, class... Args>
263 class _map<F, A, c10::guts::typelist::typelist<Args...>> {
264 public:
function_one(F && fn,const Args &...nested_node)265 static A function_one(F&& fn, const Args&... nested_node) {
266 return std::forward<F>(fn)(nested_node...);
267 }
function(const F & fn,const NestedNode<Args> &...nested_node)268 static NestedNode<A> function(
269 const F& fn,
270 const NestedNode<Args>&... nested_node) {
271 size_t degree = 0;
272 bool all_leaf = true;
273 c10::guts::tuple_map(
274 std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) {
275 all_leaf = all_leaf && (n.is_leaf());
276 if (degree > 1 && n.degree() > 1) {
277 TORCH_CHECK(
278 degree == n.degree(), "NestedNodes must match in degree.");
279 }
280 if (n.degree() > degree) {
281 degree = n.degree();
282 }
283 return nullptr;
284 });
285 // All NestedNodes just wrap regular objects.
286 if (all_leaf) {
287 return NestedNode<A>(std::forward<F>(fn)(nested_node.payload()...));
288 }
289 // Some NestedNodes wrap regular Tensors, some NestedTensors and some other
290 // types.
291 std::vector<A> result;
292 for (size_t i = 0; i < degree; i++) {
293 auto children = c10::guts::tuple_map(
294 std::forward_as_tuple(nested_node...), [&i](auto a) {
295 static_assert(
296 c10::guts::is_instantiation_of<NestedNode, decltype(a)>::value,
297 "Internal error.");
298 // Broadcast regular arguments across NestedTensor constituents.
299 // This could be a Tensor, integer or anything else really.
300 if (a.is_leaf()) {
301 return a.payload();
302 }
303 // Broadcast NestedTensors with one constituent.
304 if (a.degree() == 1 && !a.is_leaf()) {
305 return a.children(0);
306 }
307 TORCH_CHECK(a.degree() > 0, "Internal assert.");
308 return a.children(i);
309 });
310 c10::guts::apply(
311 [&result, &fn](Args... filtered) {
312 result.emplace_back(function_one(fn, filtered...));
313 },
314 std::move(children));
315 }
316 return NestedNode<A>(std::move(result));
317 }
318 };
319
320 // TODO: Add static assert to verify lambda arguments match nested_node types
321 template <class F, class... B>
322 static inline NestedNode<
323 typename c10::guts::infer_function_traits<F>::type::return_type>
map(F && fn,const NestedNode<B> &...nested_node)324 map(F&& fn, const NestedNode<B>&... nested_node) {
325 return _map<
326 F,
327 typename c10::guts::infer_function_traits<F>::type::return_type,
328 typename c10::guts::infer_function_traits<F>::type::parameter_types>::
329 function(std::forward<F>(fn), nested_node...);
330 }
331
get_nested_tensor_structure(at::Tensor tensor)332 inline TensorNode get_nested_tensor_structure(at::Tensor tensor) {
333 if (get_nested_tensor_impl_or_null(tensor) == nullptr) {
334 return TensorNode(std::move(tensor));
335 }
336 return TensorNode(tensor.unbind());
337 }
338
wrap_tensor_node(TensorNode tensor_node,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)339 inline Tensor wrap_tensor_node(
340 TensorNode tensor_node,
341 std::optional<ScalarType> dtype,
342 std::optional<Layout> layout,
343 std::optional<Device> device,
344 std::optional<bool> pin_memory) {
345 TORCH_CHECK(
346 !tensor_node.is_leaf(), "Expected TensorNode to wrap a list of Tensors.");
347 TensorOptions options_ =
348 TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
349 pin_memory);
350 if (tensor_node.degree() == 0) {
351 return wrap_buffer(ones({0}, dtype, layout, device), ones({}));
352 }
353
354 // Fast path: if all tensors are on CPU, have contiguous memory, and the same
355 // dtype, copying can be done much faster.
356 bool all_tensors_cpu = true;
357 bool all_tensors_contiguous = true;
358 bool all_tensors_same_dtype = true;
359 auto first_dtype = tensor_node.children(0).dtype();
360 std::vector<long> start_offsets(tensor_node.degree());
361 start_offsets[0] = 0;
362 long total_size = 0;
363 for (const auto i : c10::irange(tensor_node.degree())) {
364 all_tensors_cpu = all_tensors_cpu && tensor_node.children(i).is_cpu();
365 all_tensors_contiguous =
366 all_tensors_contiguous && tensor_node.children(i).is_contiguous();
367 all_tensors_same_dtype = all_tensors_same_dtype &&
368 (first_dtype == tensor_node.children(i).dtype());
369 if (!(all_tensors_cpu && all_tensors_contiguous &&
370 all_tensors_same_dtype)) {
371 break;
372 }
373 if (i > 0) {
374 start_offsets[i] =
375 start_offsets[i - 1] + tensor_node.children(i - 1).numel();
376 }
377 total_size += tensor_node.children(i).numel();
378 }
379
380 TensorOptions options;
381 Tensor nt_buffer, nt_sizes;
382 if (all_tensors_cpu && all_tensors_contiguous && all_tensors_same_dtype) {
383 nt_buffer = at::empty({total_size}, tensor_node.children(0).options());
384 nt_sizes = at::empty(
385 {static_cast<long>(tensor_node.degree()),
386 static_cast<long>(tensor_node.children(0).sizes().size())},
387 TensorOptions().dtype(kLong));
388 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
389 at::ScalarType::Half,
390 at::ScalarType::Bool,
391 at::ScalarType::BFloat16,
392 c10::typeMetaToScalarType(first_dtype),
393 "create_nt_buffer",
394 [&]() {
395 at::parallel_for(
396 0, tensor_node.degree(), 1, [&](int64_t begin, int64_t end) {
397 for (int64_t i = begin; i < end; ++i) {
398 // Only try copying memory if there is more than 0 elements
399 // for a certain tensor
400 if (tensor_node.children(i).numel() > 0) {
401 memcpy(
402 nt_buffer.mutable_data_ptr<scalar_t>() + start_offsets[i],
403 tensor_node.children(i).const_data_ptr<scalar_t>(),
404 tensor_node.children(i).numel() * sizeof(scalar_t));
405 }
406 }
407 });
408 });
409 long sizes_offset = 0;
410 for (size_t i = 0; i < tensor_node.degree(); ++i) {
411 auto tensor_sizes = tensor_node.children(i).sizes();
412 for (int64_t tensor_size : tensor_sizes) {
413 nt_sizes.mutable_data_ptr<int64_t>()[sizes_offset++] = tensor_size;
414 }
415 }
416 options = nt_buffer.options().merge_in(options_);
417 } else { // Slow path
418 std::vector<Tensor> flat_tensors;
419 std::vector<Tensor> sizes;
420 for (const auto i : c10::irange(tensor_node.degree())) {
421 flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous());
422 sizes.push_back(
423 tensor(c10::IntArrayRef(tensor_node.children(i).sizes())));
424 }
425 options = flat_tensors[0].options().merge_in(options_);
426 nt_buffer = at::cat(flat_tensors);
427 nt_sizes = at::native::stack(sizes);
428 }
429
430 return wrap_buffer(nt_buffer.to(options), nt_sizes);
431 }
432
433 } // namespace impl
434
435 // This function is meant to ease rapid operator coverage for
436 // NestedTensor kernels. It is not meant to be efficient. Use it judiciously.
437 template <class F, class... A>
map_nested_tensor(F && fn,A...a)438 inline at::Tensor map_nested_tensor(F&& fn, A... a) {
439 return wrap_tensor_node(
440 impl::map(std::forward<F>(fn), impl::get_nested_tensor_structure(a)...),
441 std::nullopt,
442 std::nullopt,
443 std::nullopt,
444 std::nullopt);
445 }
446
447 } // namespace at::native
448