1 #pragma once
2
3 #include <ATen/ATen.h>
4 #include <ATen/core/functional.h>
5 #include <c10/core/TensorOptions.h>
6 #include <torch/csrc/Export.h>
7 #include <utility>
8
9 namespace torch::utils {
10
11 /// Generate an ID for a combination of tensor backend + scalar type to be used
12 /// when ordering tensors ('like' tensors are grouped by pulling out their
13 /// backend + scalar type, so this function combines that into a single number)
type_id(const at::Tensor & tensor)14 inline size_t type_id(const at::Tensor& tensor) {
15 return static_cast<size_t>(tensor.options().backend()) *
16 static_cast<size_t>(at::ScalarType::NumOptions) +
17 static_cast<size_t>(tensor.scalar_type());
18 }
19
flatten_dense_tensors(at::TensorList tensors)20 inline at::Tensor flatten_dense_tensors(at::TensorList tensors) {
21 return at::flatten_dense_tensors(tensors);
22 }
23
unflatten_dense_tensors(const at::Tensor & flat,at::TensorList tensors)24 inline std::vector<at::Tensor> unflatten_dense_tensors(
25 const at::Tensor& flat,
26 at::TensorList tensors) {
27 return at::unflatten_dense_tensors(flat, tensors);
28 }
29
30 struct TensorGroup {
31 std::vector<at::Tensor> tensors;
32 size_t size = 0;
33
type_idTensorGroup34 size_t type_id() {
35 AT_ASSERT(!tensors.empty());
36 return ::torch::utils::type_id(tensors[0]);
37 }
38
optionsTensorGroup39 const at::TensorOptions options() {
40 AT_ASSERT(!tensors.empty());
41 return tensors[0].options();
42 }
43 };
44
45 // Helper function that takes a list of tensors and splits them into tensor
46 // groups by the size limit and outputs these tensor groups. If the input
47 // tensors are of different tensor types, they will be split into different
48 // groups as well.
49 //
50 // Two options of splitting provided to the user,
51 //
52 // Imagine the size_limit is 256 and the list of input tensors are:
53 // tensor_a(fp16 - 128 bytes),
54 // tensor_b(fp32 - 256 bytes),
55 // tensor_c(fp16 - 128 bytes),
56 //
57 // when fine_grained == false:
58 // The function will read the list of tensors sequentially and accumulate
59 // enough tensors for each data type until the size_limit, therefore:
60 // it will output: {{tensor_a, tensor_c}, {tensor_b}}
61 //
62 // when fine_grained == true:
63 // The function will read the list of tensors sequentially and accumulate
64 // enough tensors for all data types until the size_limit, and then split
65 // the accumulated tensors into different groups by data types, therefore:
66 // it will output: {{tensor_a}, {tensor_b}, {tensor_c}}
67 TORCH_API std::vector<TensorGroup> take_tensors(
68 at::TensorList tensors,
69 size_t size_limit,
70 bool fine_grained = false);
71
72 TORCH_API void reorder_tensors_like(
73 std::vector<at::Tensor>& tensors,
74 at::TensorList order);
75
76 TORCH_API std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
77 at::TensorList tensors);
78
79 TORCH_API std::vector<at::Tensor> unflatten_sparse_tensors(
80 const at::Tensor& flat_indices,
81 const at::Tensor& flat_values,
82 at::TensorList tensors);
83
84 } // namespace torch::utils
85