xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/save_restore_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/cc/save_restore_util.h"
17 
18 namespace tensorflow {
19 namespace dtensor {
20 
21 namespace {
22 // A map that is keyed by the index of tensor_name.
23 // For example, {2 : <"spec_a", "spec_b"> } means that the
24 // save_v2.tensor_names[2] should have "spec_a" and "spec_b" saved.
25 using SliceSpecByName = absl::flat_hash_map<int64_t, std::vector<std::string>>;
26 
27 // Builds a map from tensor slice spec to saving device_id for the given Tensor
28 // and layout. The output would record the saving device and the slices it needs
29 // to save.
30 //
31 // For each sharded Tensor, each device would hold a slice of the Tensor - but
32 // it isn't necessary a unique copy. For a 2 way sharded Tensor in a (2,4) mesh
33 // on the first dimension, device [0-3] and device [4-7] will hold the same
34 // slice data. To avoid saving duplicated copies of the Tensor slice, the map
35 // would only contain the min(device_id) that occupies the slice and save from
36 // there.
37 //
38 // Furthermore, to save a Tensor that isn't on CPU mesh, send/recv is necessary
39 // from saving device to its corresponding host(CPU) devices. Since we don't
40 // have multi-mesh execution yet, this isn't implemented yet.
BuildSliceSpecDeviceMap(absl::Span<const int64_t> global_shape,Layout layout)41 StatusOr<SliceSpecByName> BuildSliceSpecDeviceMap(
42     absl::Span<const int64_t> global_shape, Layout layout) {
43   if (!layout.mesh().is_cpu_mesh())
44     return errors::Unimplemented(
45         "Saving tensors on non CPU mesh needs explicit send/receive and isn't "
46         "implemented yet");
47 
48   // Result map that records the minimum device_id that occupies the unique
49   // copy.
50   // Note that llvm::SmallDenseMap won't accept std::string as a key.
51   absl::flat_hash_map<std::string, int64_t> min_device_for_slice_spec;
52   // Records the map of device_ids and a list of slice_spec that it needs to
53   // save.
54   SliceSpecByName device_slices;
55 
56   const auto& mesh = layout.mesh();
57   // Construct SliceSpec for each device in the mesh.
58   for (int device_id = 0; device_id < mesh.size(); ++device_id) {
59     TF_ASSIGN_OR_RETURN(const DeviceLocation& coords,
60                         mesh.device_location(device_id));
61     // Prefill with full spec on each dim.
62     TF_ASSIGN_OR_RETURN(std::vector<std::string> slice_specs,
63                         SliceSpecOnDevice(layout, mesh, coords, global_shape));
64 
65     // Build the real slice_spec from string pieces.
66     std::string slice_spec = absl::StrJoin(slice_specs, ":");
67     // Get local shape from the global shape.
68     std::string shape_spec = absl::StrJoin(global_shape, " ");
69     // Concat shape spec and slice spec to form a complete shape_and_slice.
70     std::string shape_and_slice = absl::StrCat(shape_spec, " ", slice_spec);
71 
72     // Only record the min device_id for the unique slice_spec on a given
73     // Tensor.
74     if (min_device_for_slice_spec.find(shape_and_slice) ==
75             min_device_for_slice_spec.end() ||
76         device_id < min_device_for_slice_spec[shape_and_slice]) {
77       min_device_for_slice_spec[shape_and_slice] = device_id;
78     }
79   }
80 
81   // Constructs device_id keyed map for future save operation conditioned on
82   // device_ids.
83   for (const auto& spec_and_id : min_device_for_slice_spec) {
84     device_slices[spec_and_id.second].push_back(spec_and_id.first);
85   }
86 
87   return device_slices;
88 }
89 
90 }  // namespace
91 
92 // Example is _dev-02-of-16.
DeviceSuffix(int device_id,int total_devices)93 std::string DeviceSuffix(int device_id, int total_devices) {
94   return absl::StrFormat("_dev-%0*d-of-%d", absl::StrCat(total_devices).size(),
95                          device_id, total_devices);
96 }
97 
98 StatusOr<absl::flat_hash_map<
99     int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>>
BuildSavingSpec(absl::Span<const SavingTensorMetadata> tensor_metadatas)100 BuildSavingSpec(absl::Span<const SavingTensorMetadata> tensor_metadatas) {
101   absl::flat_hash_map<int64_t,
102                       absl::flat_hash_map<int64_t, std::vector<std::string>>>
103       saving_specs;
104   for (const SavingTensorMetadata& tensor_metadata : tensor_metadatas) {
105     // We use index to select the tensor names and shape_and_slices from the
106     // inputs. This is generic regardless whether the inputs are constants or
107     // just arguments.
108     int index = tensor_metadata.tensor_index;
109     const Layout& layout = tensor_metadata.layout;
110     absl::Span<const int64_t> tensor_shape = tensor_metadata.shape;
111 
112     if (layout.IsFullyReplicated()) {
113       // Push a fully replicated save on device 0, where slice_spec is simply
114       // empty string.
115       saving_specs[0][index].push_back("");
116     } else {
117       // Calculate shape_and_slices for sharded case here.
118       TF_ASSIGN_OR_RETURN(const auto& slice_specs,
119                           BuildSliceSpecDeviceMap(tensor_shape, layout));
120       // Push specs for each device into the global map.
121       for (const auto& slice_spec : slice_specs) {
122         int64_t saving_device_id = slice_spec.first;
123         for (const std::string& slice : slice_spec.second) {
124           saving_specs[saving_device_id][index].push_back(slice);
125         }
126       }
127     }
128   }
129 
130   return saving_specs;
131 }
132 
BuildPerDeviceSave(const absl::flat_hash_map<int64_t,std::vector<std::string>> & saving_spec,const int device_id,absl::string_view prefix,const int total_devices)133 SaveOpSpecs BuildPerDeviceSave(
134     const absl::flat_hash_map<int64_t, std::vector<std::string>>& saving_spec,
135     const int device_id, absl::string_view prefix, const int total_devices) {
136   std::vector<std::string> new_prefixes;
137   std::vector<std::vector<int>> tensor_indices;
138   std::vector<std::vector<std::string>> shape_and_slice_specs;
139   for (const auto& tensor_name_index_and_slice_specs : saving_spec) {
140     int tensor_index = tensor_name_index_and_slice_specs.first;
141     const std::vector<std::string> specs =
142         tensor_name_index_and_slice_specs.second;
143     // For each tensor_name, we save its first slice_spec in the first
144     // save_op, second slice_spec in the second save op, etc.
145     // This allows us to group save ops together without running into
146     // duplicated tensor_names (which save_v2 op doesn't support).
147     for (int save_op_index = 0; save_op_index < specs.size(); ++save_op_index) {
148       if (save_op_index >= tensor_indices.size()) {
149         tensor_indices.push_back({});
150         shape_and_slice_specs.push_back({});
151         // Generate new prefix based on device_id and save op index, only when
152         // we need a new save_op.
153         new_prefixes.push_back(
154             absl::StrCat(prefix, DeviceSuffix(device_id, total_devices)));
155       }
156       tensor_indices[save_op_index].push_back(tensor_index);
157       shape_and_slice_specs[save_op_index].push_back(specs[save_op_index]);
158     }
159   }
160 
161   return SaveOpSpecs(new_prefixes, tensor_indices, shape_and_slice_specs);
162 }
163 
SliceSpecOnDevice(const Layout & layout,const Mesh & mesh,const DeviceLocation & device_coords,absl::Span<const int64_t> global_shape)164 StatusOr<std::vector<std::string>> SliceSpecOnDevice(
165     const Layout& layout, const Mesh& mesh, const DeviceLocation& device_coords,
166     absl::Span<const int64_t> global_shape) {
167   // Prefill the slice with replicated layouts.
168   std::vector<std::string> slice_specs(global_shape.size(), "-");
169 
170   const std::vector<std::string>& sharding_spec_strs =
171       layout.sharding_spec_strs();
172   for (int tensor_dim_index = 0; tensor_dim_index < sharding_spec_strs.size();
173        ++tensor_dim_index) {
174     const std::string& mesh_dim = sharding_spec_strs[tensor_dim_index];
175     if (layout.IsShardedDimension(mesh_dim)) {
176       TF_ASSIGN_OR_RETURN(int mesh_dim_index, mesh.idx_for_dim(mesh_dim));
177       TF_ASSIGN_OR_RETURN(int64_t dim_size, mesh.dim_size(mesh_dim));
178       int64_t per_slice_size = global_shape[tensor_dim_index] / dim_size;
179       int start = device_coords[mesh_dim_index] * per_slice_size;
180       slice_specs[tensor_dim_index] = absl::StrCat(start, ",", per_slice_size);
181     }
182   }
183   return slice_specs;
184 }
185 
186 }  // namespace dtensor
187 }  // namespace tensorflow
188