xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/save_restore_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_
17 #define TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_
18 
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/dtensor/cc/dstatus.h"
24 #include "tensorflow/dtensor/cc/tensor_layout.h"
25 
26 namespace tensorflow {
27 namespace dtensor {
28 
29 // Defines an Metadata entry when saving a Tensor.
30 struct SavingTensorMetadata {
31   // Tracks index from the original save op.
32   int64_t tensor_index;
33   // The global shape of the saving tensor.
34   std::vector<int64_t> shape;
35   // The layout of the saving tensor.
36   Layout layout;
37 
SavingTensorMetadataSavingTensorMetadata38   SavingTensorMetadata(int64_t index, std::vector<int64_t> global_shape,
39                        Layout tensor_layout)
40       : tensor_index(index),
41         shape(std::move(global_shape)),
42         layout(std::move(tensor_layout)) {}
43 };
44 
45 // Tracks a complete specification for a particular save op.
46 // The users would build out multiple save ops using the following manner for
47 // the given fields:
48 //
49 // save_op[i] = tf.SaveV2(
50 //              prefix = new_prefixes[i],
51 //              tensor_indices = tensor_indies[i],
52 //              shape_and_slices = shape_and_slice_spec[i])
53 struct SaveOpSpecs {
54   std::vector<std::string> new_prefixes;
55   std::vector<std::vector<int>> tensor_indices;
56   std::vector<std::vector<std::string>> shape_and_slice_spec;
57 
SaveOpSpecsSaveOpSpecs58   SaveOpSpecs(std::vector<std::string> prefixes,
59               std::vector<std::vector<int>> indices,
60               std::vector<std::vector<std::string>> specs)
61       : new_prefixes(std::move(prefixes)),
62         tensor_indices(std::move(indices)),
63         shape_and_slice_spec(std::move(specs)) {}
64 };
65 
66 // Returns a device suffix with printf formatting.
67 std::string DeviceSuffix(int device_id, int total_devices);
68 
69 // Builds a complete saving specification for each device on the mesh.
70 //
71 // The returned map contains a map of <device_id, SavingSpec>.
72 // Device_id is where the saving should happen, and SavingSpec is a
73 // mapping of <tensor_index -> shape_and_slices>. e.g.,
74 //
75 // A map of {device_id : 0 ->  {
76 //     0 : "2 0,1",
77 //     1 : ""
78 //   }
79 // }
80 //
81 // Means that device_0 is responsible for saving tensor 0 and 1 from the passed
82 // in tensors list. For tensor[0], it saves the only the first element in that
83 // 1d vector with 2 elements. For tensor[1], it saves all elements.
84 //
85 // We accept another map as input, that records the mapping of
86 // <tensor_index -> (tensor_global_shape, tensor_layout)>.
87 //
88 // (tensor_global_shape, tensor_layout & tensor_layout.mesh) defines which
89 // device saves what slices of the Tensor.
90 //
91 // For a complete definition of shape_and_slices field, please see:
92 // third_party/tensorflow/core/framework/tensor_slice.h
93 StatusOr<absl::flat_hash_map<
94     int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>>
95 BuildSavingSpec(absl::Span<const SavingTensorMetadata> tensor_metadatas);
96 
97 // For a given per device saving spec, find out the counts of SaveV2 ops
98 // needed and their corresponding inputs.
99 //
100 // Current SaveV2 op requires tensor_names to be unique in the list, which is a
101 // contract that distributed saving would break. For example, if the saving spec
102 // decides that device 0 is responsible for saving two slices of tensor[a], then
103 // a single SaveV2 op can't fufill. The setup is very likely to happen when
104 // saving on TPU - where 8 cores maps to 1 host. In that case, the CPU host will
105 // be responsible for saving slices on the same tensor across 8 TPU cores.
106 // TODO(b/179126981): Investigate whether we can make TF core API run with
107 // different slice spec on a same tensor key.
108 //
109 // That said, building one SaveV2 op for each save is wasteful, when a single
110 // SaveV2 op is capable of saving different tensors. Instead, we simply need to
111 // break the SaveV2 op to be able to track the longest saving specs for a single
112 // tensor happening on the device,  e.g.,
113 //
114 // For given saving specs:
115 //
116 // { 'tensor_name_a' : <"spec_a", "spec_a_2"> }
117 // { 'tensor_name_b' : <"spec_b"> }
118 //
119 // would result into two save ops, where:
120 //
121 // SaveOp1 (tensor_names = <"tensor_name_a", tensor_name_b">,
122 //                           slice_spec = <"spec_a", "spec_b">)
123 //
124 // SaveOp2 (tensor_names = "<tensor_name_a>", slice_spec = <"spec_a_2">.
125 //
126 // The output vectors tracks the new SaveV2 op parameters and they must agree on
127 // size and indexing for saving tensors.
128 //
129 // tensor_indices trackes a list of indices of tensors that are being saved for
130 // each Save op, e.g.,
131 //
132 // tensor_indices[0] is a list of tensors (in index form) that needs to be saved
133 // on the first SaveV2 op.
134 //
135 // shape_and_slice_specs tracks a list of shape_and_slice_specs being saved for
136 // each Save op, e.g.,
137 //
138 // shape_and_slice_spec[0] is a list of shape_and_slices parameters for SaveV2
139 // op.
140 SaveOpSpecs BuildPerDeviceSave(
141     const absl::flat_hash_map<int64_t, std::vector<std::string>>& saving_spec,
142     int device_id, absl::string_view prefix, int total_devices);
143 
144 // Figures out the tensor slice_spec for a given layout and mesh device
145 // location.
146 StatusOr<std::vector<std::string>> SliceSpecOnDevice(
147     const Layout& layout, const Mesh& mesh, const DeviceLocation& device_coords,
148     absl::Span<const int64_t> global_shape);
149 }  // namespace dtensor
150 
151 }  // namespace tensorflow
152 
153 #endif  // TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_
154