1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_ 17 #define TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_ 18 19 #include <string> 20 #include <utility> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "tensorflow/dtensor/cc/dstatus.h" 24 #include "tensorflow/dtensor/cc/tensor_layout.h" 25 26 namespace tensorflow { 27 namespace dtensor { 28 29 // Defines an Metadata entry when saving a Tensor. 30 struct SavingTensorMetadata { 31 // Tracks index from the original save op. 32 int64_t tensor_index; 33 // The global shape of the saving tensor. 34 std::vector<int64_t> shape; 35 // The layout of the saving tensor. 36 Layout layout; 37 SavingTensorMetadataSavingTensorMetadata38 SavingTensorMetadata(int64_t index, std::vector<int64_t> global_shape, 39 Layout tensor_layout) 40 : tensor_index(index), 41 shape(std::move(global_shape)), 42 layout(std::move(tensor_layout)) {} 43 }; 44 45 // Tracks a complete specification for a particular save op. 46 // The users would build out multiple save ops using the following manner for 47 // the given fields: 48 // 49 // save_op[i] = tf.SaveV2( 50 // prefix = new_prefixes[i], 51 // tensor_indices = tensor_indies[i], 52 // shape_and_slices = shape_and_slice_spec[i]) 53 struct SaveOpSpecs { 54 std::vector<std::string> new_prefixes; 55 std::vector<std::vector<int>> tensor_indices; 56 std::vector<std::vector<std::string>> shape_and_slice_spec; 57 SaveOpSpecsSaveOpSpecs58 SaveOpSpecs(std::vector<std::string> prefixes, 59 std::vector<std::vector<int>> indices, 60 std::vector<std::vector<std::string>> specs) 61 : new_prefixes(std::move(prefixes)), 62 tensor_indices(std::move(indices)), 63 shape_and_slice_spec(std::move(specs)) {} 64 }; 65 66 // Returns a device suffix with printf formatting. 67 std::string DeviceSuffix(int device_id, int total_devices); 68 69 // Builds a complete saving specification for each device on the mesh. 70 // 71 // The returned map contains a map of <device_id, SavingSpec>. 72 // Device_id is where the saving should happen, and SavingSpec is a 73 // mapping of <tensor_index -> shape_and_slices>. e.g., 74 // 75 // A map of {device_id : 0 -> { 76 // 0 : "2 0,1", 77 // 1 : "" 78 // } 79 // } 80 // 81 // Means that device_0 is responsible for saving tensor 0 and 1 from the passed 82 // in tensors list. For tensor[0], it saves the only the first element in that 83 // 1d vector with 2 elements. For tensor[1], it saves all elements. 84 // 85 // We accept another map as input, that records the mapping of 86 // <tensor_index -> (tensor_global_shape, tensor_layout)>. 87 // 88 // (tensor_global_shape, tensor_layout & tensor_layout.mesh) defines which 89 // device saves what slices of the Tensor. 90 // 91 // For a complete definition of shape_and_slices field, please see: 92 // third_party/tensorflow/core/framework/tensor_slice.h 93 StatusOr<absl::flat_hash_map< 94 int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>> 95 BuildSavingSpec(absl::Span<const SavingTensorMetadata> tensor_metadatas); 96 97 // For a given per device saving spec, find out the counts of SaveV2 ops 98 // needed and their corresponding inputs. 99 // 100 // Current SaveV2 op requires tensor_names to be unique in the list, which is a 101 // contract that distributed saving would break. For example, if the saving spec 102 // decides that device 0 is responsible for saving two slices of tensor[a], then 103 // a single SaveV2 op can't fufill. The setup is very likely to happen when 104 // saving on TPU - where 8 cores maps to 1 host. In that case, the CPU host will 105 // be responsible for saving slices on the same tensor across 8 TPU cores. 106 // TODO(b/179126981): Investigate whether we can make TF core API run with 107 // different slice spec on a same tensor key. 108 // 109 // That said, building one SaveV2 op for each save is wasteful, when a single 110 // SaveV2 op is capable of saving different tensors. Instead, we simply need to 111 // break the SaveV2 op to be able to track the longest saving specs for a single 112 // tensor happening on the device, e.g., 113 // 114 // For given saving specs: 115 // 116 // { 'tensor_name_a' : <"spec_a", "spec_a_2"> } 117 // { 'tensor_name_b' : <"spec_b"> } 118 // 119 // would result into two save ops, where: 120 // 121 // SaveOp1 (tensor_names = <"tensor_name_a", tensor_name_b">, 122 // slice_spec = <"spec_a", "spec_b">) 123 // 124 // SaveOp2 (tensor_names = "<tensor_name_a>", slice_spec = <"spec_a_2">. 125 // 126 // The output vectors tracks the new SaveV2 op parameters and they must agree on 127 // size and indexing for saving tensors. 128 // 129 // tensor_indices trackes a list of indices of tensors that are being saved for 130 // each Save op, e.g., 131 // 132 // tensor_indices[0] is a list of tensors (in index form) that needs to be saved 133 // on the first SaveV2 op. 134 // 135 // shape_and_slice_specs tracks a list of shape_and_slice_specs being saved for 136 // each Save op, e.g., 137 // 138 // shape_and_slice_spec[0] is a list of shape_and_slices parameters for SaveV2 139 // op. 140 SaveOpSpecs BuildPerDeviceSave( 141 const absl::flat_hash_map<int64_t, std::vector<std::string>>& saving_spec, 142 int device_id, absl::string_view prefix, int total_devices); 143 144 // Figures out the tensor slice_spec for a given layout and mesh device 145 // location. 146 StatusOr<std::vector<std::string>> SliceSpecOnDevice( 147 const Layout& layout, const Mesh& mesh, const DeviceLocation& device_coords, 148 absl::Span<const int64_t> global_shape); 149 } // namespace dtensor 150 151 } // namespace tensorflow 152 153 #endif // TENSORFLOW_DTENSOR_CC_SAVE_RESTORE_UTIL_H_ 154