xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_flatten.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/functional.h>
5 #include <c10/core/TensorOptions.h>
6 #include <torch/csrc/Export.h>
7 #include <utility>
8 
9 namespace torch::utils {
10 
11 /// Generate an ID for a combination of tensor backend + scalar type to be used
12 /// when ordering tensors ('like' tensors are grouped by pulling out their
13 /// backend + scalar type, so this function combines that into a single number)
type_id(const at::Tensor & tensor)14 inline size_t type_id(const at::Tensor& tensor) {
15   return static_cast<size_t>(tensor.options().backend()) *
16       static_cast<size_t>(at::ScalarType::NumOptions) +
17       static_cast<size_t>(tensor.scalar_type());
18 }
19 
flatten_dense_tensors(at::TensorList tensors)20 inline at::Tensor flatten_dense_tensors(at::TensorList tensors) {
21   return at::flatten_dense_tensors(tensors);
22 }
23 
unflatten_dense_tensors(const at::Tensor & flat,at::TensorList tensors)24 inline std::vector<at::Tensor> unflatten_dense_tensors(
25     const at::Tensor& flat,
26     at::TensorList tensors) {
27   return at::unflatten_dense_tensors(flat, tensors);
28 }
29 
30 struct TensorGroup {
31   std::vector<at::Tensor> tensors;
32   size_t size = 0;
33 
type_idTensorGroup34   size_t type_id() {
35     AT_ASSERT(!tensors.empty());
36     return ::torch::utils::type_id(tensors[0]);
37   }
38 
optionsTensorGroup39   const at::TensorOptions options() {
40     AT_ASSERT(!tensors.empty());
41     return tensors[0].options();
42   }
43 };
44 
45 // Helper function that takes a list of tensors and splits them into tensor
46 // groups by the size limit and outputs these tensor groups. If the input
47 // tensors are of different tensor types, they will be split into different
48 // groups as well.
49 //
50 // Two options of splitting provided to the user,
51 //
52 // Imagine the size_limit is 256 and the list of input tensors are:
53 // tensor_a(fp16 - 128 bytes),
54 // tensor_b(fp32 - 256 bytes),
55 // tensor_c(fp16 - 128 bytes),
56 //
57 // when fine_grained == false:
58 // The function will read the list of tensors sequentially and accumulate
59 // enough tensors for each data type until the size_limit, therefore:
60 // it will output: {{tensor_a, tensor_c}, {tensor_b}}
61 //
62 // when fine_grained == true:
63 // The function will read the list of tensors sequentially and  accumulate
64 // enough tensors for all data types until the size_limit, and then split
65 // the accumulated tensors into different groups by data types, therefore:
66 // it will output: {{tensor_a}, {tensor_b}, {tensor_c}}
67 TORCH_API std::vector<TensorGroup> take_tensors(
68     at::TensorList tensors,
69     size_t size_limit,
70     bool fine_grained = false);
71 
72 TORCH_API void reorder_tensors_like(
73     std::vector<at::Tensor>& tensors,
74     at::TensorList order);
75 
76 TORCH_API std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
77     at::TensorList tensors);
78 
79 TORCH_API std::vector<at::Tensor> unflatten_sparse_tensors(
80     const at::Tensor& flat_indices,
81     const at::Tensor& flat_values,
82     at::TensorList tensors);
83 
84 } // namespace torch::utils
85